Skip to content

Commit 8099be9

Browse files
fix(pu): fix update_per_collect in ddp setting (#321)
* fix(pu): fix ddp config when uptate_per_collect is None in config * polish(pu): polish update_per_collect in ddp setting * fix(pu): fix typo --------- Co-authored-by: PaParaZz1 <[email protected]>
1 parent 8a142a9 commit 8099be9

15 files changed

+266
-60
lines changed

lzero/config/utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ def lz_to_ddp_config(cfg: EasyDict) -> EasyDict:
1313
- cfg (:obj:`EasyDict`): The converted config
1414
"""
1515
w = get_world_size()
16-
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
17-
cfg.policy.n_episode = int(np.ceil(cfg.policy.n_episode) / w)
16+
# Generalized handling for multiple keys
17+
keys_to_scale = ['batch_size', 'n_episode', 'num_segments']
18+
for key in keys_to_scale:
19+
if key in cfg.policy:
20+
cfg.policy[key] = int(np.ceil(cfg.policy[key] / w))
1821
return cfg

lzero/entry/train_muzero.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lzero.policy.random_policy import LightZeroRandomPolicy
2020
from lzero.worker import MuZeroCollector as Collector
2121
from lzero.worker import MuZeroEvaluator as Evaluator
22-
from .utils import random_collect
22+
from .utils import random_collect, calculate_update_per_collect
2323

2424

2525
def train_muzero(
@@ -186,12 +186,9 @@ def train_muzero(
186186

187187
# Collect data by default config n_sample/n_episode.
188188
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
189-
if cfg.policy.update_per_collect is None:
190-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
191-
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
192-
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
193-
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
194-
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
189+
190+
# Determine updates per collection
191+
update_per_collect = calculate_update_per_collect(cfg, new_data)
195192

196193
# save returned new_data collected by the collector
197194
replay_buffer.push_game_segments(new_data)

lzero/entry/train_muzero_segment.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lzero.policy.random_policy import LightZeroRandomPolicy
2020
from lzero.worker import MuZeroEvaluator as Evaluator
2121
from lzero.worker import MuZeroSegmentCollector as Collector
22-
from .utils import random_collect
22+
from .utils import random_collect, calculate_update_per_collect
2323

2424
timer = EasyTimer()
2525

@@ -180,13 +180,10 @@ def train_muzero_segment(
180180

181181
# Collect data by default config n_sample/n_episode.
182182
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
183-
if cfg.policy.update_per_collect is None:
184-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
185-
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
186-
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
187-
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
188-
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
189183

184+
# Determine updates per collection
185+
update_per_collect = calculate_update_per_collect(cfg, new_data)
186+
190187
# save returned new_data collected by the collector
191188
replay_buffer.push_game_segments(new_data)
192189
# remove the oldest data if the replay buffer is full.

lzero/entry/train_rezero.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from lzero.policy.random_policy import LightZeroRandomPolicy
1818
from lzero.worker import MuZeroCollector as Collector
1919
from lzero.worker import MuZeroEvaluator as Evaluator
20-
from .utils import random_collect
20+
from .utils import random_collect, calculate_update_per_collect
2121

2222

2323
def train_rezero(
@@ -152,12 +152,8 @@ def train_rezero(
152152
collect_with_pure_policy=cfg.policy.collect_with_pure_policy
153153
)
154154

155-
if update_per_collect is None:
156-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
157-
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
158-
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
159-
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
160-
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
155+
# Determine updates per collection
156+
update_per_collect = calculate_update_per_collect(cfg, new_data)
161157

162158
# Update replay buffer
163159
replay_buffer.push_game_segments(new_data)

lzero/entry/train_unizero.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from lzero.policy.random_policy import LightZeroRandomPolicy
2121
from lzero.worker import MuZeroEvaluator as Evaluator
2222
from lzero.worker import MuZeroCollector as Collector
23-
from .utils import random_collect
23+
from .utils import random_collect, calculate_update_per_collect
2424

2525

2626
def train_unizero(
@@ -154,13 +154,7 @@ def train_unizero(
154154
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
155155

156156
# Determine updates per collection
157-
update_per_collect = cfg.policy.update_per_collect
158-
if update_per_collect is None:
159-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
160-
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
161-
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
162-
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
163-
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
157+
update_per_collect = calculate_update_per_collect(cfg, new_data)
164158

165159
# Update replay buffer
166160
replay_buffer.push_game_segments(new_data)

lzero/entry/train_unizero_segment.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from lzero.policy.random_policy import LightZeroRandomPolicy
2121
from lzero.worker import MuZeroEvaluator as Evaluator
2222
from lzero.worker import MuZeroSegmentCollector as Collector
23-
from .utils import random_collect
23+
from .utils import random_collect, calculate_update_per_collect
2424

2525
timer = EasyTimer()
2626

@@ -151,13 +151,7 @@ def train_unizero_segment(
151151
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
152152

153153
# Determine updates per collection
154-
update_per_collect = cfg.policy.update_per_collect
155-
if update_per_collect is None:
156-
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
157-
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
158-
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
159-
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
160-
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
154+
update_per_collect = calculate_update_per_collect(cfg, new_data)
161155

162156
# Update replay buffer
163157
replay_buffer.push_game_segments(new_data)

lzero/entry/utils.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,93 @@
22
from typing import Optional, Callable
33

44
import psutil
5+
import torch
6+
import torch.distributed as dist
57
from pympler.asizeof import asizeof
68
from tensorboardX import SummaryWriter
7-
from typing import Optional, Callable
9+
10+
811
import torch
12+
import torch.distributed as dist
13+
14+
def is_ddp_enabled():
15+
"""
16+
Check if Distributed Data Parallel (DDP) is enabled by verifying if
17+
PyTorch's distributed package is available and initialized.
18+
"""
19+
return dist.is_available() and dist.is_initialized()
20+
21+
def ddp_synchronize():
22+
"""
23+
Perform a barrier synchronization across all processes in DDP mode.
24+
Ensures all processes reach this point before continuing.
25+
"""
26+
if is_ddp_enabled():
27+
dist.barrier()
28+
29+
def ddp_all_reduce_sum(tensor):
30+
"""
31+
Perform an all-reduce operation (sum) on the given tensor across
32+
all processes in DDP mode. Returns the reduced tensor.
933
34+
Arguments:
35+
- tensor (:obj:`torch.Tensor`): The input tensor to be reduced.
36+
37+
Returns:
38+
- torch.Tensor: The reduced tensor, summed across all processes.
39+
"""
40+
if is_ddp_enabled():
41+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
42+
return tensor
43+
44+
def calculate_update_per_collect(cfg, new_data):
45+
"""
46+
Calculate the number of updates to perform per data collection in a
47+
Distributed Data Parallel (DDP) setting. This ensures that all GPUs
48+
compute the same `update_per_collect` value, synchronized across processes.
49+
50+
Arguments:
51+
- cfg: Configuration object containing policy settings.
52+
- new_data (list): The newly collected data segments.
53+
54+
Returns:
55+
- int: The number of updates to perform per collection.
56+
"""
57+
# Retrieve the update_per_collect setting from the configuration
58+
update_per_collect = cfg.policy.update_per_collect
59+
60+
if update_per_collect is None:
61+
# If update_per_collect is not explicitly set, calculate it based on
62+
# the number of collected transitions and the replay ratio.
63+
64+
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
65+
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
66+
collected_transitions_num = sum(
67+
min(len(game_segment), cfg.policy.game_segment_length)
68+
for game_segment in new_data[0]
69+
)
70+
71+
if torch.cuda.is_available():
72+
# Convert the collected transitions count to a GPU tensor for DDP operations.
73+
collected_transitions_tensor = torch.tensor(
74+
collected_transitions_num, dtype=torch.int64, device='cuda'
75+
)
76+
77+
# Synchronize the collected transitions count across all GPUs using all-reduce.
78+
total_collected_transitions = ddp_all_reduce_sum(
79+
collected_transitions_tensor
80+
).item()
81+
82+
# Calculate update_per_collect based on the total synchronized transitions count.
83+
update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio)
84+
85+
# Ensure the computed update_per_collect is positive.
86+
assert update_per_collect > 0, "update_per_collect must be positive"
87+
else:
88+
# If not using DDP, calculate update_per_collect directly from the local count.
89+
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
90+
91+
return update_per_collect
1092

1193
def initialize_zeros_batch(observation_shape, batch_size, device):
1294
"""

lzero/mcts/buffer/game_buffer_unizero.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -522,15 +522,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
522522
m_output = model.initial_inference(batch_obs, batch_action)
523523
# ======================================================================
524524

525-
if not model.training:
526-
# if not in training, obtain the scalars of the value/reward
527-
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
528-
[
529-
m_output.latent_state,
530-
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
531-
m_output.policy_logits
532-
]
533-
)
525+
# if not in training, obtain the scalars of the value/reward
526+
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
527+
[
528+
m_output.latent_state,
529+
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
530+
m_output.policy_logits
531+
]
532+
)
534533

535534
network_output.append(m_output)
536535

lzero/policy/unizero.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -728,11 +728,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
728728
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data)
729729
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
730730

731-
if not self._eval_model.training:
732-
# if not in training, obtain the scalars of the value/reward
733-
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
734-
latent_state_roots = latent_state_roots.detach().cpu().numpy()
735-
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
731+
# if not in training, obtain the scalars of the value/reward
732+
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
733+
latent_state_roots = latent_state_roots.detach().cpu().numpy()
734+
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
736735

737736
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)]
738737
if self._cfg.mcts_ctree:

zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py zoo/atari/config/atari_efficientzero_ddp_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
Overview:
8585
This script should be executed with <nproc_per_node> GPUs.
8686
Run the following command to launch the script:
87-
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py
87+
torchrun --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_ddp_config.py
8888
"""
8989
from ding.utils import DDPContext
9090
from lzero.entry import train_muzero

zoo/atari/config/atari_muzero_multigpu_ddp_config.py zoo/atari/config/atari_muzero_ddp_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
Overview:
101101
This script should be executed with <nproc_per_node> GPUs.
102102
Run the following command to launch the script:
103-
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_multigpu_ddp_config.py
103+
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_muzero_ddp_config.py
104104
"""
105105
from ding.utils import DDPContext
106106
from lzero.entry import train_muzero

zoo/atari/config/atari_unizero_multigpu_ddp_config.py zoo/atari/config/atari_unizero_ddp_config.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@
103103
Overview:
104104
This script should be executed with <nproc_per_node> GPUs.
105105
Run the following command to launch the script:
106-
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py
107-
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py
106+
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
108107
109108
"""
110109
from ding.utils import DDPContext

0 commit comments

Comments
 (0)