protomotions.agents.amp.model module¶
AMP model components including discriminator network.
This module implements the AMP-specific neural networks, particularly the discriminator that distinguishes between agent and reference motion data.
- Key Classes:
Discriminator: Binary classifier for agent vs. reference motions
AMPModel: PPO model extended with discriminator
- class protomotions.agents.amp.model.Discriminator(*args, **kwargs)[source]¶
Bases:
TensorDictModuleBaseDiscriminator network for AMP style rewards.
Binary classifier that distinguishes between agent-generated and reference motion data. Uses SequentialModule structure - just chains modules together.
- Parameters:
config (DiscriminatorConfig) – DiscriminatorConfig (extends SequentialModuleConfig).
- sequential_models¶
Sequential list of modules.
- in_keys¶
Input keys from config.
- out_keys¶
Output keys from config.
- config: DiscriminatorConfig¶
- forward(tensordict)[source]¶
Forward pass through discriminator.
- Parameters:
tensordict (MockTensorDict) – TensorDict containing observations.
- Returns:
TensorDict with discriminator output added.
- Return type:
MockTensorDict
- compute_disc_reward(disc_logits, eps=0.0001)[source]¶
Compute style reward from discriminator logits.
Converts discriminator logits to reward using negative log probability. Higher reward means motion is more similar to reference data.
- Parameters:
disc_logits (MockTensor) – Discriminator logits.
eps (float) – Small constant for numerical stability.
- Returns:
Style rewards for each sample (higher = more reference-like).
- Return type:
MockTensor
- class protomotions.agents.amp.model.AMPModel(*args, **kwargs)[source]¶
Bases:
PPOModelAMP model with actor, critic, and discriminator networks.
Extends PPOModel by adding a discriminator network that provides style rewards. The complete model includes policy, value function, and style discriminator.
- Parameters:
config (AMPModelConfig) – AMPModelConfig specifying all three networks.
- _actor¶
Policy network.
- _critic¶
Value network.
- _discriminator¶
Style discriminator network.
- config: AMPModelConfig¶