Skip to content

Commit

Permalink
Add warning when using PPO on GPU and update doc (#2017)
Browse files Browse the repository at this point in the history
* Update documentation

Added comment to PPO documentation that CPU should primarily be used unless using CNN as well as sample code. Added warning to user for both PPO and A2C that CPU should be used if the user is running GPU without using a CNN, reference Issue #1245.

* Add warning to base class and add test

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
Dev1nW and araffin authored Oct 7, 2024
1 parent 512eea9 commit 56c153f
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a9 (WIP)
Release 2.4.0a10 (WIP)
--------------------------

.. note::
Expand Down Expand Up @@ -60,12 +60,14 @@ Others:
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``

Bug Fixes:
^^^^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Updated PPO doc to recommend using CPU with ``MlpPolicy``

Release 2.3.2 (2024-04-27)
--------------------------
Expand Down
17 changes: 17 additions & 0 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
vec_env.render("human")
.. note::

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
-------

Expand Down
23 changes: 23 additions & 0 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -135,6 +136,28 @@ def _setup_model(self) -> None:
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
# Warn when not using CPU with MlpPolicy
self._maybe_recommend_cpu()

def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)

def collect_rollouts(
self,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a9
2.4.0a10
14 changes: 12 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gymnasium as gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -211,8 +212,11 @@ def test_warn_dqn_multi_env():


def test_ppo_warnings():
"""Test that PPO warns and errors correctly on
problematic rollout buffer sizes"""
"""
Test that PPO warns and errors correctly on
problematic rollout buffer sizes,
and recommend using CPU.
"""

# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
Expand All @@ -234,3 +238,9 @@ def test_ppo_warnings():
loss = model.logger.name_to_value["train/loss"]
assert loss > 0
assert not np.isnan(loss) # check not nan (since nan does not equal nan)

with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
model = PPO("MlpPolicy", "Pendulum-v1")
# Pretend to be on the GPU
model.device = th.device("cuda")
model._maybe_recommend_cpu()

0 comments on commit 56c153f

Please sign in to comment.