Skip to content

Commit 97988e2

Browse files
author
puyuan
committed
polish(pu): add moco multigpu support
1 parent 8fe1a6d commit 97988e2

File tree

2 files changed

+60
-31
lines changed

2 files changed

+60
-31
lines changed

lzero/policy/sampled_unizero_multitask.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
)
2828
from lzero.policy.unizero import UniZeroPolicy
2929
from .utils import configure_optimizers_nanogpt
30+
import torch.nn.functional as F
31+
import torch.distributed as dist
3032
import sys
3133
sys.path.append('/mnt/afs/niuyazhe/code/LibMTL/')
3234
from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect
@@ -64,7 +66,7 @@ def parameters(self):
6466
list(self.tokenizer.parameters()) +
6567
list(self.transformer.parameters()) +
6668
list(self.pos_emb.parameters()) +
67-
list(self.task_emb.parameters()) +
69+
# list(self.task_emb.parameters()) +
6870
list(self.act_embedding_table.parameters())
6971
)
7072

@@ -73,7 +75,7 @@ def zero_grad(self, set_to_none=False):
7375
self.tokenizer.zero_grad(set_to_none=set_to_none)
7476
self.transformer.zero_grad(set_to_none=set_to_none)
7577
self.pos_emb.zero_grad(set_to_none=set_to_none)
76-
self.task_emb.zero_grad(set_to_none=set_to_none)
78+
# self.task_emb.zero_grad(set_to_none=set_to_none)
7779
self.act_embedding_table.zero_grad(set_to_none=set_to_none)
7880

7981

@@ -308,7 +310,8 @@ def _init_learn(self) -> None:
308310
# TODO
309311
# 如果需要,可以在这里初始化梯度校正方法(如 MoCo, CAGrad)
310312
# self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device)
311-
self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device) # only compatiable with for 1GPU training
313+
# self.grad_correct = GradCorrect(wrapped_model, self._cfg.task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training
314+
self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) # only compatiable with for 1GPU training
312315

313316
self.grad_correct.init_param()
314317
self.grad_correct.rep_grad = False
@@ -463,10 +466,33 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None) -> Dict[s
463466
self._optimizer_world_model.zero_grad()
464467

465468
if self._cfg.use_moco:
466-
# 这里可以集成 MoCo 或 CAGrad 等梯度校正方法, 1gpu 需要知道所有task对应的梯度
469+
# 如果已经初始化且多 GPU 情况下,只有 rank0 收集其他 GPU 的 loss_list
470+
if dist.is_initialized() and dist.get_world_size() > 1:
471+
rank = dist.get_rank()
472+
world_size = dist.get_world_size()
473+
# 利用分布式 gather_object:仅 rank0 指定接收缓冲区
474+
if rank == 0:
475+
gathered_losses = [None for _ in range(world_size)]
476+
else:
477+
gathered_losses = None # 其他进程不需要接收
478+
# gather_object 要求所有进程参与:每个进程发送自己的 losses_list,rank0 接收
479+
dist.gather_object(losses_list, gathered_losses, dst=0)
480+
if rank == 0:
481+
# 将各 GPU 上的 losses_list 展平,汇总成全局 losses_list
482+
all_losses_list = []
483+
for loss_list_tmp in gathered_losses:
484+
all_losses_list.extend(loss_list_tmp)
485+
losses_list = all_losses_list
486+
else:
487+
# 非 rank0 设置为 None,防止误用
488+
losses_list = None
489+
490+
# 调用 MoCo 后向,由 grad_correct 中的 backward 实现梯度校正
491+
# 注意:在 moco.backward 中会判断当前 rank 是否为 0,只有 rank0 会根据 losses_list 计算梯度,
492+
# 其他 rank 直接等待广播校正后共享梯度
467493
lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params)
468494
else:
469-
# 不使用梯度校正的情况
495+
# 不使用梯度校正的情况,由各 rank 自己执行反向传播
470496
lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device)
471497
weighted_total_loss.backward()
472498

zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
1717
action_space_size_list=action_space_size_list,
1818
from_pixels=False,
1919
# ===== only for debug =====
20-
# frame_skip=100, # 100
21-
frame_skip=2,
20+
frame_skip=50, # 100
21+
# frame_skip=2,
2222
continuous=True, # Assuming all DMC tasks use continuous action spaces
2323
collector_env_num=collector_env_num,
2424
evaluator_env_num=evaluator_env_num,
@@ -38,6 +38,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
3838
calpha=0.5, rescale=1,
3939
),
4040
use_moco=True, # ==============TODO==============
41+
total_task_num=len(env_id_list),
4142
task_num=len(env_id_list),
4243
task_id=0, # To be set per task
4344
model=dict(
@@ -54,6 +55,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
5455
# use_shared_projection=True, # TODO
5556
use_shared_projection=False,
5657
# use_task_embed=True, # TODO
58+
task_embed_option=None, # ==============TODO: none ==============
5759
use_task_embed=False, # ==============TODO==============
5860
num_unroll_steps=num_unroll_steps,
5961
policy_entropy_weight=5e-2,
@@ -90,6 +92,7 @@ def create_config(env_id, observation_shape_list, action_space_size_list, collec
9092
num_experts_of_moe_in_transformer=4,
9193
),
9294
),
95+
use_task_exploitation_weight=False, # TODO
9396
# task_complexity_weight=True, # TODO
9497
task_complexity_weight=False, # TODO
9598
total_batch_size=total_batch_size,
@@ -153,7 +156,7 @@ def generate_configs(env_id_list: List[str],
153156
# TODO: debug
154157
# exp_name_prefix = f'data_suz_mt_20250113/ddp_8gpu_nlayer8_upc200_taskweight-eval1e3-10k-temp10-1_task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
155158

156-
exp_name_prefix = f'data_suz_mt_20250113/ddp_1gpu-moco_nlayer8_upc80_notaskweight-eval1e3-10k-temp10-1_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
159+
exp_name_prefix = f'data_suz_mt_20250207_debug/ddp_2gpu-moco_nlayer8_upc200_notaskweight_no-task-embed_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
157160

158161
# exp_name_prefix = f'data_suz_mt_20250113/ddp_3gpu_3games_nlayer8_upc200_notusp_notaskweight-symlog-01-05-eval1e3_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_tbs{total_batch_size}_seed{seed}/'
159162

@@ -205,7 +208,7 @@ def create_env_manager():
205208
Overview:
206209
This script should be executed with <nproc_per_node> GPUs.
207210
Run the following command to launch the script:
208-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_config.py
211+
python -m torch.distributed.launch --nproc_per_node=2 --master_port=29500 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_8games_moco_config.py
209212
torchrun --nproc_per_node=8 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py
210213
"""
211214

@@ -236,16 +239,16 @@ def create_env_manager():
236239
# ]
237240

238241
# DMC 8games
239-
env_id_list = [
240-
'acrobot-swingup',
241-
'cartpole-balance',
242-
'cartpole-balance_sparse',
243-
'cartpole-swingup',
244-
'cartpole-swingup_sparse',
245-
'cheetah-run',
246-
"ball_in_cup-catch",
247-
"finger-spin",
248-
]
242+
# env_id_list = [
243+
# 'acrobot-swingup',
244+
# 'cartpole-balance',
245+
# 'cartpole-balance_sparse',
246+
# 'cartpole-swingup',
247+
# 'cartpole-swingup_sparse',
248+
# 'cheetah-run',
249+
# "ball_in_cup-catch",
250+
# "finger-spin",
251+
# ]
249252

250253
# DMC 18games
251254
# env_id_list = [
@@ -278,18 +281,18 @@ def create_env_manager():
278281
n_episode = 8
279282
evaluator_env_num = 3
280283
num_simulations = 50
281-
# max_env_step = int(5e5)
282-
max_env_step = int(1e6)
284+
max_env_step = int(5e5)
285+
# max_env_step = int(1e6)
283286

284287
reanalyze_ratio = 0.0
285288

286-
# nlayer=4
289+
# nlayer=4/8
287290
total_batch_size = 512
288291
batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
289292

290-
# nlayer=8/12
291-
total_batch_size = 256
292-
batch_size = [int(min(32, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
293+
# # nlayer=12
294+
# total_batch_size = 256
295+
# batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))]
293296

294297
num_unroll_steps = 5
295298
infer_context_length = 2
@@ -299,12 +302,12 @@ def create_env_manager():
299302
reanalyze_partition = 0.75
300303

301304
# ======== TODO: only for debug ========
302-
# collector_env_num = 2
303-
# num_segments = 2
304-
# n_episode = 2
305-
# evaluator_env_num = 2
306-
# num_simulations = 1
307-
# batch_size = [4 for _ in range(len(env_id_list))]
305+
collector_env_num = 2
306+
num_segments = 2
307+
n_episode = 2
308+
evaluator_env_num = 2
309+
num_simulations = 1
310+
batch_size = [4 for _ in range(len(env_id_list))]
308311
# =======================================
309312

310313
seed = 0 # You can iterate over multiple seeds if needed

0 commit comments

Comments
 (0)