# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Proximal Policy Optimization (PPO) agent implementation.
This module implements the PPO algorithm for reinforcement learning. PPO is an
on-policy algorithm that uses clipped surrogate objectives for stable policy updates.
It collects experience through environment interaction and performs multiple epochs
of minibatch updates using Generalized Advantage Estimation (GAE).
Key Classes:
- PPO: Main PPO agent class extending BaseAgent
References:
Schulman et al. "Proximal Policy Optimization Algorithms" (2017)
"""
import torch
from torch import Tensor
from tensordict import TensorDict
import logging
from pathlib import Path
from typing import Optional, Tuple, Dict
from lightning.fabric import Fabric
from protomotions.utils.hydra_replacement import instantiate, get_class
from protomotions.agents.ppo.model import PPOModel
from protomotions.agents.common.common import weight_init
from protomotions.envs.base_env.env import BaseEnv
from protomotions.agents.utils.normalization import combine_moments
from protomotions.agents.ppo.utils import discount_values
from protomotions.agents.ppo.config import PPOAgentConfig
from protomotions.agents.base_agent.agent import BaseAgent
from protomotions.agents.utils.training import bounds_loss, handle_model_grad_clipping
log = logging.getLogger(__name__)
[docs]
class PPO(BaseAgent):
"""Proximal Policy Optimization (PPO) agent.
Implements the PPO algorithm for training reinforcement learning policies.
PPO uses clipped surrogate objectives to enable stable policy updates while
maintaining sample efficiency. This implementation supports actor-critic
architecture with separate optimizers for policy and value networks.
The agent collects experience through environment interaction, computes
advantages using Generalized Advantage Estimation (GAE), and performs
multiple epochs of minibatch updates on the collected data.
Args:
fabric: Lightning Fabric instance for distributed training.
env: Environment instance to train on.
config: PPO-specific configuration including learning rates, clip parameters, etc.
root_dir: Optional root directory for saving outputs.
Attributes:
tau: GAE lambda parameter for advantage estimation.
e_clip: PPO clipping parameter for policy updates.
actor: Policy network.
critic: Value network.
Example:
>>> fabric = Fabric(devices=4)
>>> env = Steering(config, robot_config, simulator_config, device)
>>> agent = PPO(fabric, env, config)
>>> agent.setup()
>>> agent.train()
"""
# -----------------------------
# Initialization and Setup
# -----------------------------
config: PPOAgentConfig
[docs]
def __init__(
self,
fabric: Fabric,
env: BaseEnv,
config: PPOAgentConfig,
root_dir: Optional[Path] = None,
):
super().__init__(fabric, env, config, root_dir)
self.tau: float = self.config.tau
self.e_clip: float = self.config.e_clip
# Initialize EMA for advantage normalization
if (
self.config.advantage_normalization.enabled
and self.config.advantage_normalization.use_ema
):
self.adv_mean_ema = torch.zeros(1, device=self.device, dtype=torch.float32)
self.adv_std_ema = torch.ones(1, device=self.device, dtype=torch.float32)
else:
self.adv_mean_ema = None
self.adv_std_ema = None
[docs]
def create_model(self):
"""Create PPO actor-critic model.
Instantiates the PPO model with actor and critic networks, applies
weight initialization, and returns the model.
Returns:
PPOModel instance with initialized weights.
"""
PPOModelClass = get_class(self.config.model._target_)
model: PPOModel = PPOModelClass(config=self.config.model)
model.apply(weight_init)
return model
[docs]
def create_optimizers(self, model: PPOModel):
"""Create separate optimizers for actor and critic.
Sets up Adam optimizers for policy and value networks with independent
learning rates. Uses Fabric for distributed training setup.
Args:
model: PPOModel with actor and critic networks.
"""
actor_optimizer = instantiate(
self.config.model.actor_optimizer,
params=list(model._actor.parameters()),
)
self.actor, self.actor_optimizer = self.fabric.setup(
model._actor, actor_optimizer
)
# Actor now only has forward() method
critic_optimizer = instantiate(
self.config.model.critic_optimizer,
params=list(model._critic.parameters()),
)
self.critic, self.critic_optimizer = self.fabric.setup(
model._critic, critic_optimizer
)
[docs]
def load_parameters(self, state_dict):
"""Load PPO-specific parameters from checkpoint.
Loads actor, critic, and optimizer states. Preserves config overrides
for actor_logstd if specified at command line.
Args:
state_dict: Checkpoint state dictionary containing model and optimizer states.
"""
# Save current logstd value to preserve config overrides
current_logstd = self.actor.logstd.data.clone()
current_actor_logstd_config = self.config.model.actor.actor_logstd
super().load_parameters(state_dict)
# Restore logstd if it was overridden in config (different from checkpoint)
checkpoint_logstd = self.actor.logstd.data
if not torch.allclose(current_logstd, checkpoint_logstd, atol=1e-6):
print(
f"Preserving overridden actor_logstd: {current_actor_logstd_config} "
f"(checkpoint had: {checkpoint_logstd[0].item():.3f})"
)
self.actor.logstd.data = current_logstd
self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
self.critic_optimizer.load_state_dict(state_dict["critic_optimizer"])
# Load EMA state if available
if (
self.config.advantage_normalization.enabled
and self.config.advantage_normalization.use_ema
):
if "adv_mean_ema" in state_dict:
self.adv_mean_ema.copy_(state_dict["adv_mean_ema"])
if "adv_std_ema" in state_dict:
self.adv_std_ema.copy_(state_dict["adv_std_ema"])
# -----------------------------
# Model Saving and State Dict
# -----------------------------
[docs]
def get_state_dict(self, state_dict):
extra_state_dict = super().get_state_dict(state_dict)
extra_state_dict.update(
{
"actor_optimizer": self.actor_optimizer.state_dict(),
"critic_optimizer": self.critic_optimizer.state_dict(),
}
)
# Save EMA state
if (
self.config.advantage_normalization.enabled
and self.config.advantage_normalization.use_ema
):
extra_state_dict["adv_mean_ema"] = self.adv_mean_ema
extra_state_dict["adv_std_ema"] = self.adv_std_ema
state_dict.update(extra_state_dict)
return state_dict
# -----------------------------
# Experience Buffer and Training Loop
# -----------------------------
[docs]
def register_algorithm_experience_buffer_keys(self):
super().register_algorithm_experience_buffer_keys()
# PPO-specific keys that are computed after rollout (not from model forward)
self.experience_buffer.register_key(
"next_value", shape=(getattr(self.experience_buffer, "value").shape[2:])
) # Computed in post_train_env_step
self.experience_buffer.register_key(
"returns"
) # Computed in pre_process_dataset
self.experience_buffer.register_key(
"advantages"
) # Computed in pre_process_dataset
# -----------------------------
# Environment Interaction Helpers
# -----------------------------
[docs]
def record_rollout_step(
self,
next_obs_td: TensorDict,
actions,
rewards,
dones,
terminated,
done_indices,
extras,
step,
):
"""Record PPO-specific data: next value estimates for GAE computation."""
super().record_rollout_step(
next_obs_td, actions, rewards, dones, terminated, done_indices, extras, step
)
# Use model forward to get next value
next_output_td = self.model._critic(next_obs_td)
next_value = next_output_td["value"] * (1 - terminated.float()).unsqueeze(-1)
self.experience_buffer.update_data("next_value", step, next_value)
# -----------------------------
# Optimization
# -----------------------------
[docs]
def actor_step(self, batch_dict) -> Tuple[Tensor, Dict]:
"""Compute actor loss and perform policy update.
Computes PPO clipped surrogate objective plus optional bounds loss and
extra algorithm-specific losses.
Args:
batch_dict: Minibatch containing obs, actions, old neglogp, advantages.
Returns:
Tuple of (actor_loss, log_dict) where:
- actor_loss: Total actor loss for backprop
- log_dict: Dictionary of actor metrics for logging
"""
# Forward through actor to get current policy's distribution
batch_td = TensorDict(batch_dict, batch_size=batch_dict["action"].shape[0])
batch_td = self.actor(batch_td)
mean_action = batch_td["mean_action"]
# Recompute neglogp for the actions that were actually taken (from experience buffer)
# We need the current policy's evaluation, not the sampled action's neglogp
mu = mean_action # Already tanh-bounded
std = torch.exp(self.actor.logstd)
dist = torch.distributions.Normal(mu, std)
current_neglogp = -dist.log_prob(batch_dict["action"]).sum(dim=-1)
# Compute probability ratio between new and old policy
ratio = torch.exp(batch_dict["neglogp"] - current_neglogp)
surr1 = batch_dict["advantages"] * ratio
surr2 = batch_dict["advantages"] * torch.clamp(
ratio, 1.0 - self.e_clip, 1.0 + self.e_clip
)
ppo_loss = torch.max(-surr1, -surr2)
clipped = torch.abs(ratio - 1.0) > self.e_clip
clipped = clipped.detach().float().mean()
if self.config.bounds_loss_coef > 0:
b_loss: Tensor = bounds_loss(mean_action) * self.config.bounds_loss_coef
else:
b_loss = torch.zeros(self.num_envs, device=self.device)
actor_ppo_loss = ppo_loss.mean()
b_loss = b_loss.mean()
extra_loss, extra_actor_log_dict = self.calculate_extra_actor_loss(batch_td)
actor_loss = actor_ppo_loss + b_loss + extra_loss
log_dict = {
"actor/ppo_loss": actor_ppo_loss.detach(),
"actor/bounds_loss": b_loss.detach(),
"actor/extra_loss": extra_loss.detach(),
"actor/clip_frac": clipped.detach(),
"losses/actor_loss": actor_loss.detach(),
}
log_dict.update(extra_actor_log_dict)
# Memory optimization: Detach intermediate tensors that won't be used for gradients
# This prevents unnecessary gradient graph retention
ratio = ratio.detach()
surr1 = surr1.detach()
surr2 = surr2.detach()
ppo_loss = ppo_loss.detach()
return actor_loss, log_dict
[docs]
def critic_step(self, batch_dict) -> Tuple[Tensor, Dict]:
# Convert to TensorDict for model processing
batch_td = TensorDict(batch_dict, batch_size=batch_dict["action"].shape[0])
batch_td = self.critic(batch_td)
values = batch_td["value"]
if self.config.clip_critic_loss:
critic_loss_unclipped = (values - batch_dict["returns"].unsqueeze(-1)).pow(
2
)
v_clipped = batch_dict["value"] + torch.clamp(
values - batch_dict["value"],
-self.config.e_clip,
self.config.e_clip,
)
critic_loss_clipped = (v_clipped - batch_dict["returns"].unsqueeze(-1)).pow(
2
)
critic_loss_max = torch.max(critic_loss_unclipped, critic_loss_clipped)
critic_loss = critic_loss_max.mean()
# Memory optimization: Detach intermediate tensors
critic_loss_unclipped = critic_loss_unclipped.detach()
v_clipped = v_clipped.detach()
critic_loss_clipped = critic_loss_clipped.detach()
critic_loss_max = critic_loss_max.detach()
else:
critic_loss = (batch_dict["returns"].unsqueeze(-1) - values).pow(2).mean()
log_dict = {"losses/critic_loss": critic_loss.detach()}
# Memory optimization: Detach values tensor if not needed for gradients
values = values.detach()
return critic_loss, log_dict
# -----------------------------
# Optimization Override
# -----------------------------
[docs]
def optimize_model(self) -> Dict:
# Reset epoch-level actor skip flag
self._skip_actor_for_epoch = False
training_log_dict = super().optimize_model()
# Merge advantage normalization logs if available
if hasattr(self, "_adv_norm_log"):
training_log_dict.update(self._adv_norm_log)
return training_log_dict
# -----------------------------
# Helper Functions
# -----------------------------
@torch.no_grad()
def pre_process_dataset(self):
"""
Pre-process the dataset to compute advantages and returns.
This is called before process_dataset.
"""
advantages = discount_values(
self.experience_buffer.dones,
self.experience_buffer.value.squeeze(-1),
self.experience_buffer.total_rewards,
self.experience_buffer.next_value.squeeze(-1),
self.gamma,
self.tau,
)
returns = advantages + self.experience_buffer.value.squeeze(-1)
assert torch.all(torch.isfinite(returns)), f"Returns are not finite: {returns}"
self.experience_buffer.batch_update_data("returns", returns)
adv_norm_log = {}
# Log raw advantage stats
adv_norm_log["adv_norm/raw_mean"] = advantages.mean().item()
adv_norm_log["adv_norm/raw_std"] = advantages.std().item()
adv_norm_log["adv_norm/raw_min"] = advantages.min().item()
adv_norm_log["adv_norm/raw_max"] = advantages.max().item()
if self.config.advantage_normalization.enabled:
if self.config.advantage_normalization.use_ema:
# EMA-based advantage normalization with clamping
# Apply safety minimum std to prevent extreme normalization
adv_std_safe = torch.clamp(
self.adv_std_ema, min=self.config.advantage_normalization.min_std
)
# Normalize advantages using EMA stats (from previous iteration)
if self.config.advantage_normalization.shift_mean:
advantages_normalized = (
advantages - self.adv_mean_ema
) / adv_std_safe
else:
advantages_normalized = advantages / adv_std_safe
# Clamp normalized advantages (z-scores)
clamp_range = self.config.advantage_normalization.clamp_range
advantages_clamped = torch.clamp(
advantages_normalized, -clamp_range, clamp_range
)
# Compute clamp fraction for logging
clamp_frac = (
(torch.abs(advantages_normalized) > clamp_range).float().mean()
)
# De-normalize the clamped values to get the actual values used
if self.config.advantage_normalization.shift_mean:
advantages_denorm = (
advantages_clamped * adv_std_safe + self.adv_mean_ema
)
else:
advantages_denorm = advantages_clamped * adv_std_safe
# Update EMA on GPU 0 using de-normalized clamped advantages
if self.fabric.global_rank == 0:
batch_mean = advantages_denorm.mean()
batch_std = advantages_denorm.std() + 1e-8
# EMA update: new = alpha * batch + (1 - alpha) * old
ema_alpha = self.config.advantage_normalization.ema_alpha
self.adv_mean_ema = (
ema_alpha * batch_mean + (1 - ema_alpha) * self.adv_mean_ema
)
self.adv_std_ema = (
ema_alpha * batch_std + (1 - ema_alpha) * self.adv_std_ema
)
# Broadcast EMA stats from GPU 0 to all GPUs
if self.fabric.world_size > 1 and torch.distributed.is_initialized():
torch.distributed.broadcast(self.adv_mean_ema, src=0)
torch.distributed.broadcast(self.adv_std_ema, src=0)
# Store advantage normalization stats for logging
adv_norm_log["adv_norm/mean_ema"] = self.adv_mean_ema.item()
adv_norm_log["adv_norm/std_ema"] = self.adv_std_ema.item()
adv_norm_log["adv_norm/clamp_frac"] = clamp_frac.item()
advantages = advantages_clamped
else:
# Original batch-based advantage normalization (no logging)
mean = advantages.mean()
var = advantages.var()
count = advantages.numel()
if self.fabric.world_size > 1:
all_means = self.fabric.all_gather(mean)
all_vars = self.fabric.all_gather(var)
all_counts = self.fabric.all_gather(count)
if self.fabric.global_rank == 0:
mean, var, count = combine_moments(
all_means, all_vars, all_counts
)
# Fabric broadcast returns a tensor on the source rank, so we need to move it to the device of the current rank.
updated_mean = self.fabric.broadcast(mean, src=0).to(self.device)
updated_var = self.fabric.broadcast(var, src=0).to(self.device)
if self.config.advantage_normalization.shift_mean:
advantages = (advantages - updated_mean) / (
torch.sqrt(updated_var) + 1e-8
)
else:
advantages = advantages / (torch.sqrt(updated_var) + 1e-8)
assert torch.all(
torch.isfinite(advantages)
), f"Advantages are not finite: {advantages}"
self.experience_buffer.batch_update_data("advantages", advantages)
# Store logs for later use
self._adv_norm_log = adv_norm_log