Source code for protomotions.agents.common.config
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass, field
from typing import Any, List, Dict, Optional, Union
from protomotions.utils.config_builder import ConfigBuilder
# =============================================================================
# Base Configuration for Normalized Observations
# =============================================================================
[docs]
@dataclass
class NormObsBaseConfig(ConfigBuilder):
"""Base configuration for modules that support optional observation normalization.
With LazyLinear, only num_out is needed - input sizes are inferred automatically.
This is purely about normalization settings and output dimensions.
Individual TensorDictModules add their own obs_key/out_key fields as needed.
"""
normalize_obs: bool = False
norm_clamp_value: float = 5.0
# =============================================================================
# Common Module Configurations (from common.py)
# =============================================================================
[docs]
@dataclass
class ModuleOperationConfig(ConfigBuilder):
"""Configuration for module operations."""
[docs]
@dataclass
class ModuleOperationForwardConfig(ModuleOperationConfig):
"""Configuration for module operation forward."""
[docs]
@dataclass
class ModuleOperationPermuteConfig(ModuleOperationConfig):
"""Configuration for module operation permute."""
new_order: List[int]
[docs]
@dataclass
class ModuleOperationReshapeConfig(ModuleOperationConfig):
"""Configuration for module operation reshape."""
new_shape: List[Union[int, str]]
[docs]
@dataclass
class ModuleOperationSqueezeConfig(ModuleOperationConfig):
"""Configuration for module operation squeeze."""
squeeze_dim: int
[docs]
@dataclass
class ModuleOperationUnsqueezeConfig(ModuleOperationConfig):
"""Configuration for module operation unsqueeze."""
unsqueeze_dim: int
[docs]
@dataclass
class ModuleOperationExpandConfig(ModuleOperationConfig):
"""Configuration for module operation expand."""
expand_shape: List[int]
[docs]
@dataclass
class ModuleOperationSphereProjectionConfig(ModuleOperationConfig):
"""Configuration for sphere projection operation (L2 normalization to unit sphere)."""
[docs]
@dataclass
class FlattenConfig(NormObsBaseConfig):
"""Configuration for Flatten module."""
_target_: str = "protomotions.agents.common.common.Flatten"
in_keys: List[str] = field(default_factory=list)
out_keys: List[str] = field(default_factory=list)
module_operations: List[ModuleOperationConfig] = field(
default_factory=lambda: [ModuleOperationForwardConfig()]
)
# =============================================================================
# MLP Configurations (from mlp.py)
# =============================================================================
[docs]
@dataclass
class MLPLayerConfig(ConfigBuilder):
"""Configuration for a single MLP layer."""
units: int = 512
activation: str = "relu"
use_layer_norm: bool = False
[docs]
@dataclass
class MLPWithConcatConfig(NormObsBaseConfig):
"""Configuration for Multi-Layer Perceptron with optional normalization.
Unified MLP configuration that supports optional input normalization.
Set normalize_obs=False if you don't want normalization (default is False).
obs_key and out_key are optional in config but validated in MLP module.
"""
num_out: int = None
layers: List[MLPLayerConfig] = None
# For example:
# field(default_factory=lambda: [
# MLPLayerConfig(units=1024, activation="relu", use_layer_norm=False),
# MLPLayerConfig(units=1024, activation="relu", use_layer_norm=False),
# MLPLayerConfig(units=512, activation="relu", use_layer_norm=False)
# ])
_target_: str = "protomotions.agents.common.mlp.MLPWithConcat"
in_keys: List[str] = field(default_factory=list)
out_keys: List[str] = field(default_factory=list)
output_activation: Optional[str] = None
module_operations: List[ModuleOperationConfig] = field(
default_factory=lambda: [ModuleOperationForwardConfig()]
)
def __post_init__(self):
assert self.num_out is not None, "num_out must be provided"
assert self.layers is not None, "layers must be provided"
[docs]
@dataclass
class SequentialModuleConfig(ConfigBuilder):
"""Configuration for a sequential model."""
input_models: List[Any]
_target_: str = "protomotions.agents.common.common.SequentialModule"
in_keys: List[str] = field(default_factory=list)
out_keys: List[str] = field(default_factory=list)
[docs]
@dataclass
class MultiOutputModuleConfig(ConfigBuilder):
"""Configuration for a multi-output model (one input, many outputs)."""
output_models: List[Any]
_target_: str = "protomotions.agents.common.common.MultiOutputModule"
in_keys: List[str] = field(default_factory=list)
out_keys: List[str] = field(default_factory=list)
# =============================================================================
# Transformer Configurations (from transformer.py)
# =============================================================================