-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use MPS device when available #951
base: master
Are you sure you want to change the base?
Conversation
@qgallouedec could you test this PR (do We should probably add a warning in the doc about the minimum pytorch version? (or in the code) |
Not only the pytest failed, but it caused a Python Fatal Error:
Don't know what it is. I will investigate. |
Well, I'm pretty sure the problem comes from the fact that the observation is transposed before passing into the CNN of the feature extractor, and this seems to cause some more bugs: pytorch/pytorch#81557 To reproduce: from stable_baselines3 import A2C
from stable_baselines3.common.envs import FakeImageEnv
env = FakeImageEnv()
model = A2C("CnnPolicy", env).learn(250) It causes fatal error in this line:
without traceback, but with this error message:
But more generally, there are still some features missing, such as support for the multinomial distribution (pytorch/pytorch#80760) for SB3 to work fully on the mps device So we still have to be a bit patient. |
Thanks for testing =) |
Pytorch 1.13 is out. MPS is still not fully supported and causes bugs in SB3. |
@qgallouedec , can you please provide which Ops are missing ? |
Is this still happening in latest nightly cc @qgallouedec ? |
With the latest nightly: % /Users/quentingallouedec/stable-baselines3/env/bin/python /Users/quentingallouedec/stable-baselines3/test_mps.py
[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
Traceback (most recent call last):
File "/Users/quentingallouedec/stable-baselines3/test_mps.py", line 5, in <module>
model = A2C("CnnPolicy", env).learn(250)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 193, in learn
return super().learn(
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
actions, values, log_probs = self.policy(obs_tensor)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1427, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 576, in forward
log_prob = distribution.log_prob(actions)
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 279, in log_prob
return self.distribution.log_prob(actions)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/categorical.py", line 123, in log_prob
self._validate_sample(value)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample
valid = support.check(value)
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/constraints.py", line 257, in check
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
NotImplementedError: The operator 'aten::remainder.Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS. EDIT: tested with PyTorch 2.0.0.dev20221220 |
Its in PR. Will try to priortize the merge. |
Is there any progress on this? Is mps usable in any way already? |
@BasLaa you can already give it a try by passing |
@qgallouedec how is the support with PyTorch 2.1.0? |
The number of errors decreases. Here's one a them:
Is double precision a feature of sb3 or should single precision be forced systematically? |
I think we don't really support float64... (mainly to avoid issues when using CUDA) |
If you need someone to test something please tell me I could with my Mac because this PR is there for a while now and nobody comes with a solution or a review ... |
@tty666 thank you for the proposal. Feel free to test and provide your feedback if any. As far as I remember, there are still some issues related to dtype (float64 instead of float32), see #951 (comment). As soon as all the CI passes, we can consider this PR as ready to be merged |
Any news regarding this PR? Is someone working on it? |
|
Hello! I just tried this out, out of curiosity and it seems to work. The small snippet above and another project I have been working on recently work very similarly with and without MPS. I can see GPU going to 100% with asitop and no crashes. Performance-wise it's not as good as we might expect but that might related to my particular use-case. |
Hi. I see the tests are still failing. I'll try to give a bit more details on my setup. First, I'm running a MacBook Pro M1 Pro. The test from yesterday was running with Python 3.12. This morning, I cloned the repo, switched to the feat/mps-support branch, created a Python 3.11 venv and ran
|
hi 👋 i would like to help move this pr forward, i see there hasnt been much progress in past few months, i have an m1 mac studio where i'm testing this branch with this setup:
can someone point me in the right direction for the changes that i need to do to make the tests pass? i seen in this pr only 3 files have been changed but i didn't find examples fixes of these issues edit: i tried my best to do things with common sense and fixed all tests, have a look at this pr #2005 |
Description
Add support for MPS device (uses it if available) and save cloudpickle version (important to debug saving/loading issues).
DO NOT MERGE: this PR must be tested on a MPS device first
closes #914
Motivation and Context
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line