Source code for protomotions.utils.config_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# =============================================================================
# General Config Override Utilities
# =============================================================================

import logging
from typing import Dict, Any, Callable

log = logging.getLogger(__name__)


[docs] def import_experiment_relative_eval_overrides( relative_experiment_path: str, ) -> Callable: """ Dynamically import and return the apply_inference_overrides function from an experiment module. This utility uses Python's inspect module to determine the caller's directory and loads an experiment module relative to that location. This allows evaluation scripts to import their corresponding training experiment's eval override function without hardcoding paths. Args: relative_experiment_path: Path to the experiment module relative to the caller's directory. E.g., "mlp.py" if in the same directory, or "../other/experiment.py" Returns: The apply_inference_overrides callable from the loaded experiment module. Raises: AttributeError: If the loaded module doesn't have an apply_inference_overrides function. FileNotFoundError: If the experiment module file doesn't exist. ImportError: If the module cannot be loaded or executed. Example: # In examples/experiments/mimic/mlp_deploy.py apply_inference_overrides = import_experiment_relative_eval_overrides("mlp.py") # This loads apply_inference_overrides from examples/experiments/mimic/mlp.py """ import os import importlib.util import inspect # Get the path of the file that called this function # This will be the frame outside this function's definition frame = inspect.stack()[1] caller_file_path = os.path.abspath(frame.filename) caller_dir = os.path.dirname(caller_file_path) # Construct the path to the experiment module in the same directory as the caller _experiment_path = os.path.join(caller_dir, relative_experiment_path) # Check if the file exists before attempting to load it if not os.path.exists(_experiment_path): raise FileNotFoundError( f"Experiment module not found: {_experiment_path}\n" f"Caller: {caller_file_path}\n" f"Relative path: {relative_experiment_path}\n" f"Make sure the experiment file exists and the relative path is correct." ) # Load the experiment module spec = importlib.util.spec_from_file_location("experiment_module", _experiment_path) if spec is None or spec.loader is None: raise ImportError(f"Failed to load module spec from: {_experiment_path}") experiment_module = importlib.util.module_from_spec(spec) # Execute the module and catch any errors during execution try: spec.loader.exec_module(experiment_module) except Exception as e: raise ImportError( f"Failed to execute experiment module: {_experiment_path}\n" f"Error: {type(e).__name__}: {e}\n" f"Make sure the module has no syntax errors and all imports are available." ) from e # Check if the module has the apply_inference_overrides function if not hasattr(experiment_module, "apply_inference_overrides"): raise AttributeError( f"Module does not have 'apply_inference_overrides' function: {_experiment_path}\n" f"Available attributes: {[attr for attr in dir(experiment_module) if not attr.startswith('_')]}\n" f"Make sure the experiment module defines an apply_inference_overrides function." ) return experiment_module.apply_inference_overrides
[docs] def apply_config_overrides( overrides: Dict[str, Any], env_config, simulator_config, robot_config, agent_config=None, terrain_config=None, motion_lib_config=None, scene_lib_config=None, ) -> None: """ Apply configuration overrides to config objects. This is a general-purpose utility that works for both training and evaluation. Overrides are specified in dot notation: "env.field.subfield": value. Supports both object attribute access and dictionary key access for nested paths. Raises ValueError if any override fails (field not found or invalid). Args: overrides: Dictionary of overrides to apply. Format is {"config_type.field.subfield": value, ...} env_config: Environment configuration to modify in-place simulator_config: Simulator configuration to modify in-place robot_config: Robot configuration to modify in-place agent_config: Optional agent configuration to modify in-place terrain_config: Optional terrain configuration to modify in-place motion_lib_config: Optional motion library configuration to modify in-place scene_lib_config: Optional scene library configuration to modify in-place Supported config types: - 'env': Environment config - 'simulator': Simulator config - 'robot': Robot config - 'agent': Agent config (training only) - 'terrain': Terrain config - 'motion_lib': Motion library config - 'scene_lib': Scene library config Raises: ValueError: If override key is invalid or field not found (prevents typos) Example:: apply_config_overrides( { "env.max_episode_length": 1000, "simulator.num_envs": 4096, "env.reward_config.pow_rew.weight": 2e-6, # dict key access "terrain.horizontal_scale": 0.1, }, env_config, simulator_config, robot_config, terrain_config=terrain_config ) """ if not overrides: return log.info(f"Applying {len(overrides)} config override(s)...") for key, value in overrides.items(): # Parse the key to determine config and field path parts = key.split(".") if len(parts) < 2: raise ValueError( f"Invalid override key format: '{key}'. Expected 'config.field' or 'config.field.subfield'" ) # Determine which config object to use config_type = parts[0] field_path = parts[1:] if config_type == "env": config_obj = env_config elif config_type == "simulator": config_obj = simulator_config elif config_type == "robot": config_obj = robot_config elif config_type == "agent": if agent_config is None: raise ValueError( f"Cannot override '{key}': agent_config not provided to apply_config_overrides()\n" f"Agent config overrides are only supported in training, not evaluation." ) config_obj = agent_config elif config_type == "terrain": if terrain_config is None: raise ValueError( f"Cannot override '{key}': terrain_config not provided to apply_config_overrides()" ) config_obj = terrain_config elif config_type == "motion_lib": if motion_lib_config is None: raise ValueError( f"Cannot override '{key}': motion_lib_config not provided to apply_config_overrides()" ) config_obj = motion_lib_config elif config_type == "scene_lib": if scene_lib_config is None: raise ValueError( f"Cannot override '{key}': scene_lib_config not provided to apply_config_overrides()" ) config_obj = scene_lib_config else: raise ValueError( f"Unknown config type '{config_type}' in override key: '{key}'\n" f"Valid types: 'env', 'simulator', 'robot', 'agent', 'terrain', 'motion_lib', 'scene_lib'" ) # Navigate to the target field (supports both object attributes and dict keys) target = config_obj for i, field in enumerate(field_path[:-1]): if isinstance(target, dict): if field not in target: path_so_far = ".".join(parts[: i + 2]) raise ValueError( f"Key '{field}' not found in dict at config path: '{key}'\n" f"Failed at: '{path_so_far}'\n" f"Available keys: {list(target.keys())}" ) target = target[field] else: if not hasattr(target, field): path_so_far = ".".join(parts[: i + 2]) raise ValueError( f"Field '{field}' not found in config path: '{key}'\n" f"Failed at: '{path_so_far}'\n" f"Check the config dataclass structure for valid field names." ) target = getattr(target, field) # Set the final field (supports both object attributes and dict keys) final_field = field_path[-1] allowed_field_types = [int, float, bool, str, type(None)] if isinstance(target, dict): if final_field not in target: raise ValueError( f"Key '{final_field}' not found in dict at config path: '{key}'\n" f"Available keys: {list(target.keys())}\n" f"Check for typos in the override key." ) old_value = target[final_field] field_type = type(old_value) if field_type not in allowed_field_types: raise ValueError( f"Dict value '{final_field}' is of type '{field_type}' which is not allowed. " f"Allowed types are: {allowed_field_types}" ) target[final_field] = value else: if not hasattr(target, final_field): raise ValueError( f"Field '{final_field}' not found in config path: '{key}'\n" f"Available fields: {list(target.__dataclass_fields__.keys()) if hasattr(target, '__dataclass_fields__') else dir(target)}\n" f"Check for typos in the override key." ) old_value = getattr(target, final_field) field_type = type(old_value) if field_type not in allowed_field_types: raise ValueError( f"Field '{final_field}' is of type '{field_type}' which is not allowed. " f"Allowed types are: {allowed_field_types}" ) setattr(target, final_field, value) log.info(f" {key}: {old_value} -> {value}")
[docs] def parse_cli_overrides(override_strings: list) -> Dict[str, Any]: """ Parse command-line override strings into a dictionary. Supports the format: "key=value" where value can be: - Numbers: "env.max_episode_length=1000" - Floats: "agent.learning_rate=1e-5" - Booleans: "env.enable_terrain=True" - Strings: "env.terrain.type=flat" - None: "env.early_termination=None" Args: override_strings: List of "key=value" strings Returns: Dictionary of parsed overrides Example: parse_cli_overrides(["env.max_episode_length=1000", "simulator.num_envs=4096"]) # Returns: {"env.max_episode_length": 1000, "simulator.num_envs": 4096} """ overrides = {} for override_str in override_strings: if "=" not in override_str: log.warning(f"Invalid override format (missing '='): {override_str}") continue key, value_str = override_str.split("=", 1) key = key.strip() value_str = value_str.strip() # Parse the value try: # Try to evaluate as Python literal (handles int, float, bool, None, lists, etc.) import ast value = ast.literal_eval(value_str) except (ValueError, SyntaxError): # If it fails, treat as string value = value_str overrides[key] = value return overrides