-
Notifications
You must be signed in to change notification settings - Fork 392
[Feature] Compressed storage gpu #3062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feature] Compressed storage gpu #3062
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/3062
Note: Links to docs will display an error until the docs builds have been completed. ❌ 13 New Failures, 1 Pending, 3 Unrelated FailuresAs of commit 0dbd233 with merge base db0e30d ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
When the tensor is on the CPU
|
… cursor logic to a view class. Passing all tests now.
74d85fa
to
95f532e
Compare
…o_bytestream speed test.
95f532e
to
5581cf6
Compare
Added some examples of compressing on the cpu, batched decompression on the gpu. I noticed that in my example of an Atari rollout, the
Compressing on the CPU first, then transferring, and re-using this compressed observation for the next transition gets about double the Atari transitions per second. |
@vmoens I think we're essentially done with this PR, except for a cleanup pass. Do we want CompressedListStorage to be mentioned in the documentation? Maybe have a page on compression to showcase the VRAM storage savings on the gpu? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with this! Mainly nits and minor aesthetic comments on the examples and doc but otherwise good to go!
@@ -20,7 +20,7 @@ repos: | |||
- libcst == 0.4.7 | |||
|
|||
- repo: https://github.com/pycqa/flake8 | |||
rev: 4.0.1 | |||
rev: 6.0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not against upgrades but can you comment why we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4.0.1 has an issue that causes the error below:
flake8...................................................................Failed
- hook id: flake8
- exit code: 1
Traceback (most recent call last):
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/bin/flake8", line 8, in <module>
sys.exit(main())
^^^^^^
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/main/cli.py", line 22, in main
app.run(argv)
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/main/application.py", line 375, in run
self._run(argv)
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/main/application.py", line 363, in _run
self.initialize(argv)
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/main/application.py", line 343, in initialize
self.find_plugins(config_finder)
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/main/application.py", line 157, in find_plugins
self.check_plugins = plugin_manager.Checkers(local_plugins.extension)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/plugins/manager.py", line 363, in __init__
self.manager = PluginManager(
^^^^^^^^^^^^^^
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/plugins/manager.py", line 243, in __init__
self._load_entrypoint_plugins()
File "/home/adrian/.cache/pre-commit/repoxdgdrlah/py_env-python3.12/lib/python3.12/site-packages/flake8/plugins/manager.py", line 261, in _load_entrypoint_plugins
eps = importlib_metadata.entry_points().get(self.namespace, ())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'EntryPoints' object has no attribute 'get'
From what I understand from the issues that recommend upgrading, there was a depreciation by importlib-metadata
which is fixed in 5.x.x or greater.
As an aside, you might like the speed of Ruff, I use it for my own projects.
- **Data Integrity**: Maintains full data fidelity through lossless compression | ||
- **Flexible Compression**: Supports custom compression algorithms or uses zstd by default | ||
- **TensorDict Support**: Seamlessly works with TensorDict structures | ||
- **Checkpointing**: Full support for saving and loading compressed data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add GPU support here
>>> import torch | ||
>>> from torchrl.data import ReplayBuffer, CompressedListStorage | ||
>>> from tensordict import TensorDict | ||
>>> | ||
>>> # Create a compressed storage for image data | ||
>>> storage = CompressedListStorage(max_size=1000, compression_level=3) | ||
>>> rb = ReplayBuffer(storage=storage, batch_size=32) | ||
>>> | ||
>>> # Add image data | ||
>>> images = torch.randn(100, 3, 84, 84) # Atari-like frames | ||
>>> data = TensorDict({"obs": images}, batch_size=[100]) | ||
>>> rb.extend(data) | ||
>>> | ||
>>> # Sample data (automatically decompressed) | ||
>>> sample = rb.sample(16) | ||
>>> print(sample["obs"].shape) # torch.Size([16, 3, 84, 84]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@AdrianOrenstein Can you check that this still makes sense and runs?
@@ -0,0 +1,199 @@ | |||
#!/usr/bin/env python3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need Meta headers here
|
||
from torchrl.data.replay_buffers.storages import ListStorage | ||
|
||
gym.register_envs(ale_py) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe after the imports?
@@ -0,0 +1,182 @@ | |||
#!/usr/bin/env python3 | |||
""" | |||
Example demonstrating the use of CompressedStorage for memory-efficient replay buffers on the GPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth having 2 distinct examples? Not particularly unhappy with this, just more repeated code and could be harder for people to spot what the differences are between the two (hard to know where to focus your attention)
# === CompressedListStorage + ReplayBuffer with GPU compression === | ||
print("\n=== ListStorage + ReplayBuffer (GPU) Example ===\n") | ||
|
||
codec = nvcomp.Codec(algorithm=algorithm, bitstream_kind=bitstream_kind) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this deserves a bit of explanation :)
Also I would check if the lib is installed beforehand and raise a warning /exception if it cannot be found
has_nvcomp = importlib.util.findspec("nvcomp", None) is not None
if not has_nvcomp:
raise ImportError(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this one?
storage = CompressedListStorage(max_size=1000, compression_level=6) | ||
|
||
# Create some sample data with different shapes and types | ||
print("Creating sample data...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should remove the prints in favor of logger
Thanks for the code review! |
Description
Replay buffers are used to store a lot of data and are used to feed neural networks with batched samples to learn from. So then ideally we could put this data as close to where the network is being updated. Often raw sensory observations are stored in these buffers, such as images, audio, or text, which consumes many gigabytes of precious memory. CPU memory and accelerator VRAM may be limited, or memory transfer between these devices may be costly. So this PR aims to streamline data compression to aid in efficient storage and memory transfer.
Mainly, creating a compressed storage object will aid in training state-of-the-art RL methods on benchmarks such as the Atari Learning Environment. The
~torchrl.data.replay_buffers.storages.CompressedStorage
class provides the memory savings through compression.closes #3058
closes #2983
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!