protomotions.agents.amp.agent module

Adversarial Motion Priors (AMP) agent implementation.

This module implements the AMP algorithm which extends PPO with a discriminator network that provides style rewards. The discriminator learns to distinguish between agent and reference motion data, encouraging the agent to produce naturalistic movements while accomplishing tasks.

Key Classes:
  • AMP: Main AMP agent class extending PPO

References

Peng et al. “AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control” (2021)

class protomotions.agents.amp.agent.AMP(fabric, env, config, root_dir=None)[source]

Bases: PPO

Adversarial Motion Priors (AMP) agent.

Extends PPO with a discriminator network that learns to distinguish between agent and reference motion data. The discriminator provides a style reward that encourages the agent to produce motions with similar characteristics to the reference dataset. This enables training agents that perform tasks while maintaining natural motion styles.

The agent combines task rewards with discriminator-based style rewards: - Task reward: From environment (e.g., reaching a target) - Style reward: From discriminator (similarity to reference motions)

Parameters:
  • fabric (MockFabric) – Lightning Fabric instance for distributed training.

  • env (BaseEnv) – Environment instance with motion library for reference data.

  • config (AMPAgentConfig) – AMP-specific configuration including discriminator parameters.

  • root_dir (Path | None) – Optional root directory for saving outputs.

amp_replay_buffer

Replay buffer storing agent transitions for discriminator training.

discriminator

Network that distinguishes agent from reference motions.

Example

>>> fabric = Fabric(devices=4)
>>> env = Mimic(config, robot_config, simulator_config, device)
>>> agent = AMP(fabric, env, config)
>>> agent.setup()
>>> agent.train()

Note

Requires environment with motion library (motion_lib) for sampling reference data.

config: AMPAgentConfig
__init__(fabric, env, config, root_dir=None)[source]

Initialize the base agent.

Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.

Parameters:
  • fabric (MockFabric) – Lightning Fabric for distributed training and device management.

  • env (BaseEnv) – Environment instance for agent-environment interaction.

  • config – Configuration containing hyperparameters and training settings.

  • root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).

create_optimizers(model)[source]

Create separate optimizers for actor and critic.

Sets up Adam optimizers for policy and value networks with independent learning rates. Uses Fabric for distributed training setup.

Parameters:

model (AMPModel) – PPOModel with actor and critic networks.

load_parameters(state_dict)[source]

Load PPO-specific parameters from checkpoint.

Loads actor, critic, and optimizer states. Preserves config overrides for actor_logstd if specified at command line.

Parameters:

state_dict – Checkpoint state dictionary containing model and optimizer states.

get_state_dict(state_dict)[source]

Get complete state dictionary for checkpointing.

Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.

Parameters:

state_dict – Existing state dict to update (typically empty dict).

Returns:

Updated state dictionary containing all agent state.

register_algorithm_experience_buffer_keys()[source]

Register algorithm-specific keys in the experience buffer.

Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).

update_disc_replay_buffer(data_dict)[source]
get_expert_disc_obs(num_samples)[source]
post_env_step_modifications(dones, terminated, extras)[source]

Add AMP-specific discriminator-based termination.

get_combined_experience_buffer_rewards()[source]
perform_optimization_step(batch_dict, batch_idx)[source]

Perform one PPO optimization step on a minibatch.

Computes actor and critic losses, performs backpropagation, clips gradients, and updates both networks.

Parameters:
  • batch_dict – Dictionary containing minibatch data (obs, actions, advantages, etc.).

  • batch_idx (int) – Index of current batch (unused but kept for compatibility).

Returns:

Dictionary of training metrics (losses, clip fraction, etc.).

Return type:

Dict

discriminator_step(batch_dict)[source]
static compute_pos_acc(positive_logit)[source]
static compute_neg_acc(negative_logit)[source]
terminate_early()[source]

Request early termination of training.

Sets a flag that will cause the training loop to exit gracefully after the current epoch completes.

post_epoch_logging(training_log_dict)[source]