Source code for protomotions.agents.base_agent.config
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Configuration classes for base agent.
This module defines the configuration dataclasses used by the base agent and all
derived agents. These configurations specify training parameters, optimization
settings, and evaluation parameters.
Key Classes:
- BaseAgentConfig: Main agent configuration
- BaseModelConfig: Model architecture configuration
- OptimizerConfig: Optimizer parameters
- MaxEpisodeLengthManagerConfig: Episode length curriculum
"""
from dataclasses import dataclass, field
from typing import Optional, List
from protomotions.utils.config_builder import ConfigBuilder
from protomotions.agents.evaluators.config import EvaluatorConfig
[docs]
@dataclass
class MaxEpisodeLengthManagerConfig(ConfigBuilder):
"""Configuration for managing max episode length during training."""
# Example for configuration for agent to slowly increase the max episode length
# max_episode_length_manager:
# start_length: 5
# end_length: 300
# transition_epochs: 100000
start_length: int = 5
end_length: int = 300
transition_epochs: int = 100000
[docs]
def current_max_episode_length(self, current_epoch: int) -> int:
"""
Returns the current max episode length based on linear interpolation.
Args:
current_step: Current step in the episode
Returns:
Interpolated max episode length
"""
if self.transition_epochs == 0:
# No interpolation, return the fixed value
return self.start_length
# Linear interpolation between start and end values
progress = min(current_epoch / self.transition_epochs, 1.0)
return int(self.start_length + progress * (self.end_length - self.start_length))
[docs]
@dataclass
class OptimizerConfig(ConfigBuilder):
"""Configuration for optimizers."""
_target_: str = "torch.optim.Adam"
lr: float = 1e-4
weight_decay: float = 0.0
eps: float = 1e-8
betas: tuple = field(default_factory=lambda: (0.9, 0.999))
[docs]
@dataclass
class BaseModelConfig(ConfigBuilder):
"""Configuration for PPO Model (Actor-Critic)."""
_target_: str = "protomotions.agents.base_agent.model.BaseModel"
in_keys: List[str] = field(default_factory=list)
out_keys: List[str] = field(default_factory=list)
[docs]
@dataclass
class BaseAgentConfig(ConfigBuilder):
"""Main configuration class for PPO Agent."""
batch_size: int
training_max_steps: int
_target_: str = "protomotions.agents.base_agent.agent.BaseAgent"
# Model configuration
model: BaseModelConfig = field(default_factory=BaseModelConfig)
# Base agent hyperparameters
num_steps: int = 32
gradient_clip_val: float = 0.0
fail_on_bad_grads: bool = False
check_grad_mag: bool = True
gamma: float = 0.99
# Bounds and regularization
bounds_loss_coef: float = (
0.0 # Default policy uses tanh outputs, so we don't need the bounds loss.
)
# Training configuration
task_reward_w: float = 1.0
num_mini_epochs: int = 1
training_early_termination: Optional[int] = None
# Checkpoint saving configuration
save_epoch_checkpoint_every: Optional[int] = (
1000 # Save epoch_xxx.ckpt every N epochs (None = disabled)
)
save_last_checkpoint_every: int = 10 # Save/overwrite last.ckpt every K epochs
# Episode length management
max_episode_length_manager: Optional[MaxEpisodeLengthManagerConfig] = None
# Evaluator configuration
evaluator: EvaluatorConfig = field(default_factory=EvaluatorConfig)
# Reward normalization
normalize_rewards: bool = True
normalized_reward_clamp_value: float = 5.0