protomotions.agents.utils.data module

Data utilities for experience management and batching.

This module provides utilities for managing experience buffers and creating minibatch datasets for training. Handles efficient storage and retrieval of rollout data collected during environment interaction.

Key Classes:
  • ExperienceBuffer: Buffer for storing rollout experience

  • DictDataset: Dataset for creating minibatches from experience

Key Functions:
  • swap_and_flatten01: Reshape tensors for batching

  • get_dict: Extract dictionary view of experience buffer

protomotions.agents.utils.data.swap_and_flatten01(arr)[source]

Swap and flatten first two dimensions of a tensor.

Converts (num_steps, num_envs, …) to (num_steps * num_envs, …). Commonly used to batch experience from parallel environments.

Parameters:

arr (MockTensor) – Tensor with at least 2 dimensions.

Returns:

Tensor with first two dimensions flattened.

class protomotions.agents.utils.data.ExperienceBuffer(num_envs, num_steps, device)[source]

Bases: <Mock object at 0x701e6b582350>[]

Buffer for storing rollout experience from parallel environments.

Collects observations, actions, rewards, and other data during environment rollouts. Provides efficient storage and batching for on-policy algorithms. Uses PyTorch buffers for automatic device management.

Parameters:
  • num_envs (int) – Number of parallel environments.

  • num_steps (int) – Number of steps per rollout.

store_dict

Dictionary tracking which keys have been populated.

Example

>>> buffer = ExperienceBuffer(num_envs=1024, num_steps=16)
>>> buffer.register_key("obs", shape=(128,))
>>> buffer.update_data("obs", step=0, data=observations)
>>> data_dict = buffer.get_dict()
__init__(num_envs, num_steps, device)[source]
register_key(key, shape=(), dtype=<Mock object>)[source]
update_data(key, index, data)[source]
total_sum()[source]
batch_update_data(key, data)[source]
make_dict()[source]
class protomotions.agents.utils.data.DictDataset(batch_size, tensor_dict, shuffle=False)[source]

Bases: <Mock object at 0x701e6b5904d0>[]

PyTorch Dataset for dictionary of tensors with minibatching.

Creates minibatches from a dictionary of tensors. Supports shuffling and automatic batching for training. Used to create minibatch iterators from collected experience buffers.

Parameters:
  • batch_size (int) – Size of each minibatch.

  • tensor_dict (Dict[str, MockTensor]) – Dictionary of tensors to batch (all same length in dim 0).

  • shuffle – Whether to shuffle indices before batching.

Example

>>> data = {"obs": obs_tensor, "actions": action_tensor}
>>> dataset = DictDataset(batch_size=256, tensor_dict=data, shuffle=True)
>>> for batch in dataset:
>>>     train_on_batch(batch)
__init__(batch_size, tensor_dict, shuffle=False)[source]
shuffle()[source]
num_batches()[source]