protomotions.agents.utils.training module

Training utility functions for agent training.

This module provides helper functions used during agent training including gradient clipping, bounds loss, distributed metrics aggregation, and model utilities.

Key Functions:
  • bounds_loss: Penalize actions near joint limits

  • handle_model_grad_clipping: Clip gradients and handle bad gradients

  • aggregate_scalar_metrics: Aggregate metrics across distributed processes

  • get_activation_func: Get activation function by name

protomotions.agents.utils.training.bounds_loss(mu)[source]

Compute soft bounds loss for actions near limits.

Penalizes actions that exceed soft bounds (±1.0) to keep actions within reasonable ranges and prevent extreme joint angles.

Parameters:

mu (MockTensor) – Action means from policy (batch_size, action_dim).

Returns:

Bounds loss for each sample (batch_size,). Zero if within bounds, quadratic penalty beyond soft bounds.

Return type:

MockTensor

Example

>>> actions = policy(obs)
>>> loss = bounds_loss(actions)  # Penalize extreme actions
protomotions.agents.utils.training.handle_model_grad_clipping(config, fabric, model, optimizer, model_name)[source]

Handle gradient clipping and detect bad gradients.

Computes gradient norm, clips if configured, and checks for NaN or extremely large gradients. Optionally zeros gradients if bad.

Parameters:
  • config – Agent config with grad clipping settings.

  • fabric – Fabric instance for distributed operations.

  • model – Neural network model.

  • optimizer – Optimizer for the model.

  • model_name – Name for logging (e.g., “actor”, “critic”).

Returns:

Dictionary with gradient norm metrics for logging.

Note

If bad gradients detected and fail_on_bad_grads=True, raises assertion error.

protomotions.agents.utils.training.aggregate_scalar_metrics(log_dict, fabric)[source]

Aggregate scalar metrics across all devices using all_gather and mean reduction.

All ranks compute the same averaged metrics. Then fabric.log_dict() only uploads from rank 0 (via Lightning’s rank_zero_only pattern), so wandb logs the average across all ranks rather than just rank 0’s local metrics.

protomotions.agents.utils.training.get_activation_func(activation_name, return_type='nn')[source]

Get activation function by name.

Returns either nn.Module or functional version of the activation. Supports common activations: tanh, relu, elu, gelu, silu, mish, identity.

Parameters:
  • activation_name – Name of activation function (case-insensitive).

  • return_type – Either “nn” for nn.Module or “functional” for functional version.

Returns:

Activation function (nn.Module if return_type=”nn”, function if return_type=”functional”).

Raises:

NotImplementedError – If activation name or return type not recognized.

Example

>>> act_module = get_activation_func("relu", return_type="nn")
>>> act_func = get_activation_func("relu", return_type="functional")