Source code for protomotions.agents.utils.data

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Data utilities for experience management and batching.

This module provides utilities for managing experience buffers and creating
minibatch datasets for training. Handles efficient storage and retrieval of
rollout data collected during environment interaction.

Key Classes:
    - ExperienceBuffer: Buffer for storing rollout experience
    - DictDataset: Dataset for creating minibatches from experience

Key Functions:
    - swap_and_flatten01: Reshape tensors for batching
    - get_dict: Extract dictionary view of experience buffer
"""

import torch
from torch import Tensor, nn
from torch.utils.data import Dataset
from typing import Dict
import numpy as np


[docs] def swap_and_flatten01(arr: Tensor): """Swap and flatten first two dimensions of a tensor. Converts (num_steps, num_envs, ...) to (num_steps * num_envs, ...). Commonly used to batch experience from parallel environments. Args: arr: Tensor with at least 2 dimensions. Returns: Tensor with first two dimensions flattened. """ if arr is None: return arr s = arr.size() return arr.transpose(0, 1).reshape(s[0] * s[1], *s[2:])
[docs] class ExperienceBuffer(nn.Module): """Buffer for storing rollout experience from parallel environments. Collects observations, actions, rewards, and other data during environment rollouts. Provides efficient storage and batching for on-policy algorithms. Uses PyTorch buffers for automatic device management. Args: num_envs: Number of parallel environments. num_steps: Number of steps per rollout. Attributes: store_dict: Dictionary tracking which keys have been populated. Example: >>> buffer = ExperienceBuffer(num_envs=1024, num_steps=16) >>> buffer.register_key("obs", shape=(128,)) >>> buffer.update_data("obs", step=0, data=observations) >>> data_dict = buffer.get_dict() """
[docs] def __init__(self, num_envs: int, num_steps: int, device: torch.device): super().__init__() self.num_envs = num_envs self.num_steps = num_steps self.store_dict = {} self._device = device
[docs] def register_key(self, key: str, shape=(), dtype=torch.float): assert not hasattr(self, key), key buffer = torch.zeros( (self.num_steps, self.num_envs) + shape, dtype=dtype, device=self._device ) self.register_buffer(key, buffer, persistent=False) self.store_dict[key] = 0
[docs] def update_data(self, key: str, index: int, data: Tensor): assert not data.requires_grad getattr(self, key)[index] = data self.store_dict[key] += index + 1
[docs] def total_sum(self): return (self.num_steps + 1) * (self.num_steps / 2)
[docs] def batch_update_data(self, key: str, data: Tensor): assert not data.requires_grad getattr(self, key)[:] = data self.store_dict[key] = self.total_sum()
[docs] def make_dict(self): data = {k: swap_and_flatten01(v) for k, v in self.named_buffers()} for k, v in self.store_dict.items(): assert v == self.total_sum(), f"Problem with '{k}', {v}, {self.total_sum()}" self.store_dict[k] = 0 return data
[docs] class DictDataset(Dataset): """PyTorch Dataset for dictionary of tensors with minibatching. Creates minibatches from a dictionary of tensors. Supports shuffling and automatic batching for training. Used to create minibatch iterators from collected experience buffers. Args: batch_size: Size of each minibatch. tensor_dict: Dictionary of tensors to batch (all same length in dim 0). shuffle: Whether to shuffle indices before batching. Example: >>> data = {"obs": obs_tensor, "actions": action_tensor} >>> dataset = DictDataset(batch_size=256, tensor_dict=data, shuffle=True) >>> for batch in dataset: >>> train_on_batch(batch) """
[docs] def __init__( self, batch_size: int, tensor_dict: Dict[str, Tensor], shuffle=False, ): assert len(tensor_dict) > 0 lengths_dict = {k: len(v) for k, v in tensor_dict.items()} assert ( len(set(lengths_dict.values())) == 1 ), f"All tensors must have the same length. Found: {lengths_dict}" self.num_tensors = next(iter(lengths_dict.values())) self.batch_size = batch_size assert ( self.num_tensors % self.batch_size == 0 ), f"{self.num_tensors} {self.batch_size}" self.tensor_dict = tensor_dict self.do_shuffle = shuffle self.shuffled_to_original = np.arange(self.num_tensors) if shuffle: self.shuffle()
[docs] def shuffle(self): self.shuffled_to_original = np.random.permutation(self.num_tensors)
[docs] def num_batches(self): return self.num_tensors // self.batch_size
def __len__(self): return self.num_batches() def __getitem__(self, index): assert index < len(self), f"{index} {len(self)}" start_idx = index * self.batch_size end_idx = min((index + 1) * self.batch_size, self.num_tensors) return { k: v[self.shuffled_to_original[start_idx:end_idx]] for k, v in self.tensor_dict.items() }