Skip to content

Commit 7ec4110

Browse files
committed
final?
1 parent 8dff25a commit 7ec4110

File tree

20 files changed

+2099
-1709
lines changed

20 files changed

+2099
-1709
lines changed

docs/source/reference/collectors_weightsync.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ Weight Senders
198198
:template: rl_template.rst
199199

200200
WeightSender
201+
MPWeightSender
202+
RPCWeightSender
203+
DistributedWeightSender
201204
RayModuleTransformSender
202205

203206
Weight Receivers
@@ -208,6 +211,9 @@ Weight Receivers
208211
:template: rl_template.rst
209212

210213
WeightReceiver
214+
MPWeightReceiver
215+
RPCWeightReceiver
216+
DistributedWeightReceiver
211217
RayModuleTransformReceiver
212218

213219
Transports

examples/collectors/multi_weight_updates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torchrl.data import LazyTensorStorage, ReplayBuffer
2626
from torchrl.envs.libs.gym import GymEnv
2727
from torchrl.envs.transforms.module import ModuleTransform
28-
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
28+
from torchrl.weight_update import MultiProcessWeightSyncScheme
2929

3030

3131
def make_module():

test/test_collector.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,8 +1558,6 @@ def create_env():
15581558
) # MultiSync has known indexing issues with SharedMem
15591559
def test_update_weights_shared_mem(self, use_async):
15601560
"""Test shared memory weight synchronization scheme."""
1561-
from tensordict import TensorDict
1562-
from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme
15631561

15641562
def create_env():
15651563
return ContinuousActionVecMockEnv()
@@ -4117,16 +4115,17 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme):
41174115
frames_per_batch=16,
41184116
**kwargs,
41194117
)
4120-
if not isinstance(collector, SyncDataCollector):
4121-
if weight_sync_scheme is not None:
4122-
assert isinstance(
4123-
collector._weight_sync_schemes["policy"], weight_sync_scheme
4124-
)
4125-
else:
4126-
assert isinstance(
4127-
collector._weight_sync_schemes["policy"], SharedMemWeightSyncScheme
4128-
)
41294118
try:
4119+
if not isinstance(collector, SyncDataCollector):
4120+
if weight_sync_scheme is not None:
4121+
assert isinstance(
4122+
collector._weight_sync_schemes["policy"], weight_sync_scheme
4123+
)
4124+
else:
4125+
assert isinstance(
4126+
collector._weight_sync_schemes["policy"],
4127+
SharedMemWeightSyncScheme,
4128+
)
41304129
collector.start()
41314130
for _ in range(10):
41324131
time.sleep(0.1)

test/test_weightsync.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
from tensordict.nn import TensorDictModule
1818
from torch import multiprocessing as mp
1919
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
20-
from torchrl.weight_update.weight_sync_schemes import (
21-
_resolve_model,
20+
from torchrl.weight_update import (
2221
DistributedWeightSyncScheme,
2322
MPTransport,
2423
MultiProcessWeightSyncScheme,
@@ -27,6 +26,9 @@
2726
RayWeightSyncScheme,
2827
RPCWeightSyncScheme,
2928
SharedMemTransport,
29+
)
30+
from torchrl.weight_update.utils import _resolve_model
31+
from torchrl.weight_update.weight_sync_schemes import (
3032
SharedMemWeightSyncScheme,
3133
WeightStrategy,
3234
)

torchrl/collectors/_multi_base.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def __init__(
334334
policy_factory = self._setup_policy_factory(policy_factory)
335335

336336
# Set up weight synchronization
337-
weight_sync_schemes = {}
337+
if weight_sync_schemes is None:
338+
weight_sync_schemes = {}
338339
if (
339340
not any(policy_factory)
340341
and not weight_sync_schemes
@@ -516,13 +517,13 @@ def _setup_multi_policy_and_weights(
516517
weight_sync_policy = weight_sync_schemes.get("policy")
517518
if weight_sync_policy is None:
518519
return
519-
if weight_sync_policy._initialized_on_sender:
520-
return
521520
if any(p is not None for p in policy_factory):
522-
raise RuntimeError(
523-
f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}"
524-
)
525-
weight_sync_policy.init_on_sender(model=policy, devices=self.policy_device)
521+
if not weight_sync_policy._initialized_on_sender:
522+
raise RuntimeError(
523+
f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}"
524+
)
525+
# Weight sync scheme initialization happens in _run_processes
526+
# where pipes and workers are available
526527
else:
527528
# Using legacy weight updater - extract weights and create stateful policies
528529
self._setup_multi_policy_and_weights_legacy(
@@ -821,19 +822,20 @@ def _run_processes(self) -> None:
821822
torch.set_num_threads(self.num_threads)
822823
queue_out = mp.Queue(self._queue_len) # sends data from proc to main
823824
self.procs = []
824-
self.pipes = []
825825
self._traj_pool = _TrajectoryPool(lock=True)
826826

827-
# Initialize weight sync schemes early for SharedMemWeightSyncScheme
828-
# (queue created in __init__ will be pickled with scheme to workers)
829-
# For MultiProcessWeightSyncScheme, we'll initialize after pipes are available
827+
# Create all pipes upfront (needed for weight sync scheme initialization)
828+
# Store as list of (parent, child) tuples for use in worker creation
829+
pipe_pairs = [mp.Pipe() for _ in range(self.num_workers)]
830+
# Extract parent pipes for external use (e.g., polling, receiving messages)
831+
self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs]
832+
833+
# Initialize all weight sync schemes now that pipes are available
834+
# Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes)
835+
# can be initialized here since all required resources exist
830836
if self._weight_sync_schemes:
831837
for model_id, scheme in self._weight_sync_schemes.items():
832-
# Only initialize SharedMemWeightSyncScheme now (needs queue before workers)
833-
# MultiProcessWeightSyncScheme will be initialized after workers are created
834-
if isinstance(scheme, SharedMemWeightSyncScheme) and hasattr(
835-
scheme, "init_on_sender"
836-
):
838+
if hasattr(scheme, "init_on_sender"):
837839
scheme.init_on_sender(model_id=model_id, context=self)
838840
self._weight_senders[model_id] = scheme.get_sender()
839841

@@ -848,7 +850,7 @@ def _run_processes(self) -> None:
848850
for i, (env_fun, env_fun_kwargs) in enumerate(
849851
zip(self.create_env_fn, self.create_env_kwargs)
850852
):
851-
pipe_parent, pipe_child = mp.Pipe() # send messages to procs
853+
pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes
852854
if env_fun.__class__.__name__ != "EnvCreator" and not isinstance(
853855
env_fun, EnvBase
854856
): # to avoid circular imports
@@ -966,7 +968,6 @@ def _run_processes(self) -> None:
966968
) from err
967969
pipe_child.close()
968970
self.procs.append(proc)
969-
self.pipes.append(pipe_parent)
970971

971972
# Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated"
972973
# This must happen after proc.start() but before workers send "instantiated" to avoid deadlock:
@@ -1027,18 +1028,6 @@ def _run_processes(self) -> None:
10271028
# Legacy string error message
10281029
raise RuntimeError(msg)
10291030

1030-
# Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available
1031-
# (SharedMemWeightSyncScheme was already initialized before workers)
1032-
if self._weight_sync_schemes:
1033-
for model_id, scheme in self._weight_sync_schemes.items():
1034-
# Only initialize non-SharedMem schemes here (need pipes)
1035-
if not isinstance(scheme, SharedMemWeightSyncScheme) and hasattr(
1036-
scheme, "init_on_sender"
1037-
):
1038-
scheme.init_on_sender(model_id=model_id, context=self)
1039-
# Get the initialized sender
1040-
self._weight_senders[model_id] = scheme.get_sender()
1041-
10421031
self.queue_out = queue_out
10431032
self.closed = False
10441033

torchrl/collectors/_runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
def _make_policy_factory(
33-
*, policy: Callable, policy_factory, weight_sync_scheme, worker_idx
33+
*, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None
3434
):
3535
if policy is not None and policy_factory is not None:
3636
raise ValueError("policy cannot be used with policy_factory")
@@ -40,7 +40,7 @@ def _make_policy_factory(
4040
if weight_sync_scheme is not None:
4141
# Initialize the receiver on the worker side
4242
weight_sync_scheme.init_on_worker(
43-
model=policy, model_id="policy", worker_idx=worker_idx
43+
model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe
4444
)
4545
# Get the receiver and synchronize initial weights
4646
receiver = weight_sync_scheme.get_receiver()
@@ -92,8 +92,11 @@ def _main_async_collector(
9292
_make_policy_factory,
9393
policy=policy,
9494
policy_factory=policy_factory,
95-
weight_sync_scheme=weight_sync_schemes.get("policy"),
95+
weight_sync_scheme=weight_sync_schemes.get("policy")
96+
if weight_sync_schemes
97+
else None,
9698
worker_idx=worker_idx,
99+
pipe=pipe_child,
97100
)
98101
policy = None
99102
try:

torchrl/collectors/distributed/generic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,9 +570,7 @@ def __init__(
570570
# Set up weight synchronization - prefer new schemes over legacy updater
571571
if weight_updater is None and weight_sync_schemes is None:
572572
# Default to Distributed weight sync scheme for distributed collectors
573-
from torchrl.weight_update.weight_sync_schemes import (
574-
DistributedWeightSyncScheme,
575-
)
573+
from torchrl.weight_update import DistributedWeightSyncScheme
576574

577575
weight_sync_schemes = {
578576
"policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync)

torchrl/collectors/distributed/ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def check_list_length_consistency(*lists):
539539
# Set up weight synchronization - prefer new schemes over legacy updater
540540
if weight_updater is None and weight_sync_schemes is None:
541541
# Default to Ray weight sync scheme for Ray collectors
542-
from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme
542+
from torchrl.weight_update import RayWeightSyncScheme
543543

544544
weight_sync_schemes = {"policy": RayWeightSyncScheme()}
545545

torchrl/collectors/distributed/rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
# Set up weight synchronization - prefer new schemes over legacy updater
418418
if weight_updater is None and weight_sync_schemes is None:
419419
# Default to RPC weight sync scheme for RPC collectors
420-
from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme
420+
from torchrl.weight_update import RPCWeightSyncScheme
421421

422422
weight_sync_schemes = {"policy": RPCWeightSyncScheme()}
423423

torchrl/weight_update/__init__.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,30 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .weight_sync_schemes import (
6+
from ._distributed import (
77
DistributedTransport,
8+
DistributedWeightReceiver,
9+
DistributedWeightSender,
810
DistributedWeightSyncScheme,
11+
)
12+
from ._mp import (
913
MPTransport,
14+
MPWeightReceiver,
15+
MPWeightSender,
1016
MultiProcessWeightSyncScheme,
11-
NoWeightSyncScheme,
17+
)
18+
from ._noupdate import NoWeightSyncScheme
19+
from ._ray import (
1220
RayActorTransport,
1321
RayModuleTransformReceiver,
1422
RayModuleTransformScheme,
1523
RayModuleTransformSender,
1624
RayTransport,
1725
RayWeightSyncScheme,
18-
RPCTransport,
19-
RPCWeightSyncScheme,
20-
SharedMemTransport,
21-
SharedMemWeightSyncScheme,
26+
)
27+
from ._rpc import RPCTransport, RPCWeightReceiver, RPCWeightSender, RPCWeightSyncScheme
28+
from ._shared import SharedMemTransport, SharedMemWeightSyncScheme
29+
from .weight_sync_schemes import (
2230
TransportBackend,
2331
WeightReceiver,
2432
WeightSender,
@@ -27,19 +35,30 @@
2735
)
2836

2937
__all__ = [
38+
# Base classes
3039
"TransportBackend",
40+
"WeightStrategy",
41+
"WeightSender",
42+
"WeightReceiver",
43+
"WeightSyncScheme",
44+
# Transports
3145
"MPTransport",
3246
"SharedMemTransport",
3347
"RayTransport",
3448
"RayActorTransport",
3549
"RPCTransport",
3650
"DistributedTransport",
37-
"WeightStrategy",
38-
"WeightSender",
39-
"WeightReceiver",
51+
# Senders
52+
"MPWeightSender",
53+
"RPCWeightSender",
54+
"DistributedWeightSender",
4055
"RayModuleTransformSender",
56+
# Receivers
57+
"MPWeightReceiver",
58+
"RPCWeightReceiver",
59+
"DistributedWeightReceiver",
4160
"RayModuleTransformReceiver",
42-
"WeightSyncScheme",
61+
# Schemes
4362
"MultiProcessWeightSyncScheme",
4463
"SharedMemWeightSyncScheme",
4564
"NoWeightSyncScheme",

0 commit comments

Comments
 (0)