Skip to content

Commit 376a27d

Browse files
committed
[CI] Upgrade doc python version (#3222)
1 parent 530f772 commit 376a27d

File tree

23 files changed

+55
-46
lines changed

23 files changed

+55
-46
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
build-docs:
2727
strategy:
2828
matrix:
29-
python_version: [ "3.9" ]
29+
python_version: [ "3.12" ]
3030
cuda_arch_version: [ "12.8" ]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
@@ -60,7 +60,7 @@ jobs:
6060
bash ./miniconda.sh -b -f -p "${conda_dir}"
6161
eval "$(${conda_dir}/bin/conda shell.bash hook)"
6262
printf "* Creating a test environment\n"
63-
conda create --prefix "${env_dir}" -y python=3.9
63+
conda create --prefix "${env_dir}" -y python=3.12
6464
printf "* Activating\n"
6565
conda activate "${env_dir}"
6666

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ sphinx_design
1616
torchvision
1717
dm_control
1818
mujoco<3.3.6
19-
gym[classic_control,accept-rom-license,ale-py,atari]
19+
gymnasium[classic_control,atari]
2020
pygame
2121
tqdm
2222
ipython

docs/source/reference/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ Training and Optimization Configurations
507507
SparseAdamConfig
508508

509509
Logging Configurations
510-
~~~~~~~~~~~~~~~~~~~~~
510+
~~~~~~~~~~~~~~~~~~~~~~
511511

512512
.. currentmodule:: torchrl.trainers.algorithms.configs.logging
513513

docs/source/reference/llms.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ Usage
118118
Adding Custom Templates
119119
^^^^^^^^^^^^^^^^^^^^^^^
120120

121-
You can add custom chat templates for new model families using the :func:`torchrl.data.llm.chat.add_chat_template` function.
121+
You can add custom chat templates for new model families using the :func:`torchrl.data.llm.add_chat_template` function.
122122

123-
.. autofunction:: torchrl.data.llm.chat.add_chat_template
123+
.. autofunction:: torchrl.data.llm.add_chat_template
124124

125125
Usage Examples
126126
^^^^^^^^^^^^^^
@@ -130,7 +130,7 @@ Adding a Llama Template
130130

131131
.. code-block:: python
132132
133-
>>> from torchrl.data.llm.chat import add_chat_template, History
133+
>>> from torchrl.data.llm import add_chat_template, History
134134
>>> from transformers import AutoTokenizer
135135
>>>
136136
>>> # Define the Llama chat template

docs/source/reference/utils.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. currentmodule:: torchrl
22

33
torchrl._utils package
4-
====================
4+
======================
55

66
Set of utility methods that are used internally by the library.
77

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,8 @@ exclude = [
143143

144144
[tool.usort]
145145
first_party_detection = false
146+
147+
[project.entry-points."vllm.general_plugins"]
148+
# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and
149+
# the registry subprocess) before resolving model classes.
150+
fp32_overrides = "torchrl.modules.llm.backends.vllm.vllm_plugin:register_fp32_overrides"

test/llm/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TestAsyncVLLMIntegration:
4040
@pytest.mark.slow
4141
def test_vllm_api_compatibility(self, sampling_params):
4242
"""Test that AsyncVLLM supports the same inputs as vLLM.LLM.generate()."""
43-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
43+
from torchrl.modules.llm.backends import AsyncVLLM
4444

4545
# Create AsyncVLLM service
4646
service = AsyncVLLM.from_pretrained(
@@ -113,7 +113,7 @@ def test_vllm_api_compatibility(self, sampling_params):
113113
def test_weight_updates_with_transformer(self, sampling_params):
114114
"""Test weight updates using vLLMUpdater with a real transformer model."""
115115
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
116-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
116+
from torchrl.modules.llm.backends import AsyncVLLM
117117
from torchrl.modules.llm.policies.transformers_wrapper import (
118118
TransformersWrapper,
119119
)

test/llm/test_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensordict.utils import _zip_strict
1919
from torchrl.data.llm import History
2020
from torchrl.envs.llm.transforms.kl import KLComputation, RetrieveKL, RetrieveLogProb
21-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
21+
from torchrl.modules.llm import AsyncVLLM
2222
from torchrl.modules.llm.policies.common import (
2323
_batching,
2424
ChatHistory,

torchrl/collectors/collectors.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,13 @@ def async_shutdown(
283283
) -> None:
284284
"""Shuts down the collector when started asynchronously with the `start` method.
285285
286-
Arg:
286+
Args:
287287
timeout (float, optional): The maximum time to wait for the collector to shutdown.
288288
close_env (bool, optional): If True, the collector will close the contained environment.
289289
Defaults to `True`.
290290
291291
.. seealso:: :meth:`~.start`
292+
292293
"""
293294
return self.shutdown(timeout=timeout, close_env=close_env)
294295

@@ -440,7 +441,7 @@ class SyncDataCollector(DataCollectorBase):
440441
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
441442
442443
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
443-
pickled directly), the :arg:`policy_factory` should be used instead.
444+
pickled directly), the ``policy_factory`` should be used instead.
444445
445446
Keyword Args:
446447
policy_factory (Callable[[], Callable], optional): a callable that returns
@@ -1784,7 +1785,7 @@ class _MultiDataCollector(DataCollectorBase):
17841785
``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
17851786
17861787
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
1787-
pickled directly), the :arg:`policy_factory` should be used instead.
1788+
pickled directly), the ``policy_factory`` should be used instead.
17881789
17891790
Keyword Args:
17901791
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -2749,8 +2750,8 @@ class MultiSyncDataCollector(_MultiDataCollector):
27492750
... if i == 2:
27502751
... print(data)
27512752
... break
2752-
>>> collector.shutdown()
2753-
>>> del collector
2753+
... collector.shutdown()
2754+
... del collector
27542755
TensorDict(
27552756
fields={
27562757
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -3130,8 +3131,8 @@ class MultiaSyncDataCollector(_MultiDataCollector):
31303131
... if i == 2:
31313132
... print(data)
31323133
... break
3133-
... collector.shutdown()
3134-
... del collector
3134+
... collector.shutdown()
3135+
... del collector
31353136
TensorDict(
31363137
fields={
31373138
action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
@@ -3366,7 +3367,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
33663367
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
33673368
33683369
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
3369-
pickled directly), the :arg:`policy_factory` should be used instead.
3370+
pickled directly), the ``policy_factory`` should be used instead.
33703371
33713372
Keyword Args:
33723373
policy_factory (Callable[[], Callable], optional): a callable that returns
@@ -3380,8 +3381,8 @@ class aSyncDataCollector(MultiaSyncDataCollector):
33803381
total number of frames returned by the collector
33813382
during its lifespan. If the ``total_frames`` is not divisible by
33823383
``frames_per_batch``, an exception is raised.
3383-
Endless collectors can be created by passing ``total_frames=-1``.
3384-
Defaults to ``-1`` (never ending collector).
3384+
Endless collectors can be created by passing ``total_frames=-1``.
3385+
Defaults to ``-1`` (never ending collector).
33853386
device (int, str or torch.device, optional): The generic device of the
33863387
collector. The ``device`` args fills any non-specified device: if
33873388
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or

torchrl/collectors/distributed/generic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ class DistributedDataCollector(DataCollectorBase):
282282
- In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``.
283283
284284
.. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized /
285-
pickled directly), the :arg:`policy_factory` should be used instead.
285+
pickled directly), the ``policy_factory`` should be used instead.
286286
287287
Keyword Args:
288288
policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable
@@ -296,8 +296,8 @@ class DistributedDataCollector(DataCollectorBase):
296296
number of frames returned by the collector
297297
during its lifespan. If the ``total_frames`` is not divisible by
298298
``frames_per_batch``, an exception is raised.
299-
Endless collectors can be created by passing ``total_frames=-1``.
300-
Defaults to ``-1`` (endless collector).
299+
Endless collectors can be created by passing ``total_frames=-1``.
300+
Defaults to ``-1`` (endless collector).
301301
device (int, str or torch.device, optional): The generic device of the
302302
collector. The ``device`` args fills any non-specified device: if
303303
``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or

0 commit comments

Comments
 (0)