protomotions.agents.common.common module

Common neural network components and utilities for agents.

This module provides shared building blocks used across different agent architectures, including observation normalization, weight initialization, and specialized layers.

Key Classes:
  • NormObsBase: Base class for modules with observation normalization

  • Flatten: Flattening layer with flexible dimensions

  • Embedding: Embedding layer for discrete inputs

Key Functions:
  • weight_init: Initialize network weights

  • get_params: Extract parameters from optimizer groups

protomotions.agents.common.common.get_params(obj)[source]

Extract parameters from optimizer parameter groups.

Handles both flat lists of parameters and grouped parameters (as used by optimizers).

Parameters:

obj – Either a list of nn.Parameter or a list of parameter groups (dicts).

Returns:

Flat list of all parameters.

Return type:

List[<Mock object at 0x734320f72190>[]]

protomotions.agents.common.common.weight_init(m, orthogonal=False)[source]

Initialize weights for neural network modules.

Applies appropriate initialization to linear layers and other modules. Linear layers get orthogonal or default initialization with zero bias.

Parameters:
  • m – Neural network module to initialize.

  • orthogonal – If True, use orthogonal initialization for linear layers.

class protomotions.agents.common.common.NormObsBase(config)[source]

Bases: <Mock object at 0x734320f717d0>[]

Base class for modules with observation normalization.

Provides running mean/std normalization of observations using exponential moving averages. Normalization statistics are updated during training and frozen during evaluation.

Uses lazy initialization - input shape is inferred on first forward pass.

This is a simple tensor-to-tensor module. Subclasses handle TensorDict extraction/insertion.

Parameters:

config (NormObsBaseConfig) – Configuration specifying output dimensions and normalization parameters.

running_obs_norm

RunningMeanStd module for observation normalization (lazy).

num_out

Output dimension.

__init__(config)[source]
build_norm()[source]
forward(obs)[source]

Forward pass that normalizes observations.

Parameters:

obs (MockTensor) – Observation tensor to normalize.

Returns:

Normalized observation tensor.

Return type:

MockTensor

protomotions.agents.common.common.apply_module_operations(obs, module_operations, forward_model, normalizer)[source]
class protomotions.agents.common.common.Flatten(*args, **kwargs)[source]

Bases: TensorDictModuleBase

Flatten layer with observation normalization for TensorDict inputs.

Flattens input tensor and optionally normalizes it.

Parameters:

config (FlattenConfig) – Configuration specifying obs_key, normalization parameters, etc.

norm

NormObsBase module for normalization.

flatten

Flatten layer.

in_keys

List containing obs_key.

out_keys

List containing out_key.

__init__(config)[source]
config: FlattenConfig
forward(tensordict, *args, **kwargs)[source]

Forward pass that flattens and normalizes observations.

Parameters:

tensordict (MockTensorDict) – TensorDict containing observations.

Returns:

TensorDict with flattened and normalized observations.

Return type:

MockTensorDict

class protomotions.agents.common.common.SequentialModule(*args, **kwargs)[source]

Bases: TensorDictModuleBase

Sequential model with multiple input models and a trunk.

__init__(config)[source]
config: SequentialModuleConfig
forward(tensordict, *args, **kwargs)[source]
class protomotions.agents.common.common.MultiInputModule(*args, **kwargs)[source]

Bases: TensorDictModuleBase

__init__(config)[source]
config: MultiInputModuleConfig
forward(tensordict, *args, **kwargs)[source]
class protomotions.agents.common.common.MultiOutputModule(*args, **kwargs)[source]

Bases: TensorDictModuleBase

Takes single input key, passes to multiple output heads in parallel.

Opposite of MultiInputModule - one input, multiple outputs. Useful for multi-head architectures like ASE (discriminator + encoder heads).

__init__(config)[source]
config: MultiOutputModuleConfig
forward(tensordict, *args, **kwargs)[source]

Forward through all output heads in parallel.

Each head reads from in_keys and writes to its own out_key.