protomotions.agents.utils.replay_buffer module

Replay buffer for off-policy learning.

This module provides a circular replay buffer used in AMP and ASE for storing agent transitions. The discriminator trains on batches sampled from this buffer.

Key Classes:
  • ReplayBuffer: Circular buffer with random sampling

class protomotions.agents.utils.replay_buffer.ReplayBuffer(buffer_size, device)[source]

Bases: <Mock object at 0x701e6b316390>[]

Circular replay buffer for storing and sampling transitions.

Stores agent transitions in a circular buffer and provides random sampling for discriminator training in AMP/ASE. Automatically handles buffer overflow by overwriting oldest data.

Parameters:
  • buffer_size – Maximum number of transitions to store.

  • device (<Mock object at 0x701e6b316e90>[]) – PyTorch device for tensors.

_head

Current write position in buffer.

_is_full

Whether buffer has wrapped around.

Example

>>> buffer = ReplayBuffer(buffer_size=10000, device=torch.device("cuda"))
>>> buffer.store({"obs": observations, "actions": actions})
>>> samples = buffer.sample(256)  # Sample 256 transitions
__init__(buffer_size, device)[source]
reset()[source]
get_buffer_size()[source]
store(data_dict)[source]
sample(n)[source]
property device: <Mock object at 0x701e6b31e5d0>[]

Get the current device.