Source code for protomotions.agents.utils.training

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Training utility functions for agent training.

This module provides helper functions used during agent training including
gradient clipping, bounds loss, distributed metrics aggregation, and model utilities.

Key Functions:
    - bounds_loss: Penalize actions near joint limits
    - handle_model_grad_clipping: Clip gradients and handle bad gradients
    - aggregate_scalar_metrics: Aggregate metrics across distributed processes
    - get_activation_func: Get activation function by name
"""

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import functional as F
from typing import Dict

from protomotions.utils import torch_utils
from lightning.fabric import Fabric

from protomotions.agents.common.common import get_params


[docs] def bounds_loss(mu: Tensor) -> Tensor: """Compute soft bounds loss for actions near limits. Penalizes actions that exceed soft bounds (±1.0) to keep actions within reasonable ranges and prevent extreme joint angles. Args: mu: Action means from policy (batch_size, action_dim). Returns: Bounds loss for each sample (batch_size,). Zero if within bounds, quadratic penalty beyond soft bounds. Example: >>> actions = policy(obs) >>> loss = bounds_loss(actions) # Penalize extreme actions """ soft_bound = 1.0 mu_loss_high = ( torch.maximum(mu - soft_bound, torch.tensor(0, device=mu.device)) ** 2 ) mu_loss_low = torch.minimum(mu + soft_bound, torch.tensor(0, device=mu.device)) ** 2 b_loss = (mu_loss_low + mu_loss_high).sum(axis=-1) return b_loss
[docs] def handle_model_grad_clipping(config, fabric, model, optimizer, model_name): """Handle gradient clipping and detect bad gradients. Computes gradient norm, clips if configured, and checks for NaN or extremely large gradients. Optionally zeros gradients if bad. Args: config: Agent config with grad clipping settings. fabric: Fabric instance for distributed operations. model: Neural network model. optimizer: Optimizer for the model. model_name: Name for logging (e.g., "actor", "critic"). Returns: Dictionary with gradient norm metrics for logging. Note: If bad gradients detected and fail_on_bad_grads=True, raises assertion error. """ params = get_params(list(model.parameters())) grad_norm_before_clip = torch_utils.grad_norm(params) if config.check_grad_mag: bad_grads = ( torch.isnan(grad_norm_before_clip) or grad_norm_before_clip > 1000000.0 ) else: bad_grads = torch.isnan(grad_norm_before_clip) bad_grads_count = 0 if bad_grads: if config.fail_on_bad_grads: all_params = torch.cat( [p.grad.view(-1) for p in params if p.grad is not None], dim=0, ) raise ValueError( f"NaN gradient in {model_name}" + f" {all_params.isfinite().logical_not().float().mean().item()}" + f" {all_params.abs().min().item()}" + f" {all_params.abs().max().item()}" + f" {grad_norm_before_clip.item()}" ) else: bad_grads_count = 1 for p in params: if p.grad is not None: p.grad.zero_() if config.gradient_clip_val > 0: fabric.clip_gradients( model, optimizer, max_norm=config.gradient_clip_val, error_if_nonfinite=True, ) grad_norm_after_clip = torch_utils.grad_norm(params) clip_dict = { f"{model_name}/grad_norm_before_clip": grad_norm_before_clip.detach(), f"{model_name}/grad_norm_after_clip": grad_norm_after_clip.detach(), f"{model_name}/bad_grads_count": bad_grads_count, } return clip_dict
[docs] def aggregate_scalar_metrics(log_dict: Dict, fabric: Fabric) -> Dict: """ Aggregate scalar metrics across all devices using all_gather and mean reduction. All ranks compute the same averaged metrics. Then fabric.log_dict() only uploads from rank 0 (via Lightning's rank_zero_only pattern), so wandb logs the average across all ranks rather than just rank 0's local metrics. """ aggregated_dict = {} for key, value in log_dict.items(): if isinstance(value, (int, float)): # Convert to tensor for aggregation value_tensor = torch.tensor( value, device=fabric.device, dtype=torch.float32 ) elif isinstance(value, torch.Tensor): # Ensure it's a scalar tensor if value.numel() == 1: value_tensor = value.float().to(fabric.device) else: # For non-scalar tensors, take the mean and treat as scalar value_tensor = value.mean().float().to(fabric.device) else: # For non-numeric values, keep as is (no aggregation needed) aggregated_dict[key] = value continue if fabric.world_size > 1: # Gather values from all devices all_values = fabric.all_gather(value_tensor) # Take mean across all devices aggregated_value = all_values.mean().item() else: aggregated_value = value_tensor.item() aggregated_dict[key] = aggregated_value return aggregated_dict
[docs] def get_activation_func(activation_name, return_type="nn"): """Get activation function by name. Returns either nn.Module or functional version of the activation. Supports common activations: tanh, relu, elu, gelu, silu, mish, identity. Args: activation_name: Name of activation function (case-insensitive). return_type: Either "nn" for nn.Module or "functional" for functional version. Returns: Activation function (nn.Module if return_type="nn", function if return_type="functional"). Raises: NotImplementedError: If activation name or return type not recognized. Example: >>> act_module = get_activation_func("relu", return_type="nn") >>> act_func = get_activation_func("relu", return_type="functional") """ if activation_name.lower() == "tanh": activation = (nn.Tanh(), F.tanh) elif activation_name.lower() == "relu": activation = (nn.ReLU(), F.relu) elif activation_name.lower() == "elu": activation = (nn.ELU(), F.elu) elif activation_name.lower() == "gelu": activation = (nn.GELU(), F.gelu) elif activation_name.lower() == "identity": activation = (nn.Identity(), lambda x: x) elif activation_name.lower() == "silu": activation = (nn.SiLU(), F.silu) elif activation_name.lower() == "mish": activation = (nn.Mish(), F.mish) else: raise NotImplementedError( "Activation func {} not defined".format(activation_name) ) if return_type == "nn": return activation[0] elif return_type == "functional": return activation[1] else: raise NotImplementedError("Return type {} not implemented".format(return_type))