protomotions.agents.base_agent.agent module¶
Base agent implementation for reinforcement learning.
This module provides the core agent class that all RL algorithms extend. It handles the complete training lifecycle including rollout collection, experience buffering, optimization, checkpointing, evaluation, and distributed training coordination.
- Key Classes:
BaseAgent: Abstract base class for all RL agents
- Key Features:
Distributed training with Lightning Fabric
Experience buffer management
Automatic checkpoint saving/loading
Periodic evaluation during training
Reward normalization
Episode statistics tracking
- class protomotions.agents.base_agent.agent.BaseAgent(fabric, env, config, root_dir=None)[source]¶
Bases:
objectBase class for reinforcement learning agents.
Provides the core training infrastructure that all algorithm implementations extend. Handles experience collection, optimization loop, checkpointing, and evaluation. Subclasses must implement model creation and algorithm-specific training logic.
- Parameters:
fabric (MockFabric) – Lightning Fabric instance for distributed training.
env (BaseEnv) – Environment instance for interaction.
config (BaseAgentConfig) – Agent configuration with hyperparameters.
root_dir (Path | None) – Directory for saving checkpoints and logs (optional, uses logger dir if available).
- model¶
Neural network model (created by subclass).
- optimizer¶
Optimizer for model parameters.
- experience_buffer¶
Buffer for storing rollout data.
- evaluator¶
Evaluator for computing performance metrics.
- __init__(fabric, env, config, root_dir=None)[source]¶
Initialize the base agent.
Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.
- Parameters:
fabric (MockFabric) – Lightning Fabric for distributed training and device management.
env (BaseEnv) – Environment instance for agent-environment interaction.
config (BaseAgentConfig) – Configuration containing hyperparameters and training settings.
root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).
- property should_stop¶
- load_parameters(state_dict)[source]¶
Load agent parameters from state dictionary.
Restores training state including epoch counter, step count, timing info, best scores, normalization statistics, and model weights.
- Parameters:
state_dict – Dictionary containing saved agent state from checkpoint. Expected keys: epoch, step_count, run_start_time, best_evaluated_score, running_reward_norm (if normalization enabled), model.
- get_state_dict(state_dict)[source]¶
Get complete state dictionary for checkpointing.
Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.
- Parameters:
state_dict – Existing state dict to update (typically empty dict).
- Returns:
Updated state dictionary containing all agent state.
- save(checkpoint_name='last.ckpt', new_high_score=False)[source]¶
Save model checkpoint and environment state.
- register_algorithm_experience_buffer_keys()[source]¶
Register algorithm-specific keys in the experience buffer.
Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).
- collect_rollout_step(obs_td, step)[source]¶
Collect experience data during rollout at current timestep.
Called once per timestep during the data collection (rollout) phase. Subclasses implement this to: 1. Query the policy to select actions from observations 2. Store intermediate training data in experience buffer (e.g., values, log probs) 3. Return actions to apply to the environment
- Parameters:
obs – Dictionary of observations from environment
step – Current timestep index in the rollout [0, num_steps)
- Returns:
Actions tensor to apply to environment [num_envs, action_dim]
- fit()[source]¶
Main training loop for the agent.
Executes the complete training process including: 1. Experience buffer setup (auto-registers keys from model outputs) 2. Environment rollouts (data collection) 3. Model optimization 4. Periodic evaluation 5. Checkpoint saving 6. Logging and metrics
The loop runs for max_epochs epochs, where each epoch collects num_steps of experience from num_envs parallel environments, then performs multiple optimization steps on the collected data.
Note
This is the main entry point for training. Call setup() before fit().
- add_agent_info_to_obs(obs)[source]¶
Add agent-specific observations to the environment observations.
This can be used to augment observations from both reset() and step() with agent-specific information (e.g., latent codes, discriminator obs).
- obs_dict_to_tensordict(obs_dict)[source]¶
Convert observation dict to TensorDict.
- Parameters:
obs_dict (Dict) – Dictionary of observation tensors from environment.
- Returns:
TensorDict with same keys and values.
- Return type:
MockTensorDict
- post_env_step_modifications(dones, terminated, extras)[source]¶
Allow subclasses to modify dones/terminated after env.step().
This hook allows algorithm-specific modifications (e.g., AMP discriminator termination).
- Parameters:
dones – Reset flags from environment
terminated – Termination flags from environment
extras – Info dictionary from environment
- Returns:
Modified (dones, terminated, extras) tuple
- record_rollout_step(next_obs_td, actions, rewards, dones, terminated, done_indices, extras, step)[source]¶
Record metrics and store data after environment step during rollout.
Called once per timestep during data collection phase, after the environment has been stepped. Handles: 1. Episode statistics tracking (rewards, lengths) 2. Environment extras logging 3. Experience buffer updates (actions, rewards, dones) 4. Reward normalization if enabled
- Parameters:
next_obs – Observations after environment step
actions (MockTensor) – Actions that were applied
rewards (MockTensor) – Rewards from environment step
dones (MockTensor) – Reset flags indicating episode completion
terminated (MockTensor) – Termination flags indicating early termination
done_indices (MockTensor) – Indices of environments that finished episodes
extras (Dict) – Additional info dictionary from environment
step (int) – Current timestep index in the rollout [0, num_steps)
- eval()[source]¶
Set the model to evaluation mode.
Disables training-specific behaviors like dropout and batch normalization updates. Typically called before collecting experience or during evaluation.
- max_num_batches()[source]¶
Calculate maximum number of minibatches per epoch.
- Returns:
Integer number of minibatches needed to process all collected experience.