Source code for protomotions.agents.amp.model

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""AMP model components including discriminator network.

This module implements the AMP-specific neural networks, particularly the
discriminator that distinguishes between agent and reference motion data.

Key Classes:
    - Discriminator: Binary classifier for agent vs. reference motions
    - AMPModel: PPO model extended with discriminator
"""

import torch
from torch import nn
from typing import List
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from protomotions.utils.hydra_replacement import get_class
from protomotions.agents.ppo.model import PPOModel
from protomotions.agents.amp.config import DiscriminatorConfig, AMPModelConfig


[docs] class Discriminator(TensorDictModuleBase): """Discriminator network for AMP style rewards. Binary classifier that distinguishes between agent-generated and reference motion data. Uses SequentialModule structure - just chains modules together. Args: config: DiscriminatorConfig (extends SequentialModuleConfig). Attributes: sequential_models: Sequential list of modules. in_keys: Input keys from config. out_keys: Output keys from config. """ config: DiscriminatorConfig
[docs] def __init__(self, config: DiscriminatorConfig): super().__init__() self.config = config # Build sequential modules sequential_models = [] for input_model in config.input_models: model = get_class(input_model._target_)(config=input_model) sequential_models.append(model) self.sequential_models = nn.ModuleList(sequential_models) # Set TensorDict keys from config self.in_keys = self.config.in_keys self.out_keys = self.config.out_keys
[docs] def forward(self, tensordict: TensorDict) -> TensorDict: """Forward pass through discriminator. Args: tensordict: TensorDict containing observations. Returns: TensorDict with discriminator output added. """ # Chain through all modules for model in self.sequential_models: tensordict = model(tensordict) return tensordict
[docs] def compute_disc_reward( self, disc_logits: torch.Tensor, eps: float = 1e-4 ) -> torch.Tensor: """Compute style reward from discriminator logits. Converts discriminator logits to reward using negative log probability. Higher reward means motion is more similar to reference data. Args: disc_logits: Discriminator logits. eps: Small constant for numerical stability. Returns: Style rewards for each sample (higher = more reference-like). """ prob = 1 / (1 + torch.exp(-disc_logits)) reward = -torch.log(torch.clamp(1 - prob, min=eps)) return reward
[docs] def all_discriminator_weights(self): """Get all discriminator weight matrices (works with LazyLinear). Returns: List of weight parameters from all linear layers in discriminator. """ weights: list[nn.Parameter] = [] for mod in self.modules(): if hasattr(mod, "weight") and isinstance(mod.weight, nn.Parameter): weights.append(mod.weight) return weights
[docs] def logit_weights(self) -> List[nn.Parameter]: """Get the final layer weights (logit layer). Returns: List containing the output layer weight parameter. """ last_module = self.sequential_models[-1] if hasattr(last_module, "mlp"): last_module = last_module.mlp[-1] return [last_module.weight]
[docs] class AMPModel(PPOModel): """AMP model with actor, critic, and discriminator networks. Extends PPOModel by adding a discriminator network that provides style rewards. The complete model includes policy, value function, and style discriminator. Args: config: AMPModelConfig specifying all three networks. Attributes: _actor: Policy network. _critic: Value network. _discriminator: Style discriminator network. """ config: AMPModelConfig
[docs] def __init__(self, config: AMPModelConfig): super().__init__(config) DiscriminatorClass = get_class(config.discriminator._target_) self._discriminator: Discriminator = DiscriminatorClass( config=self.config.discriminator ) # Set in_keys from actor (actor inherits from mu model) discriminator_in_keys = self._discriminator.in_keys discriminator_out_keys = self._discriminator.out_keys for key in discriminator_out_keys: assert ( key in self.config.out_keys ), f"Discriminator output key {key} not in out_keys {self.config.out_keys}" for key in discriminator_in_keys: assert ( key in self.config.in_keys ), f"Discriminator input key {key} not in in_keys {self.config.in_keys}"
[docs] def forward(self, tensordict: TensorDict) -> TensorDict: """Forward pass through PPO, and discriminator. Args: tensordict: TensorDict containing observations. Returns: TensorDict with all model outputs added. """ tensordict = super().forward(tensordict) # Discriminator forward: adds discriminator output tensordict = self._discriminator(tensordict) return tensordict