Source code for protomotions.agents.common.transformer

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Transformer architecture for sequential modeling.

This module implements transformer-based networks for processing temporal information
in reinforcement learning. Used primarily in motion tracking and MaskedMimic agents
for handling sequential observations.

Key Classes:
    - Transformer: Main transformer model with positional encoding
    - PositionalEncoding: Sinusoidal positional encodings for sequence position

Key Features:
    - Multi-head self-attention for temporal dependencies
    - Multiple input heads with different encoders
    - Positional encoding for sequence awareness
    - Flexible output heads (single or multi-headed)
"""

import torch
from torch import nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase

from protomotions.agents.utils.training import get_activation_func
from protomotions.agents.common.config import TransformerConfig


[docs] class Transformer(TensorDictModuleBase): """Transformer network for sequential observation processing. Processes multi-modal sequential inputs through separate encoders, combines them into a sequence of tokens, and applies transformer layers for temporal modeling. Used in motion tracking agents to process future reference poses. Args: config: Transformer configuration specifying architecture parameters. Attributes: input_models: Dictionary of input encoders for different observation types. sequence_pos_encoder: Positional encoding layer. seqTransEncoder: Stack of transformer encoder layers. in_keys: List of input keys collected from all input models. out_keys: List containing output key. Example: >>> config = TransformerConfig() >>> model = Transformer(config) >>> output_td = model(tensordict) """
[docs] def __init__(self, config: TransformerConfig): super().__init__() self.config = config # Set TensorDict keys self.in_keys = self.config.in_keys self.out_keys = self.config.out_keys self.output_activation = None if self.config.output_activation is not None: self.output_activation = get_activation_func(self.config.output_activation) # Extract all input tokens that aren't masks. token_input_keys = [] mask_keys = ( [value for value in self.config.input_and_mask_mapping.values()] if self.config.input_and_mask_mapping else [] ) for in_key in self.in_keys: if in_key not in mask_keys: token_input_keys.append(in_key) self._token_input_keys = token_input_keys # Transformer layers seqTransEncoderLayer = nn.TransformerEncoderLayer( d_model=self.config.latent_dim, nhead=self.config.num_heads, dim_feedforward=self.config.ff_size, dropout=self.config.dropout, activation=get_activation_func( self.config.activation, return_type="functional" ), batch_first=True, ) self.seqTransEncoder = nn.TransformerEncoder( seqTransEncoderLayer, num_layers=self.config.num_layers )
[docs] def forward(self, tensordict: TensorDict) -> TensorDict: """Forward pass through transformer. Args: tensordict: TensorDict containing all input observations. Returns: TensorDict with transformer output added at self.out_keys[0]. """ all_tokens = [] for in_key in self._token_input_keys: if tensordict[in_key].dim() == 2: all_tokens.append(tensordict[in_key].unsqueeze(1)) else: all_tokens.append(tensordict[in_key]) all_tokens = torch.cat(all_tokens, dim=1) all_masks = [] for in_key in self._token_input_keys: if ( self.config.input_and_mask_mapping and in_key in self.config.input_and_mask_mapping ): mask_key = self.config.input_and_mask_mapping[in_key] # Our mask is 1 for valid and 0 for invalid # The transformer expects the mask to be 0 for valid and 1 for invalid mask = tensordict[mask_key].logical_not() if tensordict[mask_key].dim() == 1: all_masks.append(mask.unsqueeze(1)) else: all_masks.append(mask) else: if tensordict[in_key].dim() == 2: all_masks.append( torch.zeros( tensordict.batch_size[0], 1, dtype=torch.bool, device=tensordict[in_key].device, ) ) else: all_masks.append( torch.zeros( tensordict.batch_size[0], tensordict[in_key].shape[1], dtype=torch.bool, device=tensordict[in_key].device, ) ) all_masks = torch.cat(all_masks, dim=1) output = self.seqTransEncoder( all_tokens, src_key_padding_mask=all_masks ) # [batch, seq_len, features] output = output[:, 0, :] # [batch, features] - take first token if self.output_activation is not None: output = self.output_activation(output) tensordict[self.out_keys[0]] = output return tensordict