Source code for protomotions.agents.utils.replay_buffer

# SPDX-FileCopyrightText: Copyright (c) 2025 The ProtoMotions Developers
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Replay buffer for off-policy learning.

This module provides a circular replay buffer used in AMP and ASE for storing
agent transitions. The discriminator trains on batches sampled from this buffer.

Key Classes:
    - ReplayBuffer: Circular buffer with random sampling
"""

import torch
from torch import nn


[docs] class ReplayBuffer(nn.Module): """Circular replay buffer for storing and sampling transitions. Stores agent transitions in a circular buffer and provides random sampling for discriminator training in AMP/ASE. Automatically handles buffer overflow by overwriting oldest data. Args: buffer_size: Maximum number of transitions to store. device: PyTorch device for tensors. Attributes: _head: Current write position in buffer. _is_full: Whether buffer has wrapped around. Example: >>> buffer = ReplayBuffer(buffer_size=10000, device=torch.device("cuda")) >>> buffer.store({"obs": observations, "actions": actions}) >>> samples = buffer.sample(256) # Sample 256 transitions """
[docs] def __init__(self, buffer_size, device: torch.device): super().__init__() self._head = 0 self._is_full = False self._buffer_size = buffer_size self._buffer_keys = [] self._device = device
[docs] def reset(self): self._head = 0 self._is_full = False
[docs] def get_buffer_size(self): return self._buffer_size
def __len__(self) -> int: return self._buffer_size if self._is_full else self._head
[docs] def store(self, data_dict): self._maybe_init_data_buf(data_dict) n = next(iter(data_dict.values())).shape[0] buffer_size = self.get_buffer_size() assert n <= buffer_size for key in self._buffer_keys: curr_buf = getattr(self, key) curr_n = data_dict[key].shape[0] assert n == curr_n end = self._head + n if end >= self._buffer_size: diff = self._buffer_size - self._head curr_buf[self._head :] = data_dict[key][:diff].clone() curr_buf[: n - diff] = data_dict[key][diff:].clone() self._is_full = True else: curr_buf[self._head : end] = data_dict[key].clone() self._head = (self._head + n) % buffer_size
[docs] def sample(self, n): indices = torch.randint(0, len(self), (n,), device=self.device) samples = dict() for k in self._buffer_keys: v = getattr(self, k) samples[k] = v[indices].clone() return samples
def _maybe_init_data_buf(self, data_dict): buffer_size = self.get_buffer_size() for k, v in data_dict.items(): if not hasattr(self, k): v_shape = v.shape[1:] self.register_buffer( k, torch.zeros( (buffer_size,) + v_shape, dtype=v.dtype, device=self.device ), persistent=False, ) self._buffer_keys.append(k) @property def device(self) -> torch.device: """Get the current device.""" return self._device