Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ba43867

Browse files
committedMar 13, 2025·
Update
[ghstack-poisoned]
1 parent 677e543 commit ba43867

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed
 

‎torchrl/collectors/distributed/ray.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,9 @@ def check_list_length_consistency(*lists):
456456
policy_weights = TensorDict.from_module(self._local_policy)
457457
policy_weights = policy_weights.data.lock_()
458458
else:
459-
warnings.warn(_NON_NN_POLICY_WEIGHTS)
460459
policy_weights = TensorDict(lock=True)
460+
if remote_weights_updater is None:
461+
warnings.warn(_NON_NN_POLICY_WEIGHTS)
461462
self.policy_weights = policy_weights
462463
self.collector_class = collector_class
463464
self.collected_frames = 0
@@ -467,11 +468,6 @@ def check_list_length_consistency(*lists):
467468

468469
self.update_after_each_batch = update_after_each_batch
469470
self.max_weight_update_interval = max_weight_update_interval
470-
self.remote_weights_updater = RayRemoteWeightUpdater(
471-
policy_weights=policy_weights,
472-
remote_collectors=self.remote_collectors,
473-
max_interval=self.max_weight_update_interval,
474-
)
475471

476472
self.collector_kwargs = (
477473
collector_kwargs if collector_kwargs is not None else [{}]
@@ -529,6 +525,14 @@ def check_list_length_consistency(*lists):
529525
collector_kwargs,
530526
remote_configs,
531527
)
528+
if remote_weights_updater is None:
529+
remote_weights_updater = RayRemoteWeightUpdater(
530+
policy_weights=policy_weights,
531+
remote_collectors=self.remote_collectors,
532+
max_interval=self.max_weight_update_interval,
533+
)
534+
self.remote_weights_updater = remote_weights_updater
535+
self.local_weights_updater = local_weights_updater
532536

533537
# Print info of all remote workers
534538
pending_samples = [

‎torchrl/collectors/distributed/rpc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def update_weights(
873873
if workers is None:
874874
workers = list(range(self.num_workers))
875875
futures = []
876-
weights = self.policy_weights.data if weights is None else weights
876+
weights = self.policy_weights if weights is None else weights
877877
for i in workers:
878878
if self._VERBOSE:
879879
torchrl_logger.info(f"calling update on worker {i}")
@@ -884,7 +884,7 @@ def update_weights(
884884
args=(self.collector_rrefs[i], weights),
885885
)
886886
)
887-
if kwargs.get("wait"):
887+
if kwargs.get("wait", True):
888888
for i in workers:
889889
if self._VERBOSE:
890890
torchrl_logger.info(f"waiting for worker {i}")

‎torchrl/collectors/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
_NON_NN_POLICY_WEIGHTS = (
1515
"The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and "
16-
"update_policy_weights_ will be a no-op."
16+
"update_policy_weights_ will be a no-op. Consider passing a local/remote_weight_updater object "
17+
"to your collector to handle the weight updates."
1718
)
1819

1920

0 commit comments

Comments
 (0)
Please sign in to comment.