Source code for protomotions.agents.masked_mimic.config
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from dataclasses import dataclass, field
from typing import Union, Optional
from enum import Enum
from protomotions.utils.config_builder import ConfigBuilder
from protomotions.agents.common.config import (
MultiOutputModuleConfig,
SequentialModuleConfig,
)
from protomotions.agents.base_agent.config import (
OptimizerConfig,
BaseAgentConfig,
BaseModelConfig,
)
[docs]
@dataclass
class KLDScheduleConfig(ConfigBuilder):
"""Configuration for KL divergence scheduling in VAE training."""
init_kld_coeff: float = 0.0001
end_kld_coeff: float = 0.01
start_epoch: int = 3000
end_epoch: int = 6000
[docs]
class VaeNoiseType(Enum):
NORMAL = "normal"
UNIFORM = "uniform"
ZEROS = "zeros"
[docs]
@classmethod
def from_str(cls, value: str) -> "VaeNoiseType":
"""Create enum from string, case-insensitive."""
try:
return next(
member for member in cls if member.value.lower() == value.lower()
)
except StopIteration:
raise ValueError(
f"'{value}' is not a valid {cls.__name__}. "
f"Valid values are: {[e.value for e in cls]}"
)
return cls(value)
[docs]
@dataclass
class VaeConfig(ConfigBuilder):
"""Configuration for VAE-specific parameters."""
kld_schedule: KLDScheduleConfig = field(default_factory=KLDScheduleConfig)
vae_latent_dim: int = 64
vae_noise_type: VaeNoiseType = VaeNoiseType.NORMAL
[docs]
@dataclass
class FeedForwardModelConfig(BaseModelConfig):
"""Configuration for FeedForwardModel."""
_target_: str = "protomotions.agents.masked_mimic.model.FeedForwardModel"
trunk: SequentialModuleConfig = field(default_factory=SequentialModuleConfig)
[docs]
@dataclass
class MaskedMimicModelConfig(BaseModelConfig):
"""Configuration for MaskedMimic Model (VAE-based imitation learning)."""
_target_: str = "protomotions.agents.masked_mimic.model.MaskedMimicModel"
# VAE components
encoder: MultiOutputModuleConfig = field(default_factory=MultiOutputModuleConfig)
prior: MultiOutputModuleConfig = field(default_factory=MultiOutputModuleConfig)
trunk: SequentialModuleConfig = field(default_factory=SequentialModuleConfig)
vae: VaeConfig = field(default_factory=VaeConfig)
# Optimizer
optimizer: OptimizerConfig = field(default_factory=lambda: OptimizerConfig(lr=2e-5))
[docs]
@dataclass
class MaskedMimicAgentConfig(BaseAgentConfig):
"""Main configuration class for MaskedMimic Agent."""
_target_: str = "protomotions.agents.masked_mimic.agent.MaskedMimic"
# Model configuration
model: Union[MaskedMimicModelConfig, FeedForwardModelConfig] = field(
default_factory=MaskedMimicModelConfig
)
expert_model_path: Optional[str] = None