Source code for protomotions.agents.ase.model

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import nn
from typing import List
from tensordict import TensorDict
from protomotions.agents.amp.model import Discriminator, AMPModel
from protomotions.agents.ase.config import ASEDiscriminatorEncoderConfig
from protomotions.agents.common.common import MultiOutputModule

DISC_LOGIT_INIT_SCALE = 1.0
ENC_LOGIT_INIT_SCALE = 0.1


[docs] class ASEDiscriminatorEncoder(Discriminator): """Discriminator with MI encoder head for ASE. Inherits from Discriminator and adds an encoder head for mutual information learning. """ config: ASEDiscriminatorEncoderConfig
[docs] def __init__(self, config: ASEDiscriminatorEncoderConfig): super().__init__(config) self._encoder_initialized = False
def _initialize_encoder_weights(self): """Initialize encoder weights after materialization.""" encoder = None final_module = self.sequential_models[-1] assert isinstance( final_module, MultiOutputModule ), "Final module must be a MultiOutputModule" for model in final_module.output_models: if model.out_keys[0] == "mi_enc_output": # Found the encoder module encoder = model break assert encoder is not None, "Encoder module not found" if ( not self._encoder_initialized and hasattr(encoder, "weight") and encoder.weight is not None ): torch.nn.init.uniform_( encoder.weight, -ENC_LOGIT_INIT_SCALE, ENC_LOGIT_INIT_SCALE ) torch.nn.init.zeros_(encoder.bias) self._encoder_initialized = True
[docs] def forward(self, tensordict: TensorDict) -> TensorDict: """Forward pass computing discriminator and MI encoder outputs. Args: tensordict: TensorDict containing observations and latents. Returns: TensorDict with disc_logits and mi_enc_output added. """ # Call parent (Discriminator forward) - adds disc_logits tensordict = super().forward(tensordict) # Initialize encoder weights after materialization self._initialize_encoder_weights() return tensordict
[docs] def compute_mi_reward( self, tensordict: TensorDict, mi_hypersphere_reward_shift: bool ): """Computes the Mutual Information based reward. Args: tensordict: TensorDict with mi_enc_output and latents. mi_hypersphere_reward_shift: Whether to shift reward to [0, 1]. Returns: torch.Tensor: Mutual Information reward tensor. """ enc_pred = tensordict["mi_enc_output"] latents = tensordict["latents"] neg_err = -self.calc_von_mises_fisher_enc_error(enc_pred, latents) if mi_hypersphere_reward_shift: reward = (neg_err + 1) / 2 else: reward = torch.clamp_min(neg_err, 0.0) return reward
[docs] def calc_von_mises_fisher_enc_error(self, enc_pred, latent): """Calculates the Von Mises-Fisher error between predicted and true latent vectors. Args: enc_pred (torch.Tensor): Predicted encoded latent vector. Shape (batch_size, latent_dim). latent (torch.Tensor): True latent vector. Shape (batch_size, latent_dim). Returns: torch.Tensor: Von Mises-Fisher error. Shape (batch_size, 1). """ err = enc_pred * latent err = -torch.sum(err, dim=-1, keepdim=True) return err
def _get_weights_from_module(self, module): """Helper to recursively get weights by explicitly traversing structure. Args: module: Module to extract weights from. Returns: List of weight parameters. """ weights = [] # If it's a SequentialModule, recursively process its sequential_models if hasattr(module, "sequential_models"): for sub_model in module.sequential_models: weights.extend(self._get_weights_from_module(sub_model)) # If it's a MultiInputModule, recursively process its input_models elif hasattr(module, "input_models") and isinstance( module.input_models, nn.ModuleList ): for sub_model in module.input_models: weights.extend(self._get_weights_from_module(sub_model)) # If it's a MultiOutputModule, recursively process its output_models elif hasattr(module, "output_models"): for sub_model in module.output_models: weights.extend(self._get_weights_from_module(sub_model)) # If it has an mlp Sequential, process that elif hasattr(module, "mlp") and isinstance(module.mlp, nn.Sequential): for layer in module.mlp: if hasattr(layer, "weight") and isinstance(layer.weight, nn.Parameter): weights.append(layer.weight) # Otherwise, check if this module itself has weights elif hasattr(module, "weight") and isinstance(module.weight, nn.Parameter): weights.append(module.weight) return weights
[docs] def all_weights(self): """Returns all weights from all sequential modules (trunk + discriminator + encoder). Uses explicit walking to avoid duplicates in nested structures. Returns: List[nn.Parameter]: List of all weight parameters. """ weights: list[nn.Parameter] = [] # Walk through all sequential models for seq_model in self.sequential_models: if isinstance(seq_model, MultiOutputModule): # Include all output models (both discriminator and encoder) for output_model in seq_model.output_models: weights.extend(self._get_weights_from_module(output_model)) else: weights.extend(self._get_weights_from_module(seq_model)) return weights
[docs] def all_discriminator_weights(self): """Returns weights of discriminator part only (excludes encoder head). Explicitly walks through sequential_models to avoid including encoder head. Returns: List[nn.Parameter]: List of discriminator weight parameters. """ weights: list[nn.Parameter] = [] # Walk through sequential models for seq_model in self.sequential_models: if isinstance(seq_model, MultiOutputModule): # Only include discriminator head, not encoder for output_model in seq_model.output_models: if ( hasattr(output_model, "out_keys") and "mi_enc_output" in output_model.out_keys ): continue # Skip encoder head # Include this output module's weights weights.extend(self._get_weights_from_module(output_model)) else: # Include all weights from this module weights.extend(self._get_weights_from_module(seq_model)) return weights
[docs] def logit_weights(self) -> List[nn.Parameter]: """Returns the weights of the final discriminator layer. Returns: List[nn.Parameter]: List containing the weight parameter of the discriminator's output layer. """ weights = [] # Find discriminator head in MultiOutputModule for seq_model in self.sequential_models: if isinstance(seq_model, MultiOutputModule): for output_model in seq_model.output_models: # Find the discriminator head (outputs disc_logits) if ( hasattr(output_model, "out_keys") and "disc_logits" in output_model.out_keys ): # Get the final layer weight if hasattr(output_model, "mlp") and len(output_model.mlp) > 0: final_layer = output_model.mlp[-1] if hasattr(final_layer, "weight"): weights.append(final_layer.weight) break break return weights
[docs] def all_enc_weights(self): """Returns all weights of the encoder part only (includes trunk + encoder head). Returns: List[nn.Parameter]: List of encoder weight parameters. """ weights: list[nn.Parameter] = [] # Get trunk weights (all Sequential modules before MultiOutput) for seq_model in self.sequential_models: if isinstance(seq_model, MultiOutputModule): # Found MultiOutput, only get encoder head weights for output_model in seq_model.output_models: if ( hasattr(output_model, "out_keys") and "mi_enc_output" in output_model.out_keys ): # This is the encoder head weights.extend(self._get_weights_from_module(output_model)) break # Don't continue past MultiOutput else: # Include trunk weights weights.extend(self._get_weights_from_module(seq_model)) return weights
[docs] def enc_weights(self) -> List[nn.Parameter]: """Returns the weights of the final encoder layer only. Returns: List[nn.Parameter]: List containing the weight parameter of the encoder's output layer. """ weights = [] # Find the encoder head in MultiOutputModule for seq_model in self.sequential_models: if isinstance(seq_model, MultiOutputModule): for output_model in seq_model.output_models: if ( hasattr(output_model, "out_keys") and "mi_enc_output" in output_model.out_keys ): # Get the final layer weight if hasattr(output_model, "mlp") and len(output_model.mlp) > 0: final_layer = output_model.mlp[-1] if hasattr(final_layer, "weight"): weights.append(final_layer.weight) break break return weights
# ASE just uses AMPModel - the discriminator is ASEDiscriminatorEncoder instead ASEModel = AMPModel