Skip to content

Commit d55ea52

Browse files
puyuan1996puyuan
authored and
puyuan
committed
tmp: sync code
1 parent c007917 commit d55ea52

3 files changed

+142
-43
lines changed

lzero/policy/unizero_multitask.py

+119-29
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@ class UniZeroMTPolicy(UniZeroPolicy):
207207
),
208208
),
209209
# ****** common ******
210+
# (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt).
211+
# If set to True, the checkpoint will be evaluated after the training process is complete.
212+
# IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency.
213+
# This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically.
214+
eval_offline=False,
210215
# (bool) whether to use rnd model.
211216
use_rnd_model=False,
212217
# (bool) Whether to use multi-gpu training.
@@ -1144,27 +1149,35 @@ def _state_dict_learn(self) -> Dict[str, Any]:
11441149
}
11451150

11461151
# ========== TODO: original version: load all parameters ==========
1147-
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1148-
"""
1149-
Overview:
1150-
Load the state_dict variable into policy learn mode.
1151-
Arguments:
1152-
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
1153-
"""
1154-
self._learn_model.load_state_dict(state_dict['model'])
1155-
self._target_model.load_state_dict(state_dict['target_model'])
1156-
self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])
1152+
# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1153+
# """
1154+
# Overview:
1155+
# Load the state_dict variable into policy learn mode.
1156+
# Arguments:
1157+
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
1158+
# """
1159+
# self._learn_model.load_state_dict(state_dict['model'])
1160+
# self._target_model.load_state_dict(state_dict['target_model'])
1161+
# self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])
11571162

1158-
# # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters, head use re init weight ==========
1163+
# # 仅加载 transformer_backbone 的参数,而 encoder head 以及其他部分将保留原有的初始化参数。
11591164
# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
11601165
# """
11611166
# Overview:
1162-
# Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
1167+
# Load the state_dict variable into policy learn mode,
1168+
# loading only the transformer_backbone parameters.
1169+
# The encoder, head, and other parts retain their original initialized parameters.
11631170
# Arguments:
11641171
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
11651172
# """
1173+
# # 定义需要加载的参数前缀(仅 transformer_backbone)
1174+
# include_prefixes = [
1175+
# '_orig_mod.world_model.transformer.'
1176+
# ]
1177+
11661178
# # 定义需要排除的参数前缀
11671179
# exclude_prefixes = [
1180+
# '_orig_mod.world_model.tokenizer.',
11681181
# '_orig_mod.world_model.head_policy_multi_task.',
11691182
# '_orig_mod.world_model.head_value_multi_task.',
11701183
# '_orig_mod.world_model.head_rewards_multi_task.',
@@ -1179,25 +1192,33 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
11791192
# # 添加其他需要排除的具体参数名
11801193
# ]
11811194

1182-
# def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
1195+
# def filter_state_dict(state_dict_loader: Dict[str, Any], include_prefixes: list, exclude_prefixes: list = [], exclude_keys: list = []) -> Dict[str, Any]:
11831196
# """
1184-
# 过滤掉需要排除的参数
1197+
# 过滤仅包含需要加载的参数,并排除不需要的参数
11851198
# """
11861199
# filtered = {}
11871200
# for k, v in state_dict_loader.items():
1201+
# # 仅包含指定前缀的参数
1202+
# if not any(k.startswith(prefix) for prefix in include_prefixes):
1203+
# continue
1204+
# # 排除指定前缀的参数(如果有)
11881205
# if any(k.startswith(prefix) for prefix in exclude_prefixes):
1189-
# print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
1206+
# print(f"Excluding parameter by prefix: {k}") # 调试用
11901207
# continue
1208+
# # 排除指定键的参数(如果有)
11911209
# if k in exclude_keys:
11921210
# print(f"Excluding specific parameter: {k}") # 调试用
11931211
# continue
11941212
# filtered[k] = v
11951213
# return filtered
11961214

1197-
# # 过滤并加载 'model' 部分
1215+
# # 过滤并加载 'model' 部分(仅 transformer_backbone)
11981216
# if 'model' in state_dict:
11991217
# model_state_dict = state_dict['model']
1200-
# filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
1218+
# # print(f'='*20)
1219+
# # print(f'model_state_dict:{model_state_dict.keys()}')
1220+
# # print(f'='*20)
1221+
# filtered_model_state_dict = filter_state_dict(model_state_dict, include_prefixes, exclude_prefixes, exclude_keys)
12011222
# missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
12021223
# if missing_keys:
12031224
# print(f"Missing keys when loading _learn_model: {missing_keys}")
@@ -1206,10 +1227,12 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
12061227
# else:
12071228
# print("No 'model' key found in the state_dict.")
12081229

1209-
# # 过滤并加载 'target_model' 部分
1230+
# # 不需要重新初始化 head 部分,因为它们未被加载,保持原有初始化参数
1231+
1232+
# # 过滤并加载 'target_model' 部分(仅 transformer_backbone)
12101233
# if 'target_model' in state_dict:
12111234
# target_model_state_dict = state_dict['target_model']
1212-
# filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
1235+
# filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, include_prefixes, exclude_prefixes, exclude_keys)
12131236
# missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
12141237
# if missing_keys:
12151238
# print(f"Missing keys when loading _target_model: {missing_keys}")
@@ -1218,14 +1241,81 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
12181241
# else:
12191242
# print("No 'target_model' key found in the state_dict.")
12201243

1221-
# # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
1222-
# if 'optimizer_world_model' in state_dict:
1223-
# optimizer_state_dict = state_dict['optimizer_world_model']
1224-
# try:
1225-
# self._optimizer_world_model.load_state_dict(optimizer_state_dict)
1226-
# except Exception as e:
1227-
# print(f"Error loading optimizer state_dict: {e}")
1228-
# else:
1229-
# print("No 'optimizer_world_model' key found in the state_dict.")
1244+
# # 不需要重新初始化 target_model 的 head,因为它们未被加载,保持原有初始化参数
1245+
12301246

1231-
# # 如果需要,还可以加载其他部分,例如 scheduler 等
1247+
1248+
# # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters, head use re init weight ==========
1249+
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1250+
"""
1251+
Overview:
1252+
Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
1253+
Arguments:
1254+
- state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
1255+
"""
1256+
# 定义需要排除的参数前缀
1257+
exclude_prefixes = [
1258+
'_orig_mod.world_model.head_policy_multi_task.',
1259+
'_orig_mod.world_model.head_value_multi_task.',
1260+
'_orig_mod.world_model.head_rewards_multi_task.',
1261+
'_orig_mod.world_model.head_observations_multi_task.',
1262+
'_orig_mod.world_model.task_emb.'
1263+
]
1264+
1265+
# 定义需要排除的具体参数(如果有特殊情况)
1266+
exclude_keys = [
1267+
'_orig_mod.world_model.task_emb.weight',
1268+
'_orig_mod.world_model.task_emb.bias', # 如果存在则添加
1269+
# 添加其他需要排除的具体参数名
1270+
]
1271+
1272+
def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
1273+
"""
1274+
过滤掉需要排除的参数。
1275+
"""
1276+
filtered = {}
1277+
for k, v in state_dict_loader.items():
1278+
if any(k.startswith(prefix) for prefix in exclude_prefixes):
1279+
print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
1280+
continue
1281+
if k in exclude_keys:
1282+
print(f"Excluding specific parameter: {k}") # 调试用
1283+
continue
1284+
filtered[k] = v
1285+
return filtered
1286+
1287+
# 过滤并加载 'model' 部分
1288+
if 'model' in state_dict:
1289+
model_state_dict = state_dict['model']
1290+
filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
1291+
missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
1292+
if missing_keys:
1293+
print(f"Missing keys when loading _learn_model: {missing_keys}")
1294+
if unexpected_keys:
1295+
print(f"Unexpected keys when loading _learn_model: {unexpected_keys}")
1296+
else:
1297+
print("No 'model' key found in the state_dict.")
1298+
1299+
# 过滤并加载 'target_model' 部分
1300+
if 'target_model' in state_dict:
1301+
target_model_state_dict = state_dict['target_model']
1302+
filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
1303+
missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
1304+
if missing_keys:
1305+
print(f"Missing keys when loading _target_model: {missing_keys}")
1306+
if unexpected_keys:
1307+
print(f"Unexpected keys when loading _target_model: {unexpected_keys}")
1308+
else:
1309+
print("No 'target_model' key found in the state_dict.")
1310+
1311+
# 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
1312+
# if 'optimizer_world_model' in state_dict:
1313+
# optimizer_state_dict = state_dict['optimizer_world_model']
1314+
# try:
1315+
# self._optimizer_world_model.load_state_dict(optimizer_state_dict)
1316+
# except Exception as e:
1317+
# print(f"Error loading optimizer state_dict: {e}")
1318+
# else:
1319+
# print("No 'optimizer_world_model' key found in the state_dict.")
1320+
1321+
# 如果需要,还可以加载其他部分,例如 scheduler 等

zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
4343
device='cuda',
4444
action_space_size=action_space_size,
4545
# batch_size=64 8games训练时,每张卡大约占 12*3=36G cuda显存
46-
num_layers=12,
46+
# num_layers=12,
47+
# num_heads=24,
48+
num_layers=8,
4749
num_heads=24,
4850
embed_dim=768,
4951
obs_type='image',
@@ -91,7 +93,7 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod
9193
norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,
9294
num_segments, total_batch_size):
9395
configs = []
94-
exp_name_prefix = f'data_unizero_mt_ddp-8gpu_20241226/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_seed{seed}/'
96+
exp_name_prefix = f'data_unizero_atari_mt_20250216/{len(env_id_list)}games_nlayer8_bs64_brf{buffer_reanalyze_freq}_seed{seed}/'
9597

9698
for task_id, env_id in enumerate(env_id_list):
9799
config = create_config(
@@ -155,12 +157,12 @@ def create_env_manager():
155157
reanalyze_partition = 0.75
156158

157159
# ======== TODO: only for debug ========
158-
collector_env_num = 2
159-
num_segments = 2
160-
n_episode = 2
161-
evaluator_env_num = 2
162-
num_simulations = 2
163-
batch_size = [4, 4, 4, 4, 4, 4, 4, 4]
160+
# collector_env_num = 2
161+
# num_segments = 2
162+
# n_episode = 2
163+
# evaluator_env_num = 2
164+
# num_simulations = 2
165+
# batch_size = [4, 4, 4, 4, 4, 4, 4, 4]
164166

165167

166168
for seed in [0]:

zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
5050
use_normal_head=True,
5151
use_softmoe_head=False,
5252
moe_in_transformer=False,
53-
multiplication_moe_in_transformer=False,
5453
num_experts_of_moe_in_transformer=4,
54+
multiplication_moe_in_transformer=False,
55+
num_experts_in_moe_head=4,
56+
use_moe_head=False,
5557
),
5658
),
5759
total_batch_size=total_batch_size,
@@ -83,7 +85,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu
8385

8486
def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size):
8587
configs = []
86-
exp_name_prefix = f'data_unizero_mt_ddp-2gpu_1201/finetune_pong/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/'
88+
exp_name_prefix = f'data_unizero_mt_ddp-2gpu_1201/finetune_amidar_load-encoder-backbone/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/'
89+
# exp_name_prefix = f'data_unizero_mt_ddp-2gpu_1201_debug/finetune_amidar_load-backbone/{len(env_id_list)}games_brf{buffer_reanalyze_freq}_1-encoder-{norm_type}-res2-channel256_gsl20_{len(env_id_list)}-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed{seed}/'
8790

8891
for task_id, env_id in enumerate(env_id_list):
8992
config = create_config(
@@ -127,22 +130,26 @@ def create_env_manager():
127130
Overview:
128131
This script should be executed with <nproc_per_node> GPUs.
129132
Run the following command to launch the script:
130-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py
133+
python -m torch.distributed.launch --nproc_per_node=1 --master_port=29503 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py
131134
torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py
132135
"""
133136

134137
from lzero.entry import train_unizero_multitask_segment_ddp
135138
from ding.utils import DDPContext
136139
from easydict import EasyDict
137140

138-
env_id_list = ['PongNoFrameskip-v4'] # Debug setup
141+
# env_id_list = ['PongNoFrameskip-v4'] # Debug setup
142+
env_id_list = ['AmidarNoFrameskip-v4'] # Debug setup
143+
139144
action_space_size = 18
140145

141146
# NCCL environment setup
142147
import os
143148
os.environ["NCCL_TIMEOUT"] = "3600000000"
144149

145-
for seed in [0, 1, 2]:
150+
# for seed in [0, 1, 2]:
151+
for seed in [0]:
152+
146153
collector_env_num = 8
147154
num_segments = 8
148155
n_episode = 8
@@ -163,7 +170,7 @@ def create_env_manager():
163170

164171
configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size)
165172

166-
pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/iteration_200000.pth.tar'
173+
pretrained_model_path = '/mnt/afs/niuyazhe/code/LightZero/data_unizero_mt_ddp-8gpu_1127/8games_brf0.02_nlayer8-nhead24_seed1/8games_brf0.02_1-encoder-LN-res2-channel256_gsl20_8-pred-head_lsd768-nlayer8-nh24_mbs-512-bs64_upc80_seed1/Pong_unizero-mt_seed1/ckpt/ckpt_best.pth.tar'
167174

168175
with DDPContext():
169176
train_unizero_multitask_segment_ddp(configs, seed=seed, model_path=pretrained_model_path, max_env_step=max_env_step)

0 commit comments

Comments
 (0)