protomotions.agents.utils.training module¶
Training utility functions for agent training.
This module provides helper functions used during agent training including gradient clipping, bounds loss, distributed metrics aggregation, and model utilities.
- Key Functions:
bounds_loss: Penalize actions near joint limits
handle_model_grad_clipping: Clip gradients and handle bad gradients
aggregate_scalar_metrics: Aggregate metrics across distributed processes
get_activation_func: Get activation function by name
- protomotions.agents.utils.training.bounds_loss(mu)[source]¶
Compute soft bounds loss for actions near limits.
Penalizes actions that exceed soft bounds (±1.0) to keep actions within reasonable ranges and prevent extreme joint angles.
- Parameters:
mu (MockTensor) – Action means from policy (batch_size, action_dim).
- Returns:
Bounds loss for each sample (batch_size,). Zero if within bounds, quadratic penalty beyond soft bounds.
- Return type:
MockTensor
Example
>>> actions = policy(obs) >>> loss = bounds_loss(actions) # Penalize extreme actions
- protomotions.agents.utils.training.handle_model_grad_clipping(config, fabric, model, optimizer, model_name)[source]¶
Handle gradient clipping and detect bad gradients.
Computes gradient norm, clips if configured, and checks for NaN or extremely large gradients. Optionally zeros gradients if bad.
- Parameters:
config – Agent config with grad clipping settings.
fabric – Fabric instance for distributed operations.
model – Neural network model.
optimizer – Optimizer for the model.
model_name – Name for logging (e.g., “actor”, “critic”).
- Returns:
Dictionary with gradient norm metrics for logging.
Note
If bad gradients detected and fail_on_bad_grads=True, raises assertion error.
- protomotions.agents.utils.training.aggregate_scalar_metrics(log_dict, fabric)[source]¶
Aggregate scalar metrics across all devices using all_gather and mean reduction.
All ranks compute the same averaged metrics. Then fabric.log_dict() only uploads from rank 0 (via Lightning’s rank_zero_only pattern), so wandb logs the average across all ranks rather than just rank 0’s local metrics.
- protomotions.agents.utils.training.get_activation_func(activation_name, return_type='nn')[source]¶
Get activation function by name.
Returns either nn.Module or functional version of the activation. Supports common activations: tanh, relu, elu, gelu, silu, mish, identity.
- Parameters:
activation_name – Name of activation function (case-insensitive).
return_type – Either “nn” for nn.Module or “functional” for functional version.
- Returns:
Activation function (nn.Module if return_type=”nn”, function if return_type=”functional”).
- Raises:
NotImplementedError – If activation name or return type not recognized.
Example
>>> act_module = get_activation_func("relu", return_type="nn") >>> act_func = get_activation_func("relu", return_type="functional")