Source code for protomotions.utils.hydra_replacement

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Simple replacement for hydra.utils functions to avoid the heavy hydra-core dependency.
Provides get_class and instantiate functions with compatible APIs.
"""

import importlib
from typing import Any


[docs] def get_class(path: str) -> type: """ Import and return a class from a string path. Args: path: Fully qualified class path, e.g., "torch.optim.Adam" Returns: The class object Example: >>> Adam = get_class("torch.optim.Adam") >>> optimizer = Adam(params, lr=0.001) """ module_path, class_name = path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name)
[docs] def instantiate(config, **kwargs) -> Any: """ Instantiate a class from a config object. Args: config: Config object with _target_ attribute specifying the class path, or a dict with '_target_' key **kwargs: Additional keyword arguments to pass to the constructor, overriding config values Returns: Instance of the specified class Example: >>> class Config: ... _target_ = "torch.optim.Adam" ... lr = 0.001 >>> optimizer = instantiate(config, params=model.parameters()) """ # Handle both dict and object configs if isinstance(config, dict): target = config.get("_target_") config_dict = {k: v for k, v in config.items() if k != "_target_"} else: target = getattr(config, "_target_", None) config_dict = { k: v for k, v in vars(config).items() if k != "_target_" and not k.startswith("_") } if target is None: raise ValueError( "Config must have a '_target_' attribute or key specifying the class path" ) # Get the class cls = get_class(target) # Merge config and kwargs (kwargs take precedence) merged_kwargs = {**config_dict, **kwargs} # Instantiate and return return cls(**merged_kwargs)