protomotions.agents.base_agent.agent module

Base agent implementation for reinforcement learning.

This module provides the core agent class that all RL algorithms extend. It handles the complete training lifecycle including rollout collection, experience buffering, optimization, checkpointing, evaluation, and distributed training coordination.

Key Classes:
  • BaseAgent: Abstract base class for all RL agents

Key Features:
  • Distributed training with Lightning Fabric

  • Experience buffer management

  • Automatic checkpoint saving/loading

  • Periodic evaluation during training

  • Reward normalization

  • Episode statistics tracking

class protomotions.agents.base_agent.agent.BaseAgent(fabric, env, config, root_dir=None)[source]

Bases: object

Base class for reinforcement learning agents.

Provides the core training infrastructure that all algorithm implementations extend. Handles experience collection, optimization loop, checkpointing, and evaluation. Subclasses must implement model creation and algorithm-specific training logic.

Parameters:
  • fabric (MockFabric) – Lightning Fabric instance for distributed training.

  • env (BaseEnv) – Environment instance for interaction.

  • config (BaseAgentConfig) – Agent configuration with hyperparameters.

  • root_dir (Path | None) – Directory for saving checkpoints and logs (optional, uses logger dir if available).

model

Neural network model (created by subclass).

optimizer

Optimizer for model parameters.

experience_buffer

Buffer for storing rollout data.

evaluator

Evaluator for computing performance metrics.

__init__(fabric, env, config, root_dir=None)[source]

Initialize the base agent.

Sets up distributed training infrastructure, initializes tracking metrics, and creates the evaluator. Subclasses should call super().__init__() first.

Parameters:
  • fabric (MockFabric) – Lightning Fabric for distributed training and device management.

  • env (BaseEnv) – Environment instance for agent-environment interaction.

  • config (BaseAgentConfig) – Configuration containing hyperparameters and training settings.

  • root_dir (Path | None) – Optional directory for saving outputs (uses logger dir if None).

property should_stop
setup()[source]
abstractmethod create_model()[source]
abstractmethod create_optimizers(model)[source]
load(checkpoint, load_env=True)[source]
load_parameters(state_dict)[source]

Load agent parameters from state dictionary.

Restores training state including epoch counter, step count, timing info, best scores, normalization statistics, and model weights.

Parameters:

state_dict – Dictionary containing saved agent state from checkpoint. Expected keys: epoch, step_count, run_start_time, best_evaluated_score, running_reward_norm (if normalization enabled), model.

get_state_dict(state_dict)[source]

Get complete state dictionary for checkpointing.

Collects all agent state including model weights, training progress, and normalization statistics into a single dictionary for saving.

Parameters:

state_dict – Existing state dict to update (typically empty dict).

Returns:

Updated state dictionary containing all agent state.

save(checkpoint_name='last.ckpt', new_high_score=False)[source]

Save model checkpoint and environment state.

Parameters:
  • checkpoint_name (str) – Name of checkpoint file (e.g., “last.ckpt” or “epoch_100.ckpt”)

  • new_high_score (bool) – Whether this is a new high score (will also save as score_based.ckpt)

register_algorithm_experience_buffer_keys()[source]

Register algorithm-specific keys in the experience buffer.

Subclasses override this to add custom keys to the experience buffer (e.g., AMP adds discriminator observations, ASE adds latent codes).

collect_rollout_step(obs_td, step)[source]

Collect experience data during rollout at current timestep.

Called once per timestep during the data collection (rollout) phase. Subclasses implement this to: 1. Query the policy to select actions from observations 2. Store intermediate training data in experience buffer (e.g., values, log probs) 3. Return actions to apply to the environment

Parameters:
  • obs – Dictionary of observations from environment

  • step – Current timestep index in the rollout [0, num_steps)

Returns:

Actions tensor to apply to environment [num_envs, action_dim]

fit()[source]

Main training loop for the agent.

Executes the complete training process including: 1. Experience buffer setup (auto-registers keys from model outputs) 2. Environment rollouts (data collection) 3. Model optimization 4. Periodic evaluation 5. Checkpoint saving 6. Logging and metrics

The loop runs for max_epochs epochs, where each epoch collects num_steps of experience from num_envs parallel environments, then performs multiple optimization steps on the collected data.

Note

This is the main entry point for training. Call setup() before fit().

add_agent_info_to_obs(obs)[source]

Add agent-specific observations to the environment observations.

This can be used to augment observations from both reset() and step() with agent-specific information (e.g., latent codes, discriminator obs).

obs_dict_to_tensordict(obs_dict)[source]

Convert observation dict to TensorDict.

Parameters:

obs_dict (Dict) – Dictionary of observation tensors from environment.

Returns:

TensorDict with same keys and values.

Return type:

MockTensorDict

post_env_step_modifications(dones, terminated, extras)[source]

Allow subclasses to modify dones/terminated after env.step().

This hook allows algorithm-specific modifications (e.g., AMP discriminator termination).

Parameters:
  • dones – Reset flags from environment

  • terminated – Termination flags from environment

  • extras – Info dictionary from environment

Returns:

Modified (dones, terminated, extras) tuple

check_obs_for_nans(obs, action)[source]
record_rollout_step(next_obs_td, actions, rewards, dones, terminated, done_indices, extras, step)[source]

Record metrics and store data after environment step during rollout.

Called once per timestep during data collection phase, after the environment has been stepped. Handles: 1. Episode statistics tracking (rewards, lengths) 2. Environment extras logging 3. Experience buffer updates (actions, rewards, dones) 4. Reward normalization if enabled

Parameters:
  • next_obs – Observations after environment step

  • actions (MockTensor) – Actions that were applied

  • rewards (MockTensor) – Rewards from environment step

  • dones (MockTensor) – Reset flags indicating episode completion

  • terminated (MockTensor) – Termination flags indicating early termination

  • done_indices (MockTensor) – Indices of environments that finished episodes

  • extras (Dict) – Additional info dictionary from environment

  • step (int) – Current timestep index in the rollout [0, num_steps)

get_combined_experience_buffer_rewards()[source]
optimize_model()[source]
abstractmethod perform_optimization_step(batch_dict)[source]
post_epoch_logging(training_log_dict)[source]
eval()[source]

Set the model to evaluation mode.

Disables training-specific behaviors like dropout and batch normalization updates. Typically called before collecting experience or during evaluation.

train()[source]
max_num_batches()[source]

Calculate maximum number of minibatches per epoch.

Returns:

Integer number of minibatches needed to process all collected experience.

get_step_count_increment()[source]

Calculate step count increment for distributed training.

Accounts for multiple GPUs and nodes in step counting.

Returns:

Number of environment steps per training iteration across all processes.

terminate_early()[source]

Request early termination of training.

Sets a flag that will cause the training loop to exit gracefully after the current epoch completes.