AI

The AI layer in JOSHUA provides an extensible model architecture for robot intelligence. It supports three primary workflows: inference (deploying trained models for real-time autonomous control), training (reinforcement learning and imitation learning pipelines), and data collection (recording teleoperation demonstrations for learning from human expertise).

Overview

The AI system is built around a pluggable model architecture where new models can be integrated without modifying the core inference pipeline. A central ModelRegistry maps model type identifiers to concrete model implementations, while a generic ROS2 inference node handles all communication with the rest of the robot system. This separation of concerns allows researchers and developers to focus on model logic without worrying about the robotics middleware layer.

Key Design Principles The AI system follows a configuration-driven approach: the model type, checkpoint path, input/output topics, and hardware acceleration settings are all specified in the system's Protocol Buffers configuration file. No code changes are required to switch between models.

The AI system spans three major subsystems:

Architecture Diagram

+------------------------------------------------------------------+
|                        AI SYSTEM LAYER                           |
+------------------------------------------------------------------+
|                                                                  |
|  +---------------------+    +------------------------------+     |
|  |   Model Registry    |    |     Training Pipeline        |     |
|  |   (Singleton)       |    |                              |     |
|  |                     |    |  +----------+ +-----------+  |     |
|  |  ModelType -> Class |    |  | MJX PPO  | | Isaac Sim |  |     |
|  |  RANDOM_NOISE  --+  |    |  | JAX/Flax | | RL Tasks  |  |     |
|  |  SMOL_VLA  ------+  |    |  +----------+ +-----------+  |     |
|  |  (extensible)    |  |    |         |                    |     |
|  +------------------+--+    +---------|--------------------+     |
|                     |                 |                           |
|           +---------v---------+       |    +------------------+  |
|           |    ModelBase      |       |    | Data Collection  |  |
|           |  (Abstract Class) |       |    |                  |  |
|           |                   |       |    | DataStore config |  |
|           | handle_input()    |       |    | Episode indexing |  |
|           | preprocess_input()|       |    | HF / LeRobot    |  |
|           | inference()       |       |    | CSV / Parquet    |  |
|           | postprocess()     |       |    +------------------+  |
|           | forward()         |       |                          |
|           +---------+---------+       |                          |
|                     |                 |                          |
|  +------------------v-----------------v---------------------+    |
|  |              ROS2 Inference Node (inference.py)          |    |
|  |                                                          |    |
|  |  Config ---> Registry ---> Model Instance                |    |
|  |  Subscribe: /camera/*, /joint_states, /task_description  |    |
|  |  Publish:   /action_commands                             |    |
|  +----------------------------------------------------------+    |
|                                                                  |
+------------------------------------------------------------------+
           |                                          ^
           v                                          |
   +---------------+                          +---------------+
   |  Perception   |                          |   Actuators   |
   |  (Cameras,    |                          |   (Servos,    |
   |   Sensors)    |                          |    Motors)    |
   +---------------+                          +---------------+

ModelBase (Abstract Class)

ModelBase is the abstract base class that defines the standard interface every AI model must implement. It enforces a consistent data flow pattern: input handling, preprocessing, inference, and postprocessing. This ensures that all models, regardless of their internal architecture, can be seamlessly loaded and executed by the generic ROS2 inference node.

class ModelBase(ABC):
    """Abstract base class for all JOSHUA AI models."""

    def __init__(self, config: ModelConfig):
        self.config = config
        self.device = self._resolve_device()

    @abstractmethod
    def handle_input(self, ros_msg) -> dict:
        """Convert ROS2 message(s) into a model-friendly dict."""
        pass

    @abstractmethod
    def preprocess_input(self, raw_input: dict) -> Any:
        """Normalize, resize, tokenize, or batch inputs for inference."""
        pass

    @abstractmethod
    def inference(self, processed_input: Any) -> Any:
        """Run the forward pass / model prediction."""
        pass

    @abstractmethod
    def postprocess_output(self, raw_output: Any) -> dict:
        """Convert model output to action commands dict."""
        pass

    def forward(self, ros_msg) -> dict:
        """Full pipeline: handle -> preprocess -> infer -> postprocess."""
        raw = self.handle_input(ros_msg)
        processed = self.preprocess_input(raw)
        output = self.inference(processed)
        return self.postprocess_output(output)

Method Responsibilities

Method Purpose Input Output
handle_input() Converts ROS2 messages into a model-friendly dictionary ROS2 message(s) dict
preprocess_input() Normalizes, resizes, tokenizes, or batches inputs dict Tensor / array
inference() Runs the forward pass or model prediction Tensor / array Raw model output
postprocess_output() Converts raw output to action command dictionary Raw model output dict
forward() Orchestrates the full pipeline end-to-end ROS2 message(s) dict

ModelRegistry

The ModelRegistry implements the singleton pattern for centralized model discovery and instantiation. It maintains a mapping from ModelType enum values (defined in Protocol Buffers) to their corresponding model classes. When the inference node starts, it queries the registry with the configured model type to obtain the correct class, then instantiates it with the provided configuration.

class ModelRegistry:
    """Singleton registry mapping ModelType enum to model classes."""

    _instance = None
    _registry: dict[ModelType, type[ModelBase]] = {}

    @classmethod
    def instance(cls) -> "ModelRegistry":
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    @classmethod
    def register(cls, model_type: ModelType):
        """Decorator to register a model class."""
        def decorator(model_cls: type[ModelBase]):
            cls._registry[model_type] = model_cls
            return model_cls
        return decorator

    def get_model(self, model_type: ModelType, config: ModelConfig) -> ModelBase:
        """Instantiate and return a model by type."""
        if model_type not in self._registry:
            raise ValueError(f"Unknown model type: {model_type}")
        return self._registry[model_type](config)


# Usage: registering a model
@ModelRegistry.register(ModelType.SMOL_VLA)
class SmolVLAModel(ModelBase):
    ...
Extensibility New models are added by simply decorating a ModelBase subclass with @ModelRegistry.register(ModelType.YOUR_TYPE). The registry automatically picks up the new model at startup.

ROS2 Inference Node

The inference.py module provides a generic ROS2 node that serves as the bridge between the robot's perception system and the AI model. It is model-agnostic: it loads whichever model is specified in the configuration, sets up the appropriate publishers and subscribers based on the model's input/output requirements, and delegates all data processing to the model's forward() method.

class InferenceNode(Node):
    """Generic ROS2 node for AI model inference."""

    def __init__(self, config: SystemConfig):
        super().__init__("inference_node")
        model_config = config.ai.model
        registry = ModelRegistry.instance()

        # Load model from registry
        self.model = registry.get_model(model_config.type, model_config)

        # Set up subscribers based on model input requirements
        for topic in model_config.input_topics:
            self.create_subscription(topic.msg_type, topic.name,
                                     self._on_input, topic.qos)

        # Set up publisher for action commands
        self.action_pub = self.create_publisher(
            ActionCommand, model_config.output_topic, 10)

    def _on_input(self, msg):
        """Callback: run model forward pass and publish actions."""
        actions = self.model.forward(msg)
        action_msg = self._dict_to_action_msg(actions)
        self.action_pub.publish(action_msg)

Implemented Models: RandomNoise

The RandomNoise model is a test and demonstration model that generates random action values within configurable bounds. It does not perform any meaningful inference but serves a critical role in system integration testing: it validates the entire inference pipeline from ROS2 subscription through model execution to action command publishing, without requiring a GPU or trained weights.

Characteristics

@ModelRegistry.register(ModelType.RANDOM_NOISE)
class RandomNoiseModel(ModelBase):
    """Test model generating random action values."""

    def __init__(self, config: ModelConfig):
        super().__init__(config)
        self.lower = config.params.get("noise_lower_bound", -1.0)
        self.upper = config.params.get("noise_upper_bound", 1.0)
        self.action_dim = config.action_dim

    def handle_input(self, ros_msg) -> dict:
        return {}  # No input processing needed

    def preprocess_input(self, raw_input: dict) -> None:
        return None  # No preprocessing needed

    def inference(self, processed_input) -> np.ndarray:
        return np.random.uniform(self.lower, self.upper,
                                 size=self.action_dim)

    def postprocess_output(self, raw_output: np.ndarray) -> dict:
        return {"actions": raw_output.tolist()}

Implemented Models: SmolVLA (Vision-Language-Action)

SmolVLA is the production inference model integrated from the HuggingFace LeRobot library. It is a Vision-Language-Action (VLA) model that processes multi-camera images, robot joint states, and natural language task descriptions to produce normalized action values for robot control.

Pretrained Checkpoint SmolVLA uses the lerobot/smolvla_base pretrained checkpoint from HuggingFace Hub, which has accumulated over 18,000 downloads. This checkpoint provides generalized manipulation capabilities that can be fine-tuned for specific tasks.

Input Processing

Output

Hardware Requirements

@ModelRegistry.register(ModelType.SMOL_VLA)
class SmolVLAModel(ModelBase):
    """Vision-Language-Action model from HuggingFace LeRobot."""

    def __init__(self, config: ModelConfig):
        super().__init__(config)
        self.policy = AutoPolicy.from_pretrained(
            config.checkpoint or "lerobot/smolvla_base"
        )
        self.policy.to(self.device)
        self.policy.eval()

    def handle_input(self, ros_msg) -> dict:
        return {
            "images": self._extract_images(ros_msg),
            "joint_states": self._extract_joints(ros_msg),
            "task": ros_msg.task_description,
        }

    def preprocess_input(self, raw_input: dict) -> dict:
        images = self._resize_and_normalize(raw_input["images"])
        joints = self._normalize_joints(raw_input["joint_states"])
        return {
            "observation.images": images,
            "observation.state": joints,
            "task": raw_input["task"],
        }

    def inference(self, processed_input: dict) -> torch.Tensor:
        with torch.no_grad():
            return self.policy.select_action(processed_input)

    def postprocess_output(self, raw_output: torch.Tensor) -> dict:
        actions = self._denormalize_actions(raw_output)
        return {"actions": actions.cpu().numpy().tolist()}

Model Comparison

Property RandomNoise SmolVLA
Purpose Integration testing Production inference
Input Modalities None (ignores input) Multi-camera images + joint states + language
Output Random values in [lower, upper] Normalized action values
GPU Required No Yes (CUDA)
Pretrained Weights None lerobot/smolvla_base (18k+ downloads)
Dependencies NumPy only PyTorch, HuggingFace Transformers, LeRobot
Latency < 1ms ~10–50ms (GPU-dependent)
Use Case CI/CD, pipeline validation Autonomous manipulation tasks

Training Pipeline

The training pipeline supports both reinforcement learning (RL) and imitation learning workflows. The entry point is trainer.py, which reads the training configuration and dispatches to the appropriate backend and method based on the specified training type.

+---------------------------------------------------------+
|                   trainer.py (Entry Point)              |
|                                                         |
|  Reads TrainingConfig from .pbtxt                       |
|  Dispatches to appropriate training backend:            |
|                                                         |
|  +-------------------+    +-------------------------+   |
|  |    MJX PPO        |    |    Isaac Sim RL         |   |
|  |  (mjx_rl.py)      |    |  (isaac_tasks/*.py)     |   |
|  |                    |    |                         |   |
|  |  JAX/Flax backend  |    |  NVIDIA Isaac backend   |   |
|  |  MuJoCo-XLA envs   |    |  GPU-accelerated sim    |   |
|  |  CleanRL-style PPO  |    |  Ant, Trileg tasks      |   |
|  +-------------------+    +-------------------------+   |
|                                                         |
+---------------------------------------------------------+

MJX PPO (JAX/Flax)

The primary RL training implementation uses a CleanRL-style Proximal Policy Optimization (PPO) algorithm built on JAX and Flax. This backend leverages MuJoCo-XLA (MJX) for massively parallel environment simulation, running up to 2048 environments simultaneously on a single GPU.

Actor-Critic Network Architecture

+------------------------------------------+
|          Actor-Critic Network            |
|       (Shared observation input)         |
+------------------------------------------+
|                                          |
|  Observation (state vector)              |
|         |                                |
|         v                                |
|  +------------------+                    |
|  | Dense(256) + ReLU |                    |
|  +------------------+                    |
|         |                                |
|         v                                |
|  +------------------+                    |
|  | Dense(256) + ReLU |                    |
|  +------------------+                    |
|         |                                |
|    +----+----+                           |
|    |         |                           |
|    v         v                           |
| +------+  +-------+                     |
| |Actor |  |Critic |                     |
| |Head  |  |Head   |                     |
| +------+  +-------+                     |
|    |          |                          |
|    v          v                          |
| Actions   Value V(s)                    |
| (mean)                                  |
+------------------------------------------+

PPO Training Details

Parameter Value Description
Hidden Layers 256-256 Two hidden layers with 256 units each
Activation ReLU Rectified Linear Unit
Parallel Envs 2048 Simultaneous MuJoCo-XLA environments
Advantage Estimation GAE Generalized Advantage Estimation
Loss Function Clipped PPO Clipped surrogate objective with entropy bonus
Framework JAX / Flax JIT-compiled, GPU-accelerated training
Checkpoint Format msgpack + JSON Model weights (msgpack) with metadata (JSON)
Visualization Live MuJoCo viewer Real-time rendering during training

Training Loop

# Simplified MJX PPO training loop (mjx_rl.py)

# Initialize 2048 parallel environments via MuJoCo-XLA
env_state = jax.vmap(env.reset)(rng_keys)

for update in range(num_updates):
    # Collect rollouts across all parallel envs
    for step in range(num_steps):
        action, log_prob, value = agent.get_action_and_value(obs)
        env_state = jax.vmap(env.step)(env_state, action)
        buffer.store(obs, action, reward, done, log_prob, value)

    # Compute advantages using GAE
    advantages = compute_gae(rewards, values, dones,
                             gamma=0.99, gae_lambda=0.95)

    # PPO update with clipped objective
    for epoch in range(update_epochs):
        for batch in buffer.iterate_minibatches(batch_size):
            loss = ppo_clipped_loss(batch, advantages, clip_coef=0.2)
            params = optimizer.step(loss, params)

    # Save checkpoint periodically
    if update % save_interval == 0:
        save_checkpoint(params, metadata, path="checkpoints/")

    # Update live MuJoCo viewer
    viewer.render(env_state)

MJX Environments

The MJX environment system provides a modular framework for defining RL tasks compatible with MuJoCo-XLA. Each environment implements standardized reset, step, observation, and reward functions. The modular design separates environment logic into composable "terms" for rewards, terminations, resets, and observations.

Environment Architecture

+------------------------------------------------------+
|                  MJX Environment                     |
+------------------------------------------------------+
|                                                      |
|  reset(rng) ---------> initial MJX state             |
|  step(state, action) -> next state + reward + done   |
|                                                      |
|  +------------------+  +------------------+          |
|  |    Observations  |  |     Rewards      |          |
|  |    (modular)     |  |    (modular)     |          |
|  |                  |  |                  |          |
|  |  joint_positions |  |  distance_reward |          |
|  |  joint_velocities|  |  energy_penalty  |          |
|  |  end_effector_pos|  |  success_bonus   |          |
|  |  target_position |  |  alive_bonus     |          |
|  +------------------+  +------------------+          |
|                                                      |
|  +------------------+  +------------------+          |
|  |  Terminations    |  |     Resets       |          |
|  |  (modular)       |  |    (modular)     |          |
|  |                  |  |                  |          |
|  |  out_of_bounds   |  |  randomize_target|          |
|  |  max_steps       |  |  randomize_init  |          |
|  |  self_collision  |  |  domain_randomize|          |
|  +------------------+  +------------------+          |
|                                                      |
+------------------------------------------------------+

Available Tasks

Task Type Description Observation Space
reach Arm manipulation Robot arm reaching a target position in 3D space Joint positions + velocities + target position
ant Locomotion Quadruped ant locomotion (forward movement) Body pose + joint angles + contact forces
pick_place Manipulation Pick up an object and place it at a goal location Joint states + object pose + goal position

Modular Terms

Each task is composed of modular term functions that can be mixed and matched:

Isaac Sim Integration

JOSHUA integrates with NVIDIA Isaac Sim for industrial-grade physics simulation and RL training. Isaac Sim provides photorealistic rendering, advanced contact physics, and GPU-accelerated simulation through its Omniverse platform.

Available Isaac Sim Tasks

Isaac Sim Requirements Isaac Sim integration requires NVIDIA GPU with driver version 525+ and Isaac Sim 2023.1+ installed. The Isaac Sim backend is optional and not required for core JOSHUA functionality.

Data Collection System

The data collection system records teleoperation demonstrations for use in imitation learning pipelines. It integrates with the teleoperation mode to capture synchronized sensor data, joint states, and action commands, organizing them into episodes for structured dataset creation.

DataStore Configuration

The DataStore configuration specifies how recorded demonstrations are stored and organized. It controls episode indexing, file format, storage backend, and metadata tagging.

# Example DataStore configuration in .pbtxt
data_store {
  base_path: "/data/demonstrations"
  format: HUGGINGFACE
  episode_index: true
  metadata {
    robot_type: "so100_arm"
    task: "pick_and_place"
    operator: "human_expert"
  }
  storage {
    type: LOCAL
    # type: CLOUD  # for remote storage
  }
}

Episode Indexing

Each teleoperation session is recorded as a numbered episode. Episode indexing ensures organized dataset creation with consistent naming, automatic incrementing, and metadata association. Episodes contain timestamped frames of:

Dataset Formats

Format Description Use Case
HuggingFace Datasets Apache Arrow-backed columnar format with lazy loading Large-scale training with streaming support
CSV Plain-text comma-separated values Quick inspection, simple analysis scripts
JSONL Newline-delimited JSON for structured records Flexible schema, metadata-rich recordings
Parquet Columnar binary format with compression Efficient storage, fast analytical queries

Storage Backends

LeRobot Dataset Integration

The data collection system produces datasets directly compatible with the LeRobot dataset format. This enables seamless use of collected demonstrations for fine-tuning SmolVLA and other LeRobot-compatible models through imitation learning. The format includes standardized observation keys, action normalization statistics, and episode boundary markers.

Inference Pipeline Flow

The end-to-end inference pipeline transforms a system configuration into real-time autonomous robot behavior. Below is the complete data flow from startup to action execution:

INFERENCE PIPELINE FLOW
=======================

1. STARTUP
   .pbtxt Config
       |
       v
   Parse ModelConfig (type, checkpoint, topics, device)
       |
       v
   ModelRegistry.instance().get_model(type, config)
       |
       v
   Model class instantiation (loads weights, moves to GPU)
       |
       v
   ROS2 Inference Node created

2. RUNTIME (per cycle)

   /camera/front/image_raw ----+
   /camera/wrist/image_raw ----+---> handle_input()
   /joint_states --------------+         |
   /task_description ----------+         v
                                   preprocess_input()
                                     - Resize images
                                     - Normalize joints
                                     - Tokenize language
                                         |
                                         v
                                     inference()
                                     - Forward pass (GPU)
                                     - torch.no_grad()
                                         |
                                         v
                                   postprocess_output()
                                     - Denormalize actions
                                     - Clip to joint limits
                                         |
                                         v
                                   /action_commands (publish)
                                         |
                                         v
                                   Robot Actuators Execute
Latency Considerations The inference pipeline is designed for real-time operation. On an NVIDIA Jetson Orin Nano, SmolVLA inference typically completes within 10–50ms per cycle depending on input resolution and batch configuration. The ROS2 node's callback frequency is configurable to match the desired control rate.

Adding New Models

The AI system is designed for extensibility. Follow these steps to integrate a new model into the JOSHUA inference pipeline:

Step 1: Define the ModelType

Add a new entry to the ModelType enum in the Protocol Buffers schema:

// In proto/joshua.proto
enum ModelType {
  UNKNOWN_MODEL = 0;
  RANDOM_NOISE = 1;
  SMOL_VLA = 2;
  MY_NEW_MODEL = 3;  // Add your new model type
}

Step 2: Implement ModelBase

Create a new class inheriting from ModelBase and implement all required methods:

# In ai/models/my_new_model.py

from ai.model_base import ModelBase
from ai.model_registry import ModelRegistry, ModelType

@ModelRegistry.register(ModelType.MY_NEW_MODEL)
class MyNewModel(ModelBase):
    """Custom model implementation."""

    def __init__(self, config):
        super().__init__(config)
        # Load your model weights, initialize components
        self.net = load_my_network(config.checkpoint)

    def handle_input(self, ros_msg) -> dict:
        # Extract relevant data from ROS2 messages
        return {"sensor_data": ros_msg.data}

    def preprocess_input(self, raw_input: dict):
        # Normalize, transform, batch
        return self.transform(raw_input["sensor_data"])

    def inference(self, processed_input):
        # Run forward pass
        return self.net(processed_input)

    def postprocess_output(self, raw_output) -> dict:
        # Convert to action commands
        return {"actions": raw_output.tolist()}

Step 3: Register in ModelRegistry

The @ModelRegistry.register() decorator (shown above) handles registration automatically. Ensure that your model module is imported at startup so the decorator executes. Add the import to ai/models/__init__.py:

# In ai/models/__init__.py
from ai.models.random_noise import RandomNoiseModel
from ai.models.smolvla import SmolVLAModel
from ai.models.my_new_model import MyNewModel  # Add this import

Step 4: Create a Config Preset

Create a .pbtxt configuration preset that uses your new model:

# In config/presets/my_new_model_inference.pbtxt
operation_mode: AI_INFERENCE

ai {
  model {
    type: MY_NEW_MODEL
    checkpoint: "path/to/my_weights.pt"
    input_topics {
      name: "/camera/front/image_raw"
      msg_type: "sensor_msgs/Image"
    }
    input_topics {
      name: "/joint_states"
      msg_type: "sensor_msgs/JointState"
    }
    output_topic: "/action_commands"
    device: "cuda:0"
  }
}

Step 5: Verify Integration

Run the inference pipeline with the RandomNoise model first to verify the pipeline, then switch to your model:

# Test pipeline with RandomNoise
./run.sh --config config/presets/random_noise_test.pbtxt

# Switch to your new model
./run.sh --config config/presets/my_new_model_inference.pbtxt
Checklist for Adding New Models
  1. Add ModelType entry to the protobuf enum
  2. Inherit from ModelBase and implement all abstract methods
  3. Decorate with @ModelRegistry.register(ModelType.YOUR_TYPE)
  4. Add import to ai/models/__init__.py
  5. Create a .pbtxt config preset for the model
  6. Test with the inference pipeline end-to-end