protomotions.agents.ase.agent module¶
Adversarial Skill Embeddings (ASE) agent implementation.
This module implements the ASE algorithm which extends AMP with learned skill embeddings. The discriminator encodes motions into a latent skill space, and the policy is conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for complex tasks.
- Key Classes:
ASE: Main ASE agent class extending AMP
References
Peng et al. “ASE: Large-Scale Reusable Adversarial Skill Embeddings for Physically Simulated Characters” (2022)
- class protomotions.agents.ase.agent.ASE(fabric, env, config, root_dir=None)[source]¶
Bases:
AMPAdversarial Skill Embeddings (ASE) agent.
Extends AMP with a low-level policy conditioned on learned skill embeddings. The discriminator learns to encode skills from motion data into a latent space, while the policy learns to execute behaviors conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for tasks.
Key components: - Low-level policy: Conditioned on latent skill codes - Discriminator: Encodes motions into skill embeddings - Mutual information: Encourages skill diversity - Latent sampling: Periodically samples new skills during rollouts
- Parameters:
fabric (MockFabric) – Lightning Fabric instance for distributed training.
env (BaseEnv) – Environment instance with diverse motion library.
config – ASE-specific configuration including latent dimensions.
root_dir (Path | None) – Optional root directory for saving outputs.
- latents¶
Current latent skill codes for each environment.
- latent_reset_steps¶
Steps until next latent resample.
Example
>>> fabric = Fabric(devices=4) >>> env = Mimic(config, robot_config, simulator_config, device) >>> agent = ASE(fabric, env, config) >>> agent.setup() >>> agent.train()
Note
Requires large diverse motion dataset for effective skill learning.
- discriminator: ASEDiscriminatorEncoder¶
- config: ASEAgentConfig¶
- __init__(fabric, env, config, root_dir=None)[source]¶
Initialize the base agent.
Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.
- Parameters:
fabric (MockFabric) – Lightning Fabric for distributed training and device management.
env (BaseEnv) – Environment instance for agent-environment interaction.
config – Configuration containing hyperparameters and training settings.
root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).
- load_parameters(state_dict)[source]¶
Load PPO-specific parameters from checkpoint.
Loads actor, critic, and optimizer states. Preserves config overrides for actor_logstd if specified at command line.
- Parameters:
state_dict – Checkpoint state dictionary containing model and optimizer states.
- get_state_dict(state_dict)[source]¶
Get complete state dictionary for checkpointing.
Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.
- Parameters:
state_dict – Existing state dict to update (typically empty dict).
- Returns:
Updated state dictionary containing all agent state.
- reset_latents(env_ids=None)[source]¶
Resets latent variables for specified environments or all environments if None.
- Parameters:
env_ids (torch.Tensor, optional) – Environment indices to reset latents for. Defaults to None (all envs).
- store_latents(latents, env_ids)[source]¶
Stores latent variables for specified environments.
- Parameters:
latents (torch.Tensor) – Latent variables to store. Shape (num_envs, latent_dim).
env_ids (torch.Tensor) – Environment indices to store latents for. Shape (num_envs,).
- sample_latents(n)[source]¶
Samples new latent variables uniformly on the unit-sphere.
- Parameters:
n (int) – Number of latent variables to sample.
- Returns:
Sampled latent variables. Shape (n, latent_dim).
- Return type:
- mi_enc_forward(obs_dict)[source]¶
Forward pass through the Mutual Information encoder.
- Parameters:
obs_dict (dict) – Dictionary containing observations.
- Returns:
Encoded observation tensor. Shape (batch_size, encoder_output_dim).
- Return type:
Tensor
- register_algorithm_experience_buffer_keys()[source]¶
Register algorithm-specific keys in the experience buffer.
Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).
- add_agent_info_to_obs(obs)[source]¶
Perform an environment step and inject current latents into observations.
- discriminator_step(batch_dict)[source]¶
Performs a discriminator update step.
- Parameters:
batch_dict (dict) – Batch of data from the experience buffer.
- Returns:
Discriminator loss and logging dictionary.
- Return type:
Tuple[Tensor, Dict]
- compute_uniformity_loss(encodings)[source]¶
Computes uniformity loss to encourage uniform distribution on unit sphere.
- Parameters:
encodings (Tensor) – Normalized encodings on unit sphere. Shape (batch_size, latent_dim).
- Returns:
Uniformity loss value.
- Return type:
Tensor
- calculate_extra_actor_loss(batch_td)[source]¶
Adds the diversity loss, if enabled.
- Parameters:
batch_td (TensorDict) – Batch of data from the experience buffer and the actor.
- Returns:
Extra actor loss and logging dictionary.
- Return type:
Tuple[Tensor, Dict]