protomotions.agents.ppo.agent module¶
Proximal Policy Optimization (PPO) agent implementation.
This module implements the PPO algorithm for reinforcement learning. PPO is an on-policy algorithm that uses clipped surrogate objectives for stable policy updates. It collects experience through environment interaction and performs multiple epochs of minibatch updates using Generalized Advantage Estimation (GAE).
- Key Classes:
PPO: Main PPO agent class extending BaseAgent
References
Schulman et al. “Proximal Policy Optimization Algorithms” (2017)
- class protomotions.agents.ppo.agent.PPO(fabric, env, config, root_dir=None)[source]¶
Bases:
BaseAgentProximal Policy Optimization (PPO) agent.
Implements the PPO algorithm for training reinforcement learning policies. PPO uses clipped surrogate objectives to enable stable policy updates while maintaining sample efficiency. This implementation supports actor-critic architecture with separate optimizers for policy and value networks.
The agent collects experience through environment interaction, computes advantages using Generalized Advantage Estimation (GAE), and performs multiple epochs of minibatch updates on the collected data.
- Parameters:
fabric (MockFabric) – Lightning Fabric instance for distributed training.
env (BaseEnv) – Environment instance to train on.
config (PPOAgentConfig) – PPO-specific configuration including learning rates, clip parameters, etc.
root_dir (Path | None) – Optional root directory for saving outputs.
- tau¶
GAE lambda parameter for advantage estimation.
- e_clip¶
PPO clipping parameter for policy updates.
- actor¶
Policy network.
- critic¶
Value network.
Example
>>> fabric = Fabric(devices=4) >>> env = Steering(config, robot_config, simulator_config, device) >>> agent = PPO(fabric, env, config) >>> agent.setup() >>> agent.train()
- config: PPOAgentConfig¶
- __init__(fabric, env, config, root_dir=None)[source]¶
Initialize the base agent.
Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.
- Parameters:
fabric (MockFabric) – Lightning Fabric for distributed training and device management.
env (BaseEnv) – Environment instance for agent-environment interaction.
config (PPOAgentConfig) – Configuration containing hyperparameters and training settings.
root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).
- create_model()[source]¶
Create PPO actor-critic model.
Instantiates the PPO model with actor and critic networks, applies weight initialization, and returns the model.
- Returns:
PPOModel instance with initialized weights.
- create_optimizers(model)[source]¶
Create separate optimizers for actor and critic.
Sets up Adam optimizers for policy and value networks with independent learning rates. Uses Fabric for distributed training setup.
- Parameters:
model (PPOModel) – PPOModel with actor and critic networks.
- load_parameters(state_dict)[source]¶
Load PPO-specific parameters from checkpoint.
Loads actor, critic, and optimizer states. Preserves config overrides for actor_logstd if specified at command line.
- Parameters:
state_dict – Checkpoint state dictionary containing model and optimizer states.
- get_state_dict(state_dict)[source]¶
Get complete state dictionary for checkpointing.
Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.
- Parameters:
state_dict – Existing state dict to update (typically empty dict).
- Returns:
Updated state dictionary containing all agent state.
- register_algorithm_experience_buffer_keys()[source]¶
Register algorithm-specific keys in the experience buffer.
Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).
- record_rollout_step(next_obs_td, actions, rewards, dones, terminated, done_indices, extras, step)[source]¶
Record PPO-specific data: next value estimates for GAE computation.
- perform_optimization_step(batch_dict, batch_idx)[source]¶
Perform one PPO optimization step on a minibatch.
Computes actor and critic losses, performs backpropagation, clips gradients, and updates both networks.
- Parameters:
batch_dict – Dictionary containing minibatch data (obs, actions, advantages, etc.).
batch_idx – Index of current batch (unused but kept for compatibility).
- Returns:
Dictionary of training metrics (losses, clip fraction, etc.).
- Return type:
- actor_step(batch_dict)[source]¶
Compute actor loss and perform policy update.
Computes PPO clipped surrogate objective plus optional bounds loss and extra algorithm-specific losses.
- Parameters:
batch_dict – Minibatch containing obs, actions, old neglogp, advantages.
- Returns:
actor_loss: Total actor loss for backprop
log_dict: Dictionary of actor metrics for logging
- Return type:
Tuple of (actor_loss, log_dict) where