Source code for protomotions.agents.ase.agent

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Adversarial Skill Embeddings (ASE) agent implementation.

This module implements the ASE algorithm which extends AMP with learned skill embeddings.
The discriminator encodes motions into a latent skill space, and the policy is conditioned
on these latent codes. This enables learning diverse skills from motion data and composing
them for complex tasks.

Key Classes:
    - ASE: Main ASE agent class extending AMP

References:
    Peng et al. "ASE: Large-Scale Reusable Adversarial Skill Embeddings for Physically Simulated Characters" (2022)
"""

import torch
import logging

from torch import Tensor
from tensordict import TensorDict

from lightning.fabric import Fabric

from protomotions.agents.ase.model import ASEModel, ASEDiscriminatorEncoder
from protomotions.agents.utils.step_tracker import StepTracker
from protomotions.envs.base_env.env import BaseEnv
from protomotions.agents.amp.agent import AMP
from typing import Tuple, Dict, Optional
from pathlib import Path
from protomotions.agents.ase.config import ASEAgentConfig
from protomotions.agents.utils.normalization import RewardRunningMeanStd


log = logging.getLogger(__name__)


[docs] class ASE(AMP): """Adversarial Skill Embeddings (ASE) agent. Extends AMP with a low-level policy conditioned on learned skill embeddings. The discriminator learns to encode skills from motion data into a latent space, while the policy learns to execute behaviors conditioned on these latent codes. This enables learning diverse skills from motion data and composing them for tasks. Key components: - **Low-level policy**: Conditioned on latent skill codes - **Discriminator**: Encodes motions into skill embeddings - **Mutual information**: Encourages skill diversity - **Latent sampling**: Periodically samples new skills during rollouts Args: fabric: Lightning Fabric instance for distributed training. env: Environment instance with diverse motion library. config: ASE-specific configuration including latent dimensions. root_dir: Optional root directory for saving outputs. Attributes: latents: Current latent skill codes for each environment. latent_reset_steps: Steps until next latent resample. Example: >>> fabric = Fabric(devices=4) >>> env = Mimic(config, robot_config, simulator_config, device) >>> agent = ASE(fabric, env, config) >>> agent.setup() >>> agent.train() Note: Requires large diverse motion dataset for effective skill learning. """ model: ASEModel discriminator: ASEDiscriminatorEncoder config: ASEAgentConfig # ----------------------------- # Initialization and Setup # -----------------------------
[docs] def __init__( self, fabric: Fabric, env: BaseEnv, config, root_dir: Optional[Path] = None ): super().__init__(fabric, env, config, root_dir=root_dir) if self.config.normalize_rewards: self.running_mi_enc_norm = RewardRunningMeanStd( shape=(1,), fabric=self.fabric, gamma=self.gamma, device=self.device, clamp_value=self.config.normalized_reward_clamp_value, ) else: self.running_mi_enc_norm = None
[docs] def setup(self): self.latents = torch.zeros( (self.num_envs, self.config.ase_parameters.latent_dim), dtype=torch.float, device=self.device, ) self.latent_reset_steps = StepTracker( self.num_envs, min_steps=self.config.ase_parameters.latent_steps_min, max_steps=self.config.ase_parameters.latent_steps_max, device=self.device, ) super().setup()
[docs] def load_parameters(self, state_dict): super().load_parameters(state_dict) if self.config.normalize_rewards: self.running_mi_enc_norm.load_state_dict(state_dict["running_mi_enc_norm"])
[docs] def get_state_dict(self, state_dict): extra_state_dict = super().get_state_dict(state_dict) if self.config.normalize_rewards: extra_state_dict["running_mi_enc_norm"] = ( self.running_mi_enc_norm.state_dict() ) state_dict.update(extra_state_dict) return state_dict
# ----------------------------- # Latent Management # -----------------------------
[docs] def update_latents(self): """Updates latent variables based on latent reset steps.""" self.latent_reset_steps.advance() reset_ids = self.latent_reset_steps.done_indices() if reset_ids.numel() > 0: self.reset_latents(reset_ids) self.latent_reset_steps.reset_steps(reset_ids)
[docs] def reset_latents(self, env_ids=None): """Resets latent variables for specified environments or all environments if None. Args: env_ids (torch.Tensor, optional): Environment indices to reset latents for. Defaults to None (all envs). """ if env_ids is None: env_ids = torch.arange(self.num_envs, device=self.device) latents = self.sample_latents(len(env_ids)) self.store_latents(latents, env_ids)
[docs] def store_latents(self, latents, env_ids): """Stores latent variables for specified environments. Args: latents (torch.Tensor): Latent variables to store. Shape (num_envs, latent_dim). env_ids (torch.Tensor): Environment indices to store latents for. Shape (num_envs,). """ self.latents[env_ids] = latents
[docs] def sample_latents(self, n): """Samples new latent variables uniformly on the unit-sphere. Args: n (int): Number of latent variables to sample. Returns: torch.Tensor: Sampled latent variables. Shape (n, latent_dim). """ zeros = torch.zeros( [n, self.config.ase_parameters.latent_dim], device=self.device ) gaussian_sample = torch.normal(zeros) latents = torch.nn.functional.normalize(gaussian_sample, dim=-1) return latents
[docs] def mi_enc_forward(self, obs_dict: dict) -> Tensor: """Forward pass through the Mutual Information encoder. Args: obs_dict: Dictionary containing observations. Returns: Tensor: Encoded observation tensor. Shape (batch_size, encoder_output_dim). """ # Convert to TensorDict and forward through discriminator batch_size = obs_dict[list(obs_dict.keys())[0]].shape[0] obs_td = TensorDict(obs_dict, batch_size=batch_size) obs_td = self.discriminator(obs_td) return obs_td["mi_enc_output"]
# ----------------------------- # Experience Buffer and Dataset Processing # -----------------------------
[docs] def register_algorithm_experience_buffer_keys(self): super().register_algorithm_experience_buffer_keys() if self.config.normalize_rewards: self.experience_buffer.register_key("unnormalized_mi_rewards") self.experience_buffer.register_key("mi_rewards") # mi = mutual information
# ----------------------------- # Environment Interaction # -----------------------------
[docs] def add_agent_info_to_obs(self, obs): """Perform an environment step and inject current latents into observations.""" obs = super().add_agent_info_to_obs(obs) self.update_latents() obs["latents"] = self.latents.clone() return obs
# ----------------------------- # Reward Calculation # ----------------------------- @torch.no_grad() def record_rollout_step( self, next_obs_td, actions, rewards, dones, terminated, done_indices, extras, step, ): """Record ASE-specific data: mutual information rewards for skill diversity.""" super().record_rollout_step( next_obs_td, actions, rewards, dones, terminated, done_indices, extras, step ) # Get encoder outputs next_obs_td = self.discriminator(next_obs_td) # Compute MI reward from outputs mi_r = self.discriminator.compute_mi_reward( next_obs_td, self.config.ase_parameters.mi_hypersphere_reward_shift ).view(-1) if self.config.normalize_rewards: self.experience_buffer.update_data("unnormalized_mi_rewards", step, mi_r) self.running_mi_enc_norm.record_reward(mi_r, terminated) mi_r = self.running_mi_enc_norm.normalize(mi_r) self.experience_buffer.update_data("mi_rewards", step, mi_r)
[docs] def get_combined_experience_buffer_rewards(self): rewards = super().get_combined_experience_buffer_rewards() return ( rewards + self.experience_buffer.mi_rewards * self.config.ase_parameters.mi_reward_w )
# ----------------------------- # Optimization # -----------------------------
[docs] def get_expert_disc_obs(self, num_samples: int): expert_disc_obs = super().get_expert_disc_obs(num_samples) # Add the predicted latents to the batch dict. This produces a conditional discriminator, conditioned on the latents. mi_enc_pred = self.mi_enc_forward(expert_disc_obs) expert_disc_obs["latents"] = mi_enc_pred.clone() return expert_disc_obs
[docs] def produce_negative_expert_obs(self, batch_dict): negative_expert_obs = {} for key in self.config.amp_parameters.discriminator_keys: negative_expert_obs[key] = batch_dict[f"expert_{key}"][ : self.config.amp_parameters.discriminator_batch_size ] random_conditioned_latent = torch.rand_like( batch_dict["latents"][: self.config.amp_parameters.discriminator_batch_size] ) projected_latent = torch.nn.functional.normalize( random_conditioned_latent, dim=-1 ) negative_expert_obs["latents"] = projected_latent return negative_expert_obs
[docs] def discriminator_step(self, batch_dict): """Performs a discriminator update step. Args: batch_dict (dict): Batch of data from the experience buffer. Returns: Tuple[Tensor, Dict]: Discriminator loss and logging dictionary. """ discriminator_loss, discriminator_log_dict = super().discriminator_step( batch_dict ) agent_obs = batch_dict["historical_self_obs"][ : self.config.amp_parameters.discriminator_batch_size ] latents = batch_dict["latents"][ : self.config.amp_parameters.discriminator_batch_size ] expert_obs = batch_dict["expert_historical_self_obs"][ : self.config.amp_parameters.discriminator_batch_size ] if self.config.ase_parameters.mi_enc_grad_penalty > 0: agent_obs.requires_grad_(True) # Compute MI encoder predictions for both agent and expert observations agent_disc_obs = {"historical_self_obs": agent_obs, "latents": latents} mi_enc_pred_agent = self.mi_enc_forward(agent_disc_obs) expert_disc_obs = {"historical_self_obs": expert_obs, "latents": latents} mi_enc_pred_expert = self.mi_enc_forward(expert_disc_obs) # Original MI encoder loss for agent observations mi_enc_err = self.discriminator.calc_von_mises_fisher_enc_error( mi_enc_pred_agent, latents ) mi_enc_loss = torch.mean(mi_enc_err) # Uniformity loss for uniform coverage on unit sphere # Concatenate agent and expert encodings for uniform distribution all_encodings = torch.cat([mi_enc_pred_agent, mi_enc_pred_expert], dim=0) uniformity_loss = self.compute_uniformity_loss(all_encodings) total_mi_loss = ( mi_enc_loss + uniformity_loss * self.config.ase_parameters.latent_uniformity_weight ) if self.config.ase_parameters.mi_enc_weight_decay > 0: enc_weight_params = self.discriminator.enc_weights() total: Tensor = sum([p.pow(2).sum() for p in enc_weight_params]) weight_decay_loss: Tensor = ( total * self.config.ase_parameters.mi_enc_weight_decay ) else: weight_decay_loss = torch.tensor(0.0, device=self.device) if self.config.ase_parameters.mi_enc_grad_penalty > 0: mi_enc_obs_grad = torch.autograd.grad( mi_enc_err, agent_obs, grad_outputs=torch.ones_like(mi_enc_err), create_graph=True, retain_graph=True, )[0] mi_enc_obs_grad = torch.sum(torch.square(mi_enc_obs_grad), dim=-1) mi_enc_grad_penalty = torch.mean(mi_enc_obs_grad) grad_loss: Tensor = ( mi_enc_grad_penalty * self.config.ase_parameters.mi_enc_grad_penalty ) else: grad_loss = torch.tensor(0.0, device=self.device) mi_loss = total_mi_loss + weight_decay_loss + grad_loss log_dict = { "encoder/loss": mi_loss.detach(), "encoder/mi_enc_loss": mi_enc_loss.detach(), "encoder/uniformity_loss": uniformity_loss.detach(), "encoder/l2_loss": weight_decay_loss.detach(), "encoder/grad_penalty": grad_loss.detach(), } discriminator_log_dict.update(log_dict) return mi_loss + discriminator_loss, discriminator_log_dict
[docs] def compute_uniformity_loss(self, encodings: Tensor) -> Tensor: """Computes uniformity loss to encourage uniform distribution on unit sphere. Args: encodings (Tensor): Normalized encodings on unit sphere. Shape (batch_size, latent_dim). Returns: Tensor: Uniformity loss value. """ # Compute pairwise distances between all encodings pairwise_distances = torch.cdist(encodings, encodings, p=2) # L2 distance # Convert distances to Gaussian kernel values t = self.config.ase_parameters.uniformity_kernel_scale kernel_values = torch.exp(-t * pairwise_distances**2) # Uniformity loss: log of average kernel value # Lower values indicate more uniform distribution uniformity_loss = torch.log(torch.mean(kernel_values)) return uniformity_loss
[docs] def calculate_extra_actor_loss(self, batch_td) -> Tuple[Tensor, Dict]: """Adds the diversity loss, if enabled. Args: batch_td (TensorDict): Batch of data from the experience buffer and the actor. Returns: Tuple[Tensor, Dict]: Extra actor loss and logging dictionary. """ extra_loss, extra_actor_log_dict = super().calculate_extra_actor_loss(batch_td) if self.config.ase_parameters.diversity_bonus <= 0: return extra_loss, extra_actor_log_dict diversity_loss = self.diversity_loss(batch_td) extra_actor_log_dict["actor/diversity_loss"] = diversity_loss.detach() return ( extra_loss + diversity_loss * self.config.ase_parameters.diversity_bonus, extra_actor_log_dict, )
[docs] def diversity_loss(self, batch_td): """Calculates the diversity loss to encourage latents to lead to diverse behaviors. Args: batch_td (TensorDict): Batch of data from the experience buffer and the actor. Returns: Tensor: Diversity loss. """ old_latents = batch_td["latents"].clone() clipped_old_mu = torch.clamp(batch_td["mean_action"], -1.0, 1.0) new_latents = self.sample_latents(batch_td.batch_size[0]) batch_td["latents"] = new_latents batch_td = self.actor(batch_td) clipped_new_mu = torch.clamp(batch_td["mean_action"], -1.0, 1.0) mu_diff = clipped_new_mu - clipped_old_mu mu_diff = torch.mean(torch.square(mu_diff), dim=-1) z_diff = new_latents * old_latents z_diff = torch.sum(z_diff, dim=-1) z_diff = 0.5 - 0.5 * z_diff diversity_bonus = mu_diff / (z_diff + 1e-5) diversity_loss = torch.square( self.config.ase_parameters.diversity_tar - diversity_bonus ).mean() return diversity_loss
# ----------------------------- # Termination and Logging # -----------------------------
[docs] def post_epoch_logging(self, training_log_dict: Dict): """Performs post epoch logging, including Mutual Information reward logging. Args: training_log_dict (Dict): Dictionary to update with logging information. """ training_log_dict["rewards/mi_enc_rewards"] = ( self.experience_buffer.mi_rewards.mean() ) if self.config.normalize_rewards: training_log_dict["rewards/unnormalized_mi_enc_rewards"] = ( self.experience_buffer.unnormalized_mi_rewards ) super().post_epoch_logging(training_log_dict)