@@ -456,8 +456,9 @@ def check_list_length_consistency(*lists):
456
456
policy_weights = TensorDict .from_module (self ._local_policy )
457
457
policy_weights = policy_weights .data .lock_ ()
458
458
else :
459
- warnings .warn (_NON_NN_POLICY_WEIGHTS )
460
459
policy_weights = TensorDict (lock = True )
460
+ if remote_weights_updater is None :
461
+ warnings .warn (_NON_NN_POLICY_WEIGHTS )
461
462
self .policy_weights = policy_weights
462
463
self .collector_class = collector_class
463
464
self .collected_frames = 0
@@ -467,11 +468,6 @@ def check_list_length_consistency(*lists):
467
468
468
469
self .update_after_each_batch = update_after_each_batch
469
470
self .max_weight_update_interval = max_weight_update_interval
470
- self .remote_weights_updater = RayRemoteWeightUpdater (
471
- policy_weights = policy_weights ,
472
- remote_collectors = self .remote_collectors ,
473
- max_interval = self .max_weight_update_interval ,
474
- )
475
471
476
472
self .collector_kwargs = (
477
473
collector_kwargs if collector_kwargs is not None else [{}]
@@ -529,6 +525,14 @@ def check_list_length_consistency(*lists):
529
525
collector_kwargs ,
530
526
remote_configs ,
531
527
)
528
+ if remote_weights_updater is None :
529
+ remote_weights_updater = RayRemoteWeightUpdater (
530
+ policy_weights = policy_weights ,
531
+ remote_collectors = self .remote_collectors ,
532
+ max_interval = self .max_weight_update_interval ,
533
+ )
534
+ self .remote_weights_updater = remote_weights_updater
535
+ self .local_weights_updater = local_weights_updater
532
536
533
537
# Print info of all remote workers
534
538
pending_samples = [
0 commit comments