-
Notifications
You must be signed in to change notification settings - Fork 406
Description
Motivation
Currently, training paradigms that require dynamically changing environment configuration parameters during a single training run (e.g., curriculum learning, adaptive difficulty, switching between environment variants/tasks within a bandit framework) face significant performance bottlenecks when using MultiaSyncDataCollector
.
My specific problem is that I need to update a configuration parameter within the custom Gymnasium environment instances running on the worker processes periodically. The only reliable way to achieve this currently is to shut down the entire MultiaSyncDataCollector
and create a new one with an updated create_env_fn
. This process incurs a substantial time cost (tens of seconds per configuration switch in my case), making training loops with frequent updates impractically slow. I'm frustrated because the core computation (environment steps and policy inference) is fast, but the infrastructure management (process shutdown/restart) dominates the wall-clock time during these transitions.
Solution
I propose the addition of a mechanism within torchrl
, specifically for MultiaSyncDataCollector
(and potentially other parallel collectors), that allows users to:
- Broadcast Parameter Updates: Send new configuration parameter values from the main process to all active worker processes without terminating them.
- Invoke Environment Logic: Trigger specific methods within the environment instances on the workers to apply these new parameters (e.g., calling an internal
env.update_config(new_param=value)
method). - Trigger Transform Re-initialization: Critically, provide a way to signal stateful transforms (like
ObservationNorm
,Compose
, etc.) within the worker environments to re-initialize their state based on the new environment configuration. This might involve allowing users to specify which transform methods (e.g.,transform.init_stats()
, a customtransform.reset_state()
) should be called after the environment parameters are updated.
This would allow for efficient, in-place updates of the environment setup across all workers, eliminating the costly shutdown/restart cycle.
Alternatives
- Current Workaround (Shutdown & Recreate): The primary alternative is the current working method: call
collector.shutdown()
and instantiate a newMultiaSyncDataCollector
with the updated configuration increate_env_fn
.- Pro: Guarantees correctness, including proper initialization of environments and transforms.
- Con: Extremely slow due to process management overhead.
- Using
collector.reset()
with Arguments: I attempted to pass the new parameter viacollector.reset(my_param=new_value)
, hoping it would be forwarded toenv.reset()
in the workers.- Con: This is not supported functionality (
TypeError: reset() got an unexpected keyword argument 'my_param'
) and, more importantly, it wouldn't address the re-initialization requirement for stateful transforms likeObservationNorm
.
- Con: This is not supported functionality (
- Manual IPC via
collector.pipes
: One could theoretically try sending custom messages throughcollector.pipes
to the workers.- Con: Requires modifying
torchrl
's internal worker loop logic to handle these custom messages, making it brittle, hard to maintain, and breaking encapsulation. It also requires careful handling of synchronization and transform state.
- Con: Requires modifying
Additional context
The need for dynamic updates is particularly relevant for complex training procedures where the environment's characteristics change over time based on agent progress or predefined schedules. The correct handling of stateful transforms (ObservationNorm
being a key example) is essential for stability, as using stale normalization statistics after a parameter change can lead to incorrect observations and poor learning. The error message TypeError: MultiaSyncDataCollector.reset() got an unexpected keyword argument '...'
confirms the limitation of the current reset
approach for passing parameters.
Checklist
- I have checked that there is no similar issue in the repo (required)