Source code for protomotions.agents.masked_mimic.model
# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from tensordict import TensorDict
from protomotions.utils.hydra_replacement import get_class
from protomotions.agents.common.common import SequentialModule, MultiOutputModule
from protomotions.agents.base_agent.model import BaseModel
# Import for type annotations - using TYPE_CHECKING to avoid circular imports
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from protomotions.agents.masked_mimic.config import MaskedMimicModelConfig
[docs]
class FeedForwardModel(BaseModel):
"""Simple feedforward model for masked mimic without VAE."""
[docs]
def __init__(self, config):
super().__init__(config)
self.config = config
TrunkClass: SequentialModule = get_class(self.config.trunk._target_)
self._trunk = TrunkClass(config=self.config.trunk)
# Set TensorDict keys
self.in_keys = self._trunk.in_keys
self.out_keys = ["action"]
[docs]
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass computing action.
Args:
tensordict: TensorDict containing observations.
Returns:
TensorDict with action added.
"""
tensordict = self._trunk(tensordict)
action = tensordict[self._trunk.config.out_key]
tensordict["action"] = action
return tensordict
[docs]
class MaskedMimicModel(BaseModel):
"""MaskedMimic model architecture with Variational Autoencoder (VAE).
Combines a prior network (acting on sparse observations) and an encoder
network (acting on full/expert observations) to learn a latent space for
motion control.
The prior learns to predict the expert's latent distribution from sparse
data, enabling the agent to perform robust control when full state info
is unavailable during inference.
Args:
config: Configuration for the MaskedMimic model.
"""
config: "MaskedMimicModelConfig"
[docs]
def __init__(self, config: "MaskedMimicModelConfig"):
"""Initialize the MaskedMimic model components."""
super().__init__(config)
# create networks
EncoderClass = get_class(self.config.encoder._target_)
self._encoder: MultiOutputModule = EncoderClass(config=self.config.encoder)
PriorClass = get_class(self.config.prior._target_)
self._prior: SequentialModule = PriorClass(config=self.config.prior)
TrunkClass = get_class(self.config.trunk._target_)
self._trunk: SequentialModule = TrunkClass(config=self.config.trunk)
# Set TensorDict keys (collect from all components)
# Include vae_noise as an input requirement
trunk_in_keys_without_latents = [
key for key in self._trunk.in_keys if key not in ["vae_latent"]
]
self.in_keys = list(
set(
self._prior.in_keys
+ self._encoder.in_keys
+ ["vae_noise"]
+ trunk_in_keys_without_latents
)
)
self.out_keys = ["action", "privileged_action"]
[docs]
@staticmethod
def reparameterization(mean, std, vae_noise):
"""Reparameterization trick: z = mu + std * noise"""
z = mean + std * vae_noise
return z
[docs]
def forward(self, tensordict: TensorDict) -> TensorDict:
"""Forward pass through MaskedMimic model.
Always computes both prior and encoder for consistency and ONNX compatibility.
Expects vae_noise to be provided in tensordict (generated by agent).
Args:
tensordict: TensorDict containing observations and vae_noise.
Returns:
TensorDict with action and all VAE outputs.
"""
# Compute prior outputs
tensordict = self._prior(tensordict)
prior_mu = tensordict[self._prior.out_keys[0]]
prior_logvar = tensordict[self._prior.out_keys[1]]
# Reparameterization using external noise
std = torch.exp(0.5 * prior_logvar)
vae_noise = tensordict["vae_noise"]
z = self.reparameterization(
prior_mu, std, vae_noise
) # z is the latent code for the action
tensordict["vae_latent"] = z
# Compute non-privileged action (prior path)
tensordict = self._trunk(tensordict)
action = tensordict[self._trunk.out_keys[0]]
# Compute encoder outputs
tensordict = self._encoder(tensordict)
encoder_mu = tensordict[self._encoder.out_keys[0]]
encoder_logvar = tensordict[self._encoder.out_keys[1]]
# Combine: encoder mu is residual to prior mu
privileged_mu = prior_mu + encoder_mu
privileged_logvar = encoder_logvar # Use encoder's logvar directly
# Combine privileged mu and logvar to get privileged z
privileged_std = torch.exp(0.5 * privileged_logvar)
privileged_z = self.reparameterization(privileged_mu, privileged_std, vae_noise)
# Compute privileged action (prior + encoder path)
tensordict["vae_latent"] = privileged_z
tensordict = self._trunk(tensordict)
privileged_action = tensordict[self._trunk.out_keys[0]]
tensordict["action"] = action
tensordict["privileged_action"] = privileged_action
return tensordict
[docs]
def kl_loss(self, tensordict: TensorDict):
"""Compute KL divergence between encoder and prior.
Args:
tensordict: TensorDict containing prior and encoder outputs.
Returns:
KL divergence tensor.
"""
return 0.5 * (
tensordict["prior_logvar"]
- tensordict["encoder_logvar"]
+ torch.exp(tensordict["encoder_logvar"])
/ torch.exp(tensordict["prior_logvar"])
+ tensordict["encoder_mu"] ** 2 / torch.exp(tensordict["prior_logvar"])
- 1
)