AI
The AI layer in JOSHUA provides an extensible model architecture for robot intelligence. It supports three primary workflows: inference (deploying trained models for real-time autonomous control), training (reinforcement learning and imitation learning pipelines), and data collection (recording teleoperation demonstrations for learning from human expertise).
Overview
The AI system is built around a pluggable model architecture where new models can be integrated without modifying the core inference pipeline. A central ModelRegistry maps model type identifiers to concrete model implementations, while a generic ROS2 inference node handles all communication with the rest of the robot system. This separation of concerns allows researchers and developers to focus on model logic without worrying about the robotics middleware layer.
The AI system spans three major subsystems:
- Inference Pipeline — Loads trained models via the registry, subscribes to perception topics (cameras, joint states), runs forward passes, and publishes action commands.
- Training Pipeline — Provides RL training (PPO via JAX/Flax with MuJoCo-XLA) and integration points for imitation learning using collected demonstration data.
- Data Collection — Records teleoperation sessions with episode indexing, producing datasets compatible with HuggingFace and LeRobot formats for downstream training.
Architecture Diagram
+------------------------------------------------------------------+
| AI SYSTEM LAYER |
+------------------------------------------------------------------+
| |
| +---------------------+ +------------------------------+ |
| | Model Registry | | Training Pipeline | |
| | (Singleton) | | | |
| | | | +----------+ +-----------+ | |
| | ModelType -> Class | | | MJX PPO | | Isaac Sim | | |
| | RANDOM_NOISE --+ | | | JAX/Flax | | RL Tasks | | |
| | SMOL_VLA ------+ | | +----------+ +-----------+ | |
| | (extensible) | | | | | |
| +------------------+--+ +---------|--------------------+ |
| | | |
| +---------v---------+ | +------------------+ |
| | ModelBase | | | Data Collection | |
| | (Abstract Class) | | | | |
| | | | | DataStore config | |
| | handle_input() | | | Episode indexing | |
| | preprocess_input()| | | HF / LeRobot | |
| | inference() | | | CSV / Parquet | |
| | postprocess() | | +------------------+ |
| | forward() | | |
| +---------+---------+ | |
| | | |
| +------------------v-----------------v---------------------+ |
| | ROS2 Inference Node (inference.py) | |
| | | |
| | Config ---> Registry ---> Model Instance | |
| | Subscribe: /camera/*, /joint_states, /task_description | |
| | Publish: /action_commands | |
| +----------------------------------------------------------+ |
| |
+------------------------------------------------------------------+
| ^
v |
+---------------+ +---------------+
| Perception | | Actuators |
| (Cameras, | | (Servos, |
| Sensors) | | Motors) |
+---------------+ +---------------+
ModelBase (Abstract Class)
ModelBase is the abstract base class that defines the standard interface every AI model must implement. It enforces a consistent data flow pattern: input handling, preprocessing, inference, and postprocessing. This ensures that all models, regardless of their internal architecture, can be seamlessly loaded and executed by the generic ROS2 inference node.
class ModelBase(ABC):
"""Abstract base class for all JOSHUA AI models."""
def __init__(self, config: ModelConfig):
self.config = config
self.device = self._resolve_device()
@abstractmethod
def handle_input(self, ros_msg) -> dict:
"""Convert ROS2 message(s) into a model-friendly dict."""
pass
@abstractmethod
def preprocess_input(self, raw_input: dict) -> Any:
"""Normalize, resize, tokenize, or batch inputs for inference."""
pass
@abstractmethod
def inference(self, processed_input: Any) -> Any:
"""Run the forward pass / model prediction."""
pass
@abstractmethod
def postprocess_output(self, raw_output: Any) -> dict:
"""Convert model output to action commands dict."""
pass
def forward(self, ros_msg) -> dict:
"""Full pipeline: handle -> preprocess -> infer -> postprocess."""
raw = self.handle_input(ros_msg)
processed = self.preprocess_input(raw)
output = self.inference(processed)
return self.postprocess_output(output)
Method Responsibilities
| Method | Purpose | Input | Output |
|---|---|---|---|
handle_input() |
Converts ROS2 messages into a model-friendly dictionary | ROS2 message(s) | dict |
preprocess_input() |
Normalizes, resizes, tokenizes, or batches inputs | dict |
Tensor / array |
inference() |
Runs the forward pass or model prediction | Tensor / array | Raw model output |
postprocess_output() |
Converts raw output to action command dictionary | Raw model output | dict |
forward() |
Orchestrates the full pipeline end-to-end | ROS2 message(s) | dict |
ModelRegistry
The ModelRegistry implements the singleton pattern for centralized model discovery and instantiation. It maintains a mapping from ModelType enum values (defined in Protocol Buffers) to their corresponding model classes. When the inference node starts, it queries the registry with the configured model type to obtain the correct class, then instantiates it with the provided configuration.
class ModelRegistry:
"""Singleton registry mapping ModelType enum to model classes."""
_instance = None
_registry: dict[ModelType, type[ModelBase]] = {}
@classmethod
def instance(cls) -> "ModelRegistry":
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def register(cls, model_type: ModelType):
"""Decorator to register a model class."""
def decorator(model_cls: type[ModelBase]):
cls._registry[model_type] = model_cls
return model_cls
return decorator
def get_model(self, model_type: ModelType, config: ModelConfig) -> ModelBase:
"""Instantiate and return a model by type."""
if model_type not in self._registry:
raise ValueError(f"Unknown model type: {model_type}")
return self._registry[model_type](config)
# Usage: registering a model
@ModelRegistry.register(ModelType.SMOL_VLA)
class SmolVLAModel(ModelBase):
...
ModelBase subclass with @ModelRegistry.register(ModelType.YOUR_TYPE). The registry automatically picks up the new model at startup.
ROS2 Inference Node
The inference.py module provides a generic ROS2 node that serves as the bridge between the robot's perception system and the AI model. It is model-agnostic: it loads whichever model is specified in the configuration, sets up the appropriate publishers and subscribers based on the model's input/output requirements, and delegates all data processing to the model's forward() method.
class InferenceNode(Node):
"""Generic ROS2 node for AI model inference."""
def __init__(self, config: SystemConfig):
super().__init__("inference_node")
model_config = config.ai.model
registry = ModelRegistry.instance()
# Load model from registry
self.model = registry.get_model(model_config.type, model_config)
# Set up subscribers based on model input requirements
for topic in model_config.input_topics:
self.create_subscription(topic.msg_type, topic.name,
self._on_input, topic.qos)
# Set up publisher for action commands
self.action_pub = self.create_publisher(
ActionCommand, model_config.output_topic, 10)
def _on_input(self, msg):
"""Callback: run model forward pass and publish actions."""
actions = self.model.forward(msg)
action_msg = self._dict_to_action_msg(actions)
self.action_pub.publish(action_msg)
Implemented Models: RandomNoise
The RandomNoise model is a test and demonstration model that generates random action values within configurable bounds. It does not perform any meaningful inference but serves a critical role in system integration testing: it validates the entire inference pipeline from ROS2 subscription through model execution to action command publishing, without requiring a GPU or trained weights.
Characteristics
- Purpose: System integration testing and pipeline validation
- Input: Any ROS2 message (ignored during processing)
- Output: Random action values within configurable noise bounds
- Dependencies: None (no GPU, no pretrained weights)
- Configuration:
noise_lower_boundandnoise_upper_boundparameters
@ModelRegistry.register(ModelType.RANDOM_NOISE)
class RandomNoiseModel(ModelBase):
"""Test model generating random action values."""
def __init__(self, config: ModelConfig):
super().__init__(config)
self.lower = config.params.get("noise_lower_bound", -1.0)
self.upper = config.params.get("noise_upper_bound", 1.0)
self.action_dim = config.action_dim
def handle_input(self, ros_msg) -> dict:
return {} # No input processing needed
def preprocess_input(self, raw_input: dict) -> None:
return None # No preprocessing needed
def inference(self, processed_input) -> np.ndarray:
return np.random.uniform(self.lower, self.upper,
size=self.action_dim)
def postprocess_output(self, raw_output: np.ndarray) -> dict:
return {"actions": raw_output.tolist()}
Implemented Models: SmolVLA (Vision-Language-Action)
SmolVLA is the production inference model integrated from the HuggingFace LeRobot library. It is a Vision-Language-Action (VLA) model that processes multi-camera images, robot joint states, and natural language task descriptions to produce normalized action values for robot control.
lerobot/smolvla_base pretrained checkpoint from HuggingFace Hub, which has accumulated over 18,000 downloads. This checkpoint provides generalized manipulation capabilities that can be fine-tuned for specific tasks.
Input Processing
- Multi-camera images: Subscribes to multiple camera topics (e.g.,
/camera/front/image_raw,/camera/wrist/image_raw) and processes them as visual observations. - Joint states: Current robot joint positions and velocities from
/joint_states. - Language task description: Natural language string describing the desired task (e.g., "pick up the red cube and place it on the plate").
Output
- Normalized action values representing target joint positions or end-effector deltas.
- Action values are denormalized in postprocessing to match the robot's physical joint limits.
Hardware Requirements
- GPU: CUDA-capable GPU required for real-time inference.
- VRAM: Minimum 4 GB recommended for the base checkpoint.
- Supported platforms: NVIDIA Jetson Orin Nano (ARM64), desktop GPUs (AMD64).
@ModelRegistry.register(ModelType.SMOL_VLA)
class SmolVLAModel(ModelBase):
"""Vision-Language-Action model from HuggingFace LeRobot."""
def __init__(self, config: ModelConfig):
super().__init__(config)
self.policy = AutoPolicy.from_pretrained(
config.checkpoint or "lerobot/smolvla_base"
)
self.policy.to(self.device)
self.policy.eval()
def handle_input(self, ros_msg) -> dict:
return {
"images": self._extract_images(ros_msg),
"joint_states": self._extract_joints(ros_msg),
"task": ros_msg.task_description,
}
def preprocess_input(self, raw_input: dict) -> dict:
images = self._resize_and_normalize(raw_input["images"])
joints = self._normalize_joints(raw_input["joint_states"])
return {
"observation.images": images,
"observation.state": joints,
"task": raw_input["task"],
}
def inference(self, processed_input: dict) -> torch.Tensor:
with torch.no_grad():
return self.policy.select_action(processed_input)
def postprocess_output(self, raw_output: torch.Tensor) -> dict:
actions = self._denormalize_actions(raw_output)
return {"actions": actions.cpu().numpy().tolist()}
Model Comparison
| Property | RandomNoise | SmolVLA |
|---|---|---|
| Purpose | Integration testing | Production inference |
| Input Modalities | None (ignores input) | Multi-camera images + joint states + language |
| Output | Random values in [lower, upper] | Normalized action values |
| GPU Required | No | Yes (CUDA) |
| Pretrained Weights | None | lerobot/smolvla_base (18k+ downloads) |
| Dependencies | NumPy only | PyTorch, HuggingFace Transformers, LeRobot |
| Latency | < 1ms | ~10–50ms (GPU-dependent) |
| Use Case | CI/CD, pipeline validation | Autonomous manipulation tasks |
Training Pipeline
The training pipeline supports both reinforcement learning (RL) and imitation learning workflows. The entry point is trainer.py, which reads the training configuration and dispatches to the appropriate backend and method based on the specified training type.
+---------------------------------------------------------+
| trainer.py (Entry Point) |
| |
| Reads TrainingConfig from .pbtxt |
| Dispatches to appropriate training backend: |
| |
| +-------------------+ +-------------------------+ |
| | MJX PPO | | Isaac Sim RL | |
| | (mjx_rl.py) | | (isaac_tasks/*.py) | |
| | | | | |
| | JAX/Flax backend | | NVIDIA Isaac backend | |
| | MuJoCo-XLA envs | | GPU-accelerated sim | |
| | CleanRL-style PPO | | Ant, Trileg tasks | |
| +-------------------+ +-------------------------+ |
| |
+---------------------------------------------------------+
MJX PPO (JAX/Flax)
The primary RL training implementation uses a CleanRL-style Proximal Policy Optimization (PPO) algorithm built on JAX and Flax. This backend leverages MuJoCo-XLA (MJX) for massively parallel environment simulation, running up to 2048 environments simultaneously on a single GPU.
Actor-Critic Network Architecture
+------------------------------------------+
| Actor-Critic Network |
| (Shared observation input) |
+------------------------------------------+
| |
| Observation (state vector) |
| | |
| v |
| +------------------+ |
| | Dense(256) + ReLU | |
| +------------------+ |
| | |
| v |
| +------------------+ |
| | Dense(256) + ReLU | |
| +------------------+ |
| | |
| +----+----+ |
| | | |
| v v |
| +------+ +-------+ |
| |Actor | |Critic | |
| |Head | |Head | |
| +------+ +-------+ |
| | | |
| v v |
| Actions Value V(s) |
| (mean) |
+------------------------------------------+
PPO Training Details
| Parameter | Value | Description |
|---|---|---|
| Hidden Layers | 256-256 | Two hidden layers with 256 units each |
| Activation | ReLU | Rectified Linear Unit |
| Parallel Envs | 2048 | Simultaneous MuJoCo-XLA environments |
| Advantage Estimation | GAE | Generalized Advantage Estimation |
| Loss Function | Clipped PPO | Clipped surrogate objective with entropy bonus |
| Framework | JAX / Flax | JIT-compiled, GPU-accelerated training |
| Checkpoint Format | msgpack + JSON | Model weights (msgpack) with metadata (JSON) |
| Visualization | Live MuJoCo viewer | Real-time rendering during training |
Training Loop
# Simplified MJX PPO training loop (mjx_rl.py)
# Initialize 2048 parallel environments via MuJoCo-XLA
env_state = jax.vmap(env.reset)(rng_keys)
for update in range(num_updates):
# Collect rollouts across all parallel envs
for step in range(num_steps):
action, log_prob, value = agent.get_action_and_value(obs)
env_state = jax.vmap(env.step)(env_state, action)
buffer.store(obs, action, reward, done, log_prob, value)
# Compute advantages using GAE
advantages = compute_gae(rewards, values, dones,
gamma=0.99, gae_lambda=0.95)
# PPO update with clipped objective
for epoch in range(update_epochs):
for batch in buffer.iterate_minibatches(batch_size):
loss = ppo_clipped_loss(batch, advantages, clip_coef=0.2)
params = optimizer.step(loss, params)
# Save checkpoint periodically
if update % save_interval == 0:
save_checkpoint(params, metadata, path="checkpoints/")
# Update live MuJoCo viewer
viewer.render(env_state)
MJX Environments
The MJX environment system provides a modular framework for defining RL tasks compatible with MuJoCo-XLA. Each environment implements standardized reset, step, observation, and reward functions. The modular design separates environment logic into composable "terms" for rewards, terminations, resets, and observations.
Environment Architecture
+------------------------------------------------------+
| MJX Environment |
+------------------------------------------------------+
| |
| reset(rng) ---------> initial MJX state |
| step(state, action) -> next state + reward + done |
| |
| +------------------+ +------------------+ |
| | Observations | | Rewards | |
| | (modular) | | (modular) | |
| | | | | |
| | joint_positions | | distance_reward | |
| | joint_velocities| | energy_penalty | |
| | end_effector_pos| | success_bonus | |
| | target_position | | alive_bonus | |
| +------------------+ +------------------+ |
| |
| +------------------+ +------------------+ |
| | Terminations | | Resets | |
| | (modular) | | (modular) | |
| | | | | |
| | out_of_bounds | | randomize_target| |
| | max_steps | | randomize_init | |
| | self_collision | | domain_randomize| |
| +------------------+ +------------------+ |
| |
+------------------------------------------------------+
Available Tasks
| Task | Type | Description | Observation Space |
|---|---|---|---|
| reach | Arm manipulation | Robot arm reaching a target position in 3D space | Joint positions + velocities + target position |
| ant | Locomotion | Quadruped ant locomotion (forward movement) | Body pose + joint angles + contact forces |
| pick_place | Manipulation | Pick up an object and place it at a goal location | Joint states + object pose + goal position |
Modular Terms
Each task is composed of modular term functions that can be mixed and matched:
- Reward terms: Distance-based rewards, energy penalties, success bonuses, alive bonuses. Combined with configurable weights.
- Termination terms: Out-of-bounds detection, maximum step limits, self-collision checks.
- Reset terms: Target position randomization, initial state randomization, domain randomization for sim-to-real transfer.
- Observation terms: Joint positions, joint velocities, end-effector position, target position, contact sensor readings.
Isaac Sim Integration
JOSHUA integrates with NVIDIA Isaac Sim for industrial-grade physics simulation and RL training. Isaac Sim provides photorealistic rendering, advanced contact physics, and GPU-accelerated simulation through its Omniverse platform.
Available Isaac Sim Tasks
- Ant: Quadruped locomotion task mirroring the MJX ant environment, enabling direct comparison between MuJoCo-XLA and Isaac Sim training backends.
- Trileg: Three-legged robot locomotion task specific to custom JOSHUA robot configurations, testing asymmetric gait learning.
Data Collection System
The data collection system records teleoperation demonstrations for use in imitation learning pipelines. It integrates with the teleoperation mode to capture synchronized sensor data, joint states, and action commands, organizing them into episodes for structured dataset creation.
DataStore Configuration
The DataStore configuration specifies how recorded demonstrations are stored and organized. It controls episode indexing, file format, storage backend, and metadata tagging.
# Example DataStore configuration in .pbtxt
data_store {
base_path: "/data/demonstrations"
format: HUGGINGFACE
episode_index: true
metadata {
robot_type: "so100_arm"
task: "pick_and_place"
operator: "human_expert"
}
storage {
type: LOCAL
# type: CLOUD # for remote storage
}
}
Episode Indexing
Each teleoperation session is recorded as a numbered episode. Episode indexing ensures organized dataset creation with consistent naming, automatic incrementing, and metadata association. Episodes contain timestamped frames of:
- Camera images (all configured cameras)
- Joint positions and velocities
- Action commands (leader arm positions in teleoperation)
- Task description metadata
Dataset Formats
| Format | Description | Use Case |
|---|---|---|
| HuggingFace Datasets | Apache Arrow-backed columnar format with lazy loading | Large-scale training with streaming support |
| CSV | Plain-text comma-separated values | Quick inspection, simple analysis scripts |
| JSONL | Newline-delimited JSON for structured records | Flexible schema, metadata-rich recordings |
| Parquet | Columnar binary format with compression | Efficient storage, fast analytical queries |
Storage Backends
- Local filesystem: Direct storage on the robot's disk or an NFS mount. Best for low-latency recording during teleoperation.
- Cloud storage: Upload to S3-compatible object storage for centralized dataset management and team collaboration.
LeRobot Dataset Integration
The data collection system produces datasets directly compatible with the LeRobot dataset format. This enables seamless use of collected demonstrations for fine-tuning SmolVLA and other LeRobot-compatible models through imitation learning. The format includes standardized observation keys, action normalization statistics, and episode boundary markers.
Inference Pipeline Flow
The end-to-end inference pipeline transforms a system configuration into real-time autonomous robot behavior. Below is the complete data flow from startup to action execution:
INFERENCE PIPELINE FLOW
=======================
1. STARTUP
.pbtxt Config
|
v
Parse ModelConfig (type, checkpoint, topics, device)
|
v
ModelRegistry.instance().get_model(type, config)
|
v
Model class instantiation (loads weights, moves to GPU)
|
v
ROS2 Inference Node created
2. RUNTIME (per cycle)
/camera/front/image_raw ----+
/camera/wrist/image_raw ----+---> handle_input()
/joint_states --------------+ |
/task_description ----------+ v
preprocess_input()
- Resize images
- Normalize joints
- Tokenize language
|
v
inference()
- Forward pass (GPU)
- torch.no_grad()
|
v
postprocess_output()
- Denormalize actions
- Clip to joint limits
|
v
/action_commands (publish)
|
v
Robot Actuators Execute
Adding New Models
The AI system is designed for extensibility. Follow these steps to integrate a new model into the JOSHUA inference pipeline:
Step 1: Define the ModelType
Add a new entry to the ModelType enum in the Protocol Buffers schema:
// In proto/joshua.proto
enum ModelType {
UNKNOWN_MODEL = 0;
RANDOM_NOISE = 1;
SMOL_VLA = 2;
MY_NEW_MODEL = 3; // Add your new model type
}
Step 2: Implement ModelBase
Create a new class inheriting from ModelBase and implement all required methods:
# In ai/models/my_new_model.py
from ai.model_base import ModelBase
from ai.model_registry import ModelRegistry, ModelType
@ModelRegistry.register(ModelType.MY_NEW_MODEL)
class MyNewModel(ModelBase):
"""Custom model implementation."""
def __init__(self, config):
super().__init__(config)
# Load your model weights, initialize components
self.net = load_my_network(config.checkpoint)
def handle_input(self, ros_msg) -> dict:
# Extract relevant data from ROS2 messages
return {"sensor_data": ros_msg.data}
def preprocess_input(self, raw_input: dict):
# Normalize, transform, batch
return self.transform(raw_input["sensor_data"])
def inference(self, processed_input):
# Run forward pass
return self.net(processed_input)
def postprocess_output(self, raw_output) -> dict:
# Convert to action commands
return {"actions": raw_output.tolist()}
Step 3: Register in ModelRegistry
The @ModelRegistry.register() decorator (shown above) handles registration automatically. Ensure that your model module is imported at startup so the decorator executes. Add the import to ai/models/__init__.py:
# In ai/models/__init__.py
from ai.models.random_noise import RandomNoiseModel
from ai.models.smolvla import SmolVLAModel
from ai.models.my_new_model import MyNewModel # Add this import
Step 4: Create a Config Preset
Create a .pbtxt configuration preset that uses your new model:
# In config/presets/my_new_model_inference.pbtxt
operation_mode: AI_INFERENCE
ai {
model {
type: MY_NEW_MODEL
checkpoint: "path/to/my_weights.pt"
input_topics {
name: "/camera/front/image_raw"
msg_type: "sensor_msgs/Image"
}
input_topics {
name: "/joint_states"
msg_type: "sensor_msgs/JointState"
}
output_topic: "/action_commands"
device: "cuda:0"
}
}
Step 5: Verify Integration
Run the inference pipeline with the RandomNoise model first to verify the pipeline, then switch to your model:
# Test pipeline with RandomNoise
./run.sh --config config/presets/random_noise_test.pbtxt
# Switch to your new model
./run.sh --config config/presets/my_new_model_inference.pbtxt
- Add
ModelTypeentry to the protobuf enum - Inherit from
ModelBaseand implement all abstract methods - Decorate with
@ModelRegistry.register(ModelType.YOUR_TYPE) - Add import to
ai/models/__init__.py - Create a
.pbtxtconfig preset for the model - Test with the inference pipeline end-to-end