# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass
from typing import Dict, Any, Type, TypeVar, Union, get_origin, get_args, get_type_hints
from enum import Enum
import torch
from omegaconf import DictConfig
T = TypeVar("T")
[docs]
@dataclass
class ConfigBuilder:
"""Mixin class providing dictionary conversion functionality."""
[docs]
@classmethod
def from_dict(cls: Type[T], config_dict: Dict[str, Any]) -> T:
"""Create an instance from a dictionary, converting lists to tensors where appropriate.
Args:
config_dict: Dictionary containing configuration values.
Returns:
Instance of the class with values from the dictionary.
"""
field_types = get_type_hints(cls)
processed_dict = {}
# Helper function for type conversion
def convert_value(
val_to_convert: Any, target_type: Type, current_key: str
) -> Any:
if val_to_convert is None:
return None
origin = get_origin(target_type)
args = get_args(target_type)
# 1. Nested dataclass (must have from_dict)
if hasattr(target_type, "from_dict") and isinstance(
val_to_convert, (dict, DictConfig)
):
return target_type.from_dict(val_to_convert)
# 2. Enum
if isinstance(target_type, type) and issubclass(target_type, Enum):
return target_type.from_str(val_to_convert)
# 3. torch.Tensor from list
if target_type is torch.Tensor and isinstance(val_to_convert, list):
return torch.tensor(val_to_convert)
# 4. Dictionary with torch.Tensor values (e.g., Dict[Any, torch.Tensor])
if origin is dict and args and len(args) == 2 and args[1] is torch.Tensor:
converted_dict = {}
for k_dict, v_dict in val_to_convert.items():
converted_dict[k_dict] = torch.tensor(v_dict)
return converted_dict
# 5. List of torch.Tensors (e.g. List[torch.Tensor])
if origin is list and args and len(args) == 1 and args[0] is torch.Tensor:
converted_list = []
for item_idx, item in enumerate(val_to_convert):
converted_list.append(torch.tensor(item))
return converted_list
# Default: return value as is
# Dict and List of primitive values (without Enum or Tensor) are returned as is
return val_to_convert
for key, value in config_dict.items():
if key not in field_types:
print(
f"Note: '{key}' in config_dict is not a field in {cls.__name__}, it will be ignored."
)
continue
field_type = field_types[key]
if value is None:
processed_dict[key] = None
continue
origin_ft = get_origin(field_type)
args_ft = get_args(field_type)
if (
origin_ft is Union and len(args_ft) == 2 and args_ft[1] is type(None)
): # Optional[T]
inner_actual_type = args_ft[0]
processed_dict[key] = convert_value(value, inner_actual_type, key)
else:
processed_dict[key] = convert_value(value, field_type, key)
try:
return cls(**processed_dict)
except TypeError as e:
print(
f"Error instantiating {cls.__name__} with processed_dict. Keys in dict: {list(processed_dict.keys())}"
)
print(f"Original error: {str(e)}")
# Show expected fields vs provided fields
import inspect
sig = inspect.signature(cls.__init__)
expected_params = [
param.name for param in sig.parameters.values() if param.name != "self"
]
missing_params = [
param for param in expected_params if param not in processed_dict.keys()
]
if missing_params:
print(f"Missing required parameters: {missing_params}")
# Show field types for debugging
for k_detail, v_detail in processed_dict.items():
expected_type_detail = field_types.get(k_detail)
print(
f" Field '{k_detail}': Expected {expected_type_detail}, Got {type(v_detail)} = {v_detail}"
)
raise
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert the config to a dictionary, handling nested dataclasses.
Returns:
Dictionary representation of the config.
"""
result = {}
for field_name in self.__dataclass_fields__: # Changed 'field' to 'field_name' to avoid conflict with 'field' from dataclasses
value = getattr(self, field_name)
if value is None:
result[field_name] = None
elif hasattr(value, "to_dict"): # Handle nested dataclasses
result[field_name] = value.to_dict()
elif isinstance(value, Enum): # Handle enums
result[field_name] = value.value
elif isinstance(value, (list, tuple)): # Handle lists/tuples of dataclasses
if value and hasattr(value[0], "to_dict"):
result[field_name] = [item.to_dict() for item in value]
else:
result[field_name] = value
elif isinstance(
value, dict
): # Handle dicts of dataclasses or other complex types
processed_dict_val = {}
for k, v in value.items():
if hasattr(v, "to_dict"):
processed_dict_val[k] = v.to_dict()
elif isinstance(v, torch.Tensor):
processed_dict_val[k] = v.tolist()
elif isinstance(v, Enum):
processed_dict_val[k] = v.value
else:
processed_dict_val[k] = v
result[field_name] = processed_dict_val
elif isinstance(value, torch.Tensor):
result[field_name] = value.tolist()
else:
result[field_name] = value
return result
[docs]
def __getitem__(self, key: str) -> Any:
"""Make configs behave like dicts for compatibility with external libraries."""
return self.to_dict()[key]
[docs]
def __contains__(self, key: str) -> bool:
"""Support 'in' operator for compatibility with external libraries."""
return key in self.to_dict()
[docs]
def get(self, key: str, default: Any = None) -> Any:
"""Dict-like get method for compatibility with external libraries."""
return self.to_dict().get(key, default)
[docs]
def build_standard_configs(
args,
terrain_config_fn,
scene_lib_config_fn,
motion_lib_config_fn,
env_config_fn,
configure_robot_and_simulator_fn=None,
agent_config_fn=None,
):
"""Build standard robot, simulator, terrain, scene_lib, motion_lib, env, and optionally agent configs.
This is a helper function to reduce boilerplate in experiment files.
All configs are built with training defaults - eval overrides applied separately via apply_inference_overrides().
Parameter order matches execution order: robot → sim → terrain → scene_lib → motion_lib → env → agent
Args:
args: Command line arguments containing robot_name, simulator, etc.
terrain_config_fn: REQUIRED function that takes (args) and returns TerrainConfig (or None for no terrain)
scene_lib_config_fn: REQUIRED function that takes (args) and returns SceneLibConfig (scene_file can be None for empty)
motion_lib_config_fn: REQUIRED function that takes (args) and returns MotionLibConfig (motion_file can be None for empty)
env_config_fn: REQUIRED function that takes (robot_config, args) and returns env config
configure_robot_and_simulator_fn: Optional function that takes (robot_config, simulator_config, args)
agent_config_fn: Optional function that takes (robot_config, env_config, args) and returns agent config
Returns:
Dict with keys: robot, simulator, terrain, scene_lib, motion_lib, env, agent (optional)
"""
from protomotions.robot_configs.factory import robot_config
from protomotions.simulator.factory import simulator_config as simulator_config_func
# Build robot config from factory
robot_cfg = robot_config(args.robot_name)
# Build simulator config from factory
simulator_cfg = simulator_config_func(
args.simulator, robot_cfg, args.headless, args.num_envs, args.experiment_name
)
# Configure robot and simulator for this experiment (if function provided)
if configure_robot_and_simulator_fn is not None:
configure_robot_and_simulator_fn(robot_cfg, simulator_cfg, args)
# Build component configs (independent of robot_config)
# These functions must always be provided
terrain_cfg = terrain_config_fn(args) # Can return None for no terrain (exception)
scene_lib_cfg = scene_lib_config_fn(
args
) # Must return SceneLibConfig (scene_file can be None)
motion_lib_cfg = motion_lib_config_fn(
args
) # Must return MotionLibConfig (motion_file can be None)
# Build env config (depends on robot_config)
env_cfg = env_config_fn(robot_cfg, args)
# Build agent config if function provided (depends on robot_config and env_config)
agent_cfg = (
agent_config_fn(robot_cfg, env_cfg, args)
if agent_config_fn is not None
else None
)
return {
"robot": robot_cfg,
"simulator": simulator_cfg,
"terrain": terrain_cfg,
"scene_lib": scene_lib_cfg,
"motion_lib": motion_lib_cfg,
"env": env_cfg,
"agent": agent_cfg,
}