Skip to content

Commit 1860f11

Browse files
author
puyuan
committed
polish(pu): polish moco multigpu option
1 parent 97988e2 commit 1860f11

3 files changed

+28
-47
lines changed

lzero/entry/train_muzero_multitask_segment_ddp.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,9 @@ def train_muzero_multitask_segment_ddp(
379379
)
380380
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
381381

382-
if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
382+
# if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
383+
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
384+
383385
print('=' * 20)
384386
print(f'Rank {rank} 评估 task_id: {cfg.policy.task_id}...')
385387

lzero/policy/sampled_unizero_multitask.py

+5-26
Original file line numberDiff line numberDiff line change
@@ -465,40 +465,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s
465465
# Core learn model update step
466466
self._optimizer_world_model.zero_grad()
467467

468+
# 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表
469+
# 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True
468470
if self._cfg.use_moco:
469-
# 如果已经初始化且多 GPU 情况下,只有 rank0 收集其他 GPU 的 loss_list
470-
if dist.is_initialized() and dist.get_world_size() > 1:
471-
rank = dist.get_rank()
472-
world_size = dist.get_world_size()
473-
# 利用分布式 gather_object:仅 rank0 指定接收缓冲区
474-
if rank == 0:
475-
gathered_losses = [None for _ in range(world_size)]
476-
else:
477-
gathered_losses = None # 其他进程不需要接收
478-
# gather_object 要求所有进程参与:每个进程发送自己的 losses_list,rank0 接收
479-
dist.gather_object(losses_list, gathered_losses, dst=0)
480-
if rank == 0:
481-
# 将各 GPU 上的 losses_list 展平,汇总成全局 losses_list
482-
all_losses_list = []
483-
for loss_list_tmp in gathered_losses:
484-
all_losses_list.extend(loss_list_tmp)
485-
losses_list = all_losses_list
486-
else:
487-
# 非 rank0 设置为 None,防止误用
488-
losses_list = None
489-
490-
# 调用 MoCo 后向,由 grad_correct 中的 backward 实现梯度校正
491-
# 注意:在 moco.backward 中会判断当前 rank 是否为 0,只有 rank0 会根据 losses_list 计算梯度,
492-
# 其他 rank 直接等待广播校正后共享梯度
493-
lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params)
471+
# 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正
472+
lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params)
494473
else:
495474
# 不使用梯度校正的情况,由各 rank 自己执行反向传播
496475
lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device)
497476
weighted_total_loss.backward()
498477

499478
total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value)
500479

501-
if self._cfg.multi_gpu:
480+
if self._cfg.multi_gpu and not self._cfg.use_moco:
502481
self.sync_gradients(self._learn_model)
503482

504483
self._optimizer_world_model.step()

zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
1717
action_space_size_list=action_space_size_list,
1818
from_pixels=False,
1919
# ===== only for debug =====
20-
frame_skip=50, # 100
21-
# frame_skip=2,
20+
# frame_skip=50, # 100
21+
frame_skip=2,
2222
continuous=True, # Assuming all DMC tasks use continuous action spaces
2323
collector_env_num=collector_env_num,
2424
evaluator_env_num=evaluator_env_num,
@@ -156,7 +156,7 @@ def generate_configs(env_id_list: List[str],
156156
# TODO: debug
157157
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
158158

159-
exp_name_prefix = f'data_suz_mt_20250207_debug/ddp_2gpu-moco_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
159+
exp_name_prefix = f'data_suz_mt_20250207/ddp_8gpu-moco_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
160160

161161
# exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
162162

@@ -208,7 +208,7 @@ def create_env_manager():
208208
Overview:
209209
This script should be executed with <nproc_per_node> GPUs.
210210
Run the following command to launch the script:
211-
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py
211+
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29501 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py
212212
torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py
213213
"""
214214

@@ -239,16 +239,16 @@ def create_env_manager():
239239
# ]
240240

241241
# DMC 8games
242-
# env_id_list = [
243-
# 'acrobot-swingup',
244-
# 'cartpole-balance',
245-
# 'cartpole-balance_sparse',
246-
# 'cartpole-swingup',
247-
# 'cartpole-swingup_sparse',
248-
# 'cheetah-run',
249-
# "ball_in_cup-catch",
250-
# "finger-spin",
251-
# ]
242+
env_id_list = [
243+
'acrobot-swingup',
244+
'cartpole-balance',
245+
'cartpole-balance_sparse',
246+
'cartpole-swingup',
247+
'cartpole-swingup_sparse',
248+
'cheetah-run',
249+
"ball_in_cup-catch",
250+
"finger-spin",
251+
]
252252

253253
# DMC 18games
254254
# env_id_list = [
@@ -302,12 +302,12 @@ def create_env_manager():
302302
reanalyze_partition = 0.75
303303

304304
# ======== TODO: only for debug ========
305-
collector_env_num = 2
306-
num_segments = 2
307-
n_episode = 2
308-
evaluator_env_num = 2
309-
num_simulations = 1
310-
batch_size = [4 for _ in range(len(env_id_list))]
305+
# collector_env_num = 2
306+
# num_segments = 2
307+
# n_episode = 2
308+
# evaluator_env_num = 2
309+
# num_simulations = 1
310+
# batch_size = [4 for _ in range(len(env_id_list))]
311311
# =======================================
312312

313313
seed = 0 # You can iterate over multiple seeds if needed

0 commit comments

Comments
 (0)