@@ -334,7 +334,8 @@ def __init__(
334334 policy_factory = self ._setup_policy_factory (policy_factory )
335335
336336 # Set up weight synchronization
337- weight_sync_schemes = {}
337+ if weight_sync_schemes is None :
338+ weight_sync_schemes = {}
338339 if (
339340 not any (policy_factory )
340341 and not weight_sync_schemes
@@ -516,13 +517,13 @@ def _setup_multi_policy_and_weights(
516517 weight_sync_policy = weight_sync_schemes .get ("policy" )
517518 if weight_sync_policy is None :
518519 return
519- if weight_sync_policy ._initialized_on_sender :
520- return
521520 if any (p is not None for p in policy_factory ):
522- raise RuntimeError (
523- f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got { policy_factory = } "
524- )
525- weight_sync_policy .init_on_sender (model = policy , devices = self .policy_device )
521+ if not weight_sync_policy ._initialized_on_sender :
522+ raise RuntimeError (
523+ f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got { policy_factory = } "
524+ )
525+ # Weight sync scheme initialization happens in _run_processes
526+ # where pipes and workers are available
526527 else :
527528 # Using legacy weight updater - extract weights and create stateful policies
528529 self ._setup_multi_policy_and_weights_legacy (
@@ -821,19 +822,20 @@ def _run_processes(self) -> None:
821822 torch .set_num_threads (self .num_threads )
822823 queue_out = mp .Queue (self ._queue_len ) # sends data from proc to main
823824 self .procs = []
824- self .pipes = []
825825 self ._traj_pool = _TrajectoryPool (lock = True )
826826
827- # Initialize weight sync schemes early for SharedMemWeightSyncScheme
828- # (queue created in __init__ will be pickled with scheme to workers)
829- # For MultiProcessWeightSyncScheme, we'll initialize after pipes are available
827+ # Create all pipes upfront (needed for weight sync scheme initialization)
828+ # Store as list of (parent, child) tuples for use in worker creation
829+ pipe_pairs = [mp .Pipe () for _ in range (self .num_workers )]
830+ # Extract parent pipes for external use (e.g., polling, receiving messages)
831+ self .pipes = [pipe_parent for pipe_parent , _ in pipe_pairs ]
832+
833+ # Initialize all weight sync schemes now that pipes are available
834+ # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes)
835+ # can be initialized here since all required resources exist
830836 if self ._weight_sync_schemes :
831837 for model_id , scheme in self ._weight_sync_schemes .items ():
832- # Only initialize SharedMemWeightSyncScheme now (needs queue before workers)
833- # MultiProcessWeightSyncScheme will be initialized after workers are created
834- if isinstance (scheme , SharedMemWeightSyncScheme ) and hasattr (
835- scheme , "init_on_sender"
836- ):
838+ if hasattr (scheme , "init_on_sender" ):
837839 scheme .init_on_sender (model_id = model_id , context = self )
838840 self ._weight_senders [model_id ] = scheme .get_sender ()
839841
@@ -848,7 +850,7 @@ def _run_processes(self) -> None:
848850 for i , (env_fun , env_fun_kwargs ) in enumerate (
849851 zip (self .create_env_fn , self .create_env_kwargs )
850852 ):
851- pipe_parent , pipe_child = mp . Pipe () # send messages to procs
853+ pipe_parent , pipe_child = pipe_pairs [ i ] # use pre-created pipes
852854 if env_fun .__class__ .__name__ != "EnvCreator" and not isinstance (
853855 env_fun , EnvBase
854856 ): # to avoid circular imports
@@ -966,7 +968,6 @@ def _run_processes(self) -> None:
966968 ) from err
967969 pipe_child .close ()
968970 self .procs .append (proc )
969- self .pipes .append (pipe_parent )
970971
971972 # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated"
972973 # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock:
@@ -1027,18 +1028,6 @@ def _run_processes(self) -> None:
10271028 # Legacy string error message
10281029 raise RuntimeError (msg )
10291030
1030- # Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available
1031- # (SharedMemWeightSyncScheme was already initialized before workers)
1032- if self ._weight_sync_schemes :
1033- for model_id , scheme in self ._weight_sync_schemes .items ():
1034- # Only initialize non-SharedMem schemes here (need pipes)
1035- if not isinstance (scheme , SharedMemWeightSyncScheme ) and hasattr (
1036- scheme , "init_on_sender"
1037- ):
1038- scheme .init_on_sender (model_id = model_id , context = self )
1039- # Get the initialized sender
1040- self ._weight_senders [model_id ] = scheme .get_sender ()
1041-
10421031 self .queue_out = queue_out
10431032 self .closed = False
10441033
0 commit comments