Skip to content

[Feature] Compressed storage gpu #3062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ dependencies:
- transformers
- ninja
- timm
- safetensors
1 change: 1 addition & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
python3 -m pip install "pybind11[global]"
python3.10 -m pip install git+https://github.com/pytorch/tensordict
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
python3.10 setup.py develop

# test import
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/benchmarks_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
python3.10 -m pip install ninja pytest pytest-benchmark mujoco dm_control "gym[accept-rom-license,atari]"
python3.10 -m pip install "pybind11[global]"
python3.10 -m pip install git+https://github.com/pytorch/tensordict
python3.10 -m pip install safetensors tqdm pandas numpy matplotlib
python3.10 setup.py develop
# python3.10 -m pip install git+https://github.com/pytorch/rl@$GITHUB_BRANCH

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- libcst == 0.4.7

- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
args: [--config=setup.cfg]
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
pytest-benchmark
tenacity
safetensors
tqdm
pandas
numpy
matplotlib
145 changes: 145 additions & 0 deletions benchmarks/test_compressed_storage_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import io
import pickle

import pytest
import torch
try:
from safetensors.torch import save
except ImportError:
save = None

from torchrl.data import CompressedListStorage


class TestCompressedStorageBenchmark:
"""Benchmark tests for CompressedListStorage."""

@staticmethod
def make_compressible_mock_data(num_experiences: int, device=None) -> dict:
"""Easily compressible data for testing."""
if device is None:
device = torch.device("cpu")

return {
"observations": torch.zeros(
(num_experiences, 4, 84, 84),
dtype=torch.uint8,
device=device,
),
"actions": torch.zeros((num_experiences,), device=device),
"rewards": torch.zeros((num_experiences,), device=device),
"next_observations": torch.zeros(
(num_experiences, 4, 84, 84),
dtype=torch.uint8,
device=device,
),
"terminations": torch.zeros(
(num_experiences,), dtype=torch.bool, device=device
),
"truncations": torch.zeros(
(num_experiences,), dtype=torch.bool, device=device
),
"batch_size": [num_experiences],
}

@staticmethod
def make_uncompressible_mock_data(num_experiences: int, device=None) -> dict:
"""Uncompressible data for testing."""
if device is None:
device = torch.device("cpu")
return {
"observations": torch.randn(
(num_experiences, 4, 84, 84),
dtype=torch.float32,
device=device,
),
"actions": torch.randint(0, 10, (num_experiences,), device=device),
"rewards": torch.randn(
(num_experiences,), dtype=torch.float32, device=device
),
"next_observations": torch.randn(
(num_experiences, 4, 84, 84),
dtype=torch.float32,
device=device,
),
"terminations": torch.rand((num_experiences,), device=device)
< 0.2, # ~20% True
"truncations": torch.rand((num_experiences,), device=device)
< 0.1, # ~10% True
"batch_size": [num_experiences],
}

@pytest.mark.benchmark(
group="tensor_serialization_speed",
min_time=0.1,
max_time=0.5,
min_rounds=5,
disable_gc=True,
warmup=False,
)
@pytest.mark.parametrize(
"serialization_method",
["pickle", "torch.save", "untyped_storage", "numpy", "safetensors"],
)
def test_tensor_to_bytestream_speed(self, benchmark, serialization_method: str):
"""Benchmark the speed of different tensor serialization methods.

TODO: we might need to also test which methods work on the gpu.
pytest benchmarks/test_compressed_storage_benchmark.py::TestCompressedStorageBenchmark::test_tensor_to_bytestream_speed -v --benchmark-only --benchmark-sort='mean' --benchmark-columns='mean, ops'

------------------------ benchmark 'tensor_to_bytestream_speed': 5 tests -------------------------
Name (time in us) Mean (smaller is better) OPS (bigger is better)
--------------------------------------------------------------------------------------------------
test_tensor_serialization_speed[numpy] 2.3520 (1.0) 425,162.1779 (1.0)
test_tensor_serialization_speed[safetensors] 14.7170 (6.26) 67,948.7129 (0.16)
test_tensor_serialization_speed[pickle] 19.0711 (8.11) 52,435.3333 (0.12)
test_tensor_serialization_speed[torch.save] 32.0648 (13.63) 31,186.8261 (0.07)
test_tensor_serialization_speed[untyped_storage] 38,227.0224 (>1000.0) 26.1595 (0.00)
--------------------------------------------------------------------------------------------------
"""

def serialize_with_pickle(data: torch.Tensor) -> bytes:
"""Serialize tensor using pickle."""
buffer = io.BytesIO()
pickle.dump(data, buffer)
return buffer.getvalue()

def serialize_with_untyped_storage(data: torch.Tensor) -> bytes:
"""Serialize tensor using torch's built-in method."""
return bytes(data.untyped_storage())

def serialize_with_numpy(data: torch.Tensor) -> bytes:
"""Serialize tensor using numpy."""
return data.numpy().tobytes()

def serialize_with_safetensors(data: torch.Tensor) -> bytes:
return save({"0": data})

def serialize_with_torch(data: torch.Tensor) -> bytes:
"""Serialize tensor using torch's built-in method."""
buffer = io.BytesIO()
torch.save(data, buffer)
return buffer.getvalue()

# Benchmark each serialization method
if serialization_method == "pickle":
serialize_fn = serialize_with_pickle
elif serialization_method == "torch.save":
serialize_fn = serialize_with_torch
elif serialization_method == "untyped_storage":
serialize_fn = serialize_with_untyped_storage
elif serialization_method == "numpy":
serialize_fn = serialize_with_numpy
elif serialization_method == "safetensors":
serialize_fn = serialize_with_safetensors
else:
raise ValueError(f"Unknown serialization method: {serialization_method}")

data = self.make_compressible_mock_data(1).get("observations")

# Run the actual benchmark
benchmark(serialize_fn, data)
62 changes: 62 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ using the following components:
:template: rl_template.rst


CompressedListStorage
CompressedListStorageCheckpointer
FlatStorageCheckpointer
H5StorageCheckpointer
ImmutableDatasetWriter
Expand Down Expand Up @@ -191,6 +193,66 @@ were found from rough benchmarking in https://github.com/pytorch/rl/tree/main/be
| :class:`LazyMemmapStorage` | 3.44x |
+-------------------------------+-----------+

Compressed Storage for Memory Efficiency
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

For applications where memory usage or memory bandwidth is a primary concern, especially when storing or transferring
large sensory observations like images, audio, or text. The :class:`~torchrl.data.replay_buffers.storages.CompressedListStorage`
provides significant memory savings through compression.

The `CompressedListStorage`` compresses data when storing and decompresses when retrieving,
achieving compression ratios of 2-10x for image data while maintaining full data fidelity.
It uses zstd compression by default but supports custom compression algorithms.

Key features:
- **Memory Efficiency**: Achieves significant memory savings through compression
- **Data Integrity**: Maintains full data fidelity through lossless compression
- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default
- **TensorDict Support**: Seamlessly works with TensorDict structures
- **Checkpointing**: Full support for saving and loading compressed data

Example usage:

>>> import torch
>>> from torchrl.data import ReplayBuffer, CompressedListStorage
>>> from tensordict import TensorDict
>>>
>>> # Create a compressed storage for image data
>>> storage = CompressedListStorage(max_size=1000, compression_level=3)
>>> rb = ReplayBuffer(storage=storage, batch_size=32)
>>>
>>> # Add image data
>>> images = torch.randn(100, 3, 84, 84) # Atari-like frames
>>> data = TensorDict({"obs": images}, batch_size=[100])
>>> rb.extend(data)
>>>
>>> # Sample data (automatically decompressed)
>>> sample = rb.sample(16)
>>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84])

The compression level can be adjusted from 1 (fast, less compression) to 22 (slow, more compression),
with level 3 being a good default for most use cases.

For custom compression algorithms:

>>> def my_compress(tensor):
... return tensor.to(torch.uint8) # Simple example
>>>
>>> def my_decompress(compressed_tensor, metadata):
... return compressed_tensor.to(metadata["dtype"])
>>>
>>> storage = CompressedListStorage(
... max_size=1000,
... compression_fn=my_compress,
... decompression_fn=my_decompress
... )

.. note:: The CompressedListStorage requires the `zstandard` library for default compression.
Install with: ``pip install zstandard``

.. note:: An example of how to use the CompressedListStorage is available in the
`examples/replay-buffers/compressed_replay_buffer_example.py <https://github.com/pytorch/rl/blob/main/examples/replay-buffers/compressed_replay_buffer_example.py>`_ file.

Sharing replay buffers across processes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
Loading
Loading