protomotions.agents.utils.data module¶
Data utilities for experience management and batching.
This module provides utilities for managing experience buffers and creating minibatch datasets for training. Handles efficient storage and retrieval of rollout data collected during environment interaction.
- Key Classes:
ExperienceBuffer: Buffer for storing rollout experience
DictDataset: Dataset for creating minibatches from experience
- Key Functions:
swap_and_flatten01: Reshape tensors for batching
get_dict: Extract dictionary view of experience buffer
- protomotions.agents.utils.data.swap_and_flatten01(arr)[source]¶
Swap and flatten first two dimensions of a tensor.
Converts (num_steps, num_envs, …) to (num_steps * num_envs, …). Commonly used to batch experience from parallel environments.
- Parameters:
arr (MockTensor) – Tensor with at least 2 dimensions.
- Returns:
Tensor with first two dimensions flattened.
- class protomotions.agents.utils.data.ExperienceBuffer(num_envs, num_steps, device)[source]¶
Bases:
<Mock object at 0x701e6b582350>[]Buffer for storing rollout experience from parallel environments.
Collects observations, actions, rewards, and other data during environment rollouts. Provides efficient storage and batching for on-policy algorithms. Uses PyTorch buffers for automatic device management.
- Parameters:
- store_dict¶
Dictionary tracking which keys have been populated.
Example
>>> buffer = ExperienceBuffer(num_envs=1024, num_steps=16) >>> buffer.register_key("obs", shape=(128,)) >>> buffer.update_data("obs", step=0, data=observations) >>> data_dict = buffer.get_dict()
- class protomotions.agents.utils.data.DictDataset(batch_size, tensor_dict, shuffle=False)[source]¶
Bases:
<Mock object at 0x701e6b5904d0>[]PyTorch Dataset for dictionary of tensors with minibatching.
Creates minibatches from a dictionary of tensors. Supports shuffling and automatic batching for training. Used to create minibatch iterators from collected experience buffers.
- Parameters:
Example
>>> data = {"obs": obs_tensor, "actions": action_tensor} >>> dataset = DictDataset(batch_size=256, tensor_dict=data, shuffle=True) >>> for batch in dataset: >>> train_on_batch(batch)