protomotions.agents.ase.agent module

Adversarial Skill Embeddings (ASE) agent implementation.

This module implements the ASE algorithm which extends AMP with learned skill embeddings. The discriminator encodes motions into a latent skill space, and the policy is conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for complex tasks.

Key Classes:
  • ASE: Main ASE agent class extending AMP

References

Peng et al. “ASE: Large-Scale Reusable Adversarial Skill Embeddings for Physically Simulated Characters” (2022)

class protomotions.agents.ase.agent.ASE(fabric, env, config, root_dir=None)[source]

Bases: AMP

Adversarial Skill Embeddings (ASE) agent.

Extends AMP with a low-level policy conditioned on learned skill embeddings. The discriminator learns to encode skills from motion data into a latent space, while the policy learns to execute behaviors conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for tasks.

Key components: - Low-level policy: Conditioned on latent skill codes - Discriminator: Encodes motions into skill embeddings - Mutual information: Encourages skill diversity - Latent sampling: Periodically samples new skills during rollouts

Parameters:
  • fabric (MockFabric) – Lightning Fabric instance for distributed training.

  • env (BaseEnv) – Environment instance with diverse motion library.

  • config – ASE-specific configuration including latent dimensions.

  • root_dir (Path | None) – Optional root directory for saving outputs.

latents

Current latent skill codes for each environment.

latent_reset_steps

Steps until next latent resample.

Example

>>> fabric = Fabric(devices=4)
>>> env = Mimic(config, robot_config, simulator_config, device)
>>> agent = ASE(fabric, env, config)
>>> agent.setup()
>>> agent.train()

Note

Requires large diverse motion dataset for effective skill learning.

model: AMPModel
discriminator: ASEDiscriminatorEncoder
config: ASEAgentConfig
__init__(fabric, env, config, root_dir=None)[source]

Initialize the base agent.

Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.

Parameters:
  • fabric (MockFabric) – Lightning Fabric for distributed training and device management.

  • env (BaseEnv) – Environment instance for agent-environment interaction.

  • config – Configuration containing hyperparameters and training settings.

  • root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).

setup()[source]
load_parameters(state_dict)[source]

Load PPO-specific parameters from checkpoint.

Loads actor, critic, and optimizer states. Preserves config overrides for actor_logstd if specified at command line.

Parameters:

state_dict – Checkpoint state dictionary containing model and optimizer states.

get_state_dict(state_dict)[source]

Get complete state dictionary for checkpointing.

Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.

Parameters:

state_dict – Existing state dict to update (typically empty dict).

Returns:

Updated state dictionary containing all agent state.

update_latents()[source]

Updates latent variables based on latent reset steps.

reset_latents(env_ids=None)[source]

Resets latent variables for specified environments or all environments if None.

Parameters:

env_ids (torch.Tensor, optional) – Environment indices to reset latents for. Defaults to None (all envs).

store_latents(latents, env_ids)[source]

Stores latent variables for specified environments.

Parameters:
  • latents (torch.Tensor) – Latent variables to store. Shape (num_envs, latent_dim).

  • env_ids (torch.Tensor) – Environment indices to store latents for. Shape (num_envs,).

sample_latents(n)[source]

Samples new latent variables uniformly on the unit-sphere.

Parameters:

n (int) – Number of latent variables to sample.

Returns:

Sampled latent variables. Shape (n, latent_dim).

Return type:

torch.Tensor

mi_enc_forward(obs_dict)[source]

Forward pass through the Mutual Information encoder.

Parameters:

obs_dict (dict) – Dictionary containing observations.

Returns:

Encoded observation tensor. Shape (batch_size, encoder_output_dim).

Return type:

Tensor

register_algorithm_experience_buffer_keys()[source]

Register algorithm-specific keys in the experience buffer.

Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).

add_agent_info_to_obs(obs)[source]

Perform an environment step and inject current latents into observations.

get_combined_experience_buffer_rewards()[source]
get_expert_disc_obs(num_samples)[source]
produce_negative_expert_obs(batch_dict)[source]
discriminator_step(batch_dict)[source]

Performs a discriminator update step.

Parameters:

batch_dict (dict) – Batch of data from the experience buffer.

Returns:

Discriminator loss and logging dictionary.

Return type:

Tuple[Tensor, Dict]

compute_uniformity_loss(encodings)[source]

Computes uniformity loss to encourage uniform distribution on unit sphere.

Parameters:

encodings (Tensor) – Normalized encodings on unit sphere. Shape (batch_size, latent_dim).

Returns:

Uniformity loss value.

Return type:

Tensor

calculate_extra_actor_loss(batch_td)[source]

Adds the diversity loss, if enabled.

Parameters:

batch_td (TensorDict) – Batch of data from the experience buffer and the actor.

Returns:

Extra actor loss and logging dictionary.

Return type:

Tuple[Tensor, Dict]

diversity_loss(batch_td)[source]

Calculates the diversity loss to encourage latents to lead to diverse behaviors.

Parameters:

batch_td (TensorDict) – Batch of data from the experience buffer and the actor.

Returns:

Diversity loss.

Return type:

Tensor

post_epoch_logging(training_log_dict)[source]

Performs post epoch logging, including Mutual Information reward logging.

Parameters:

training_log_dict (Dict) – Dictionary to update with logging information.