Source code for protomotions.agents.ppo.model
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""PPO model implementation with actor-critic architecture.
This module implements the neural network models for Proximal Policy Optimization.
The actor outputs a Gaussian policy distribution, and the critic estimates state values.
Key Classes:
- PPOActor: Policy network with Gaussian action distribution
- PPOModel: Complete actor-critic model for PPO
"""
import torch
from torch import distributions, nn
from protomotions.utils.hydra_replacement import get_class
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase
from protomotions.agents.common.common import SequentialModuleConfig
from protomotions.agents.ppo.config import PPOActorConfig, PPOModelConfig
from protomotions.agents.base_agent.model import BaseModel
[docs]
class PPOActor(TensorDictModuleBase):
"""PPO policy network (actor).
Self-contained policy that computes distribution parameters, samples actions,
and computes log probabilities all in a single forward pass.
Args:
config: Actor configuration including network architecture and initial log std.
Attributes:
logstd: Log standard deviation parameter (typically fixed during training).
mu: Neural network that outputs action means.
in_keys: List of input keys from mu model.
out_keys: List of output keys (action, mean_action, neglogp).
"""
[docs]
def __init__(self, config: PPOActorConfig):
super().__init__()
self.config = config
self.logstd = nn.Parameter(
torch.ones(self.config.num_out) * self.config.actor_logstd,
requires_grad=False,
)
MuClass = get_class(self.config.mu_model._target_)
self.mu: TensorDictModuleBase = MuClass(config=self.config.mu_model)
self.in_keys = self.config.in_keys
self.out_keys = self.config.out_keys
for key in ["action", "mean_action", "neglogp"]:
assert (
key in self.out_keys
), f"PPOActor output key {key} not in out_keys {self.out_keys}"
[docs]
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass: compute mu/std, sample action, compute neglogp.
This is the only method - self-contained and clean.
Args:
tensordict: TensorDict containing observations.
Returns:
TensorDict with action, mean_action, and neglogp added.
"""
# Compute distribution parameters
tensordict = self.mu(tensordict)
mu = tensordict[self.config.mu_key]
std = torch.exp(self.logstd)
# Sample action from distribution
dist = distributions.Normal(mu, std)
action = dist.sample()
# Compute negative log probability
neglogp = -dist.log_prob(action).sum(dim=-1)
# Store all outputs
tensordict["action"] = action
tensordict["mean_action"] = mu
tensordict["neglogp"] = neglogp
return tensordict
[docs]
class PPOModel(BaseModel):
"""Complete PPO model with actor and critic networks.
Pure forward function that computes all model outputs in TensorDict.
The forward pass adds action distribution parameters and value estimates.
Args:
config: Model configuration specifying actor and critic architectures.
Attributes:
_actor: Policy network.
_critic: Value network.
"""
config: PPOModelConfig
[docs]
def __init__(self, config: PPOModelConfig):
super().__init__(config)
# create networks
ActorClass = get_class(self.config.actor._target_)
self._actor: PPOActor = ActorClass(config=self.config.actor)
CriticClass = get_class(self.config.critic._target_)
self._critic: SequentialModuleConfig = CriticClass(config=self.config.critic)
# Set in_keys from actor (actor inherits from mu model)
actor_critic_in_keys = list(set(self._actor.in_keys + self._critic.in_keys))
actor_critic_out_keys = list(set(self._actor.out_keys + self._critic.out_keys))
for key in actor_critic_out_keys:
assert (
key in self.config.out_keys
), f"PPOModel output key {key} not in out_keys {self.config.out_keys}"
for key in actor_critic_in_keys:
assert (
key in self.config.in_keys
), f"PPOModel input key {key} not in in_keys {self.config.in_keys}"
self.in_keys = self.config.in_keys
self.out_keys = self.config.out_keys
[docs]
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass through actor and critic.
This is the main interface for the model. Computes all outputs:
- action: Sampled action
- mean_action: Deterministic action (mean)
- neglogp: Negative log probability of sampled action
- value: State value estimate
Args:
tensordict: TensorDict containing observations.
Returns:
TensorDict with all model outputs added.
"""
# Actor forward: adds action, mean_action, neglogp
tensordict = self._actor(tensordict)
# Critic forward: adds value estimate
tensordict = self._critic(tensordict)
return tensordict