protomotions.agents.masked_mimic.agent module¶
MaskedMimic agent implementation for versatile motion control.
This module implements the MaskedMimic algorithm which learns to reconstruct expert tracker actions from partial observations. It trains on data from a full-body motion tracker while randomly masking observations, enabling the agent to handle diverse control tasks from incomplete information.
- Training Process:
Phase 1: Train expert full-body tracker (separate)
Phase 2: Train MaskedMimic to imitate expert with masked observations
- Key Classes:
MaskedMimic: Main MaskedMimic agent class extending BaseAgent
References
Tessler et al. “MaskedMimic: Unified Physics-Based Character Control Through Masked Motion Inpainting” (2024)
- class protomotions.agents.masked_mimic.agent.MaskedMimic(fabric, env, config, root_dir=None)[source]¶
Bases:
BaseAgentMaskedMimic agent for versatile motion control.
Learns to reconstruct expert tracker actions from partial observations by training on data from a full-body motion tracker. The agent uses masked observations (randomly occluded body parts or features) and learns to infer the complete action from incomplete information. This enables versatile control where the agent can respond to various types of motion commands.
Training process: 1. Phase 1: Train expert full-body tracker (separate training) 2. Phase 2: Train MaskedMimic to imitate expert with masked observations
Key features: - Masked observations: Randomly masks input features during training - Action reconstruction: Learns to predict expert tracker actions - Optional VAE: Can use VAE latent codes for additional control - Versatile control: Single policy handles diverse motion tasks
- Parameters:
fabric (MockFabric) – Lightning Fabric instance for distributed training.
env (Mimic) – Mimic environment for motion tracking.
config (MaskedMimicAgentConfig) – MaskedMimic configuration including expert model path and masking parameters.
root_dir (Path | None) – Optional root directory for saving outputs.
- expert_model¶
Pre-trained full-body tracker model (loaded from config).
- vae_noise¶
VAE latent codes for each environment (if using VAE).
Example
>>> fabric = Fabric(devices=4) >>> env = Mimic(config, robot_config, simulator_config, device) >>> config.expert_model_path = "results/expert_tracker/" >>> agent = MaskedMimic(fabric, env, config) >>> agent.setup() >>> agent.train()
Note
Requires pre-trained expert tracker model specified in config.expert_model_path.
- env: Mimic¶
- model: MaskedMimicModel¶
- __init__(fabric, env, config, root_dir=None)[source]¶
Initialize the base agent.
Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.
- Parameters:
fabric (MockFabric) – Lightning Fabric for distributed training and device management.
env (Mimic) – Environment instance for agent-environment interaction.
config (MaskedMimicAgentConfig) – Configuration containing hyperparameters and training settings.
root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).
- config: MaskedMimicAgentConfig¶
- post_env_step_modifications(dones, terminated, extras)[source]¶
Allow subclasses to modify dones/terminated after env.step().
This hook allows algorithm-specific modifications (e.g., AMP discriminator termination).
- Parameters:
dones – Reset flags from environment
terminated – Termination flags from environment
extras – Info dictionary from environment
- Returns:
Modified (dones, terminated, extras) tuple
- add_agent_info_to_obs(obs)[source]¶
Add agent-specific observations to the environment observations.
- load_parameters(state_dict)[source]¶
Load agent parameters from state dictionary.
Restores training state including epoch counter, step count, timing info, best scores, normalization statistics, and model weights.
- Parameters:
state_dict – Dictionary containing saved agent state from checkpoint. Expected keys: epoch, step_count, run_start_time, best_evaluated_score, running_reward_norm (if normalization enabled), model.
- register_algorithm_experience_buffer_keys()[source]¶
Register algorithm-specific keys in the experience buffer.
Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).
- collect_rollout_step(obs_td, step)[source]¶
Collect MaskedMimic-specific data: policy actions and expert actions.
- get_state_dict(state_dict)[source]¶
Get complete state dictionary for checkpointing.
Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.
- Parameters:
state_dict – Existing state dict to update (typically empty dict).
- Returns:
Updated state dictionary containing all agent state.