@@ -207,6 +207,11 @@ class UniZeroMTPolicy(UniZeroPolicy):
207
207
),
208
208
),
209
209
# ****** common ******
210
+ # (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt).
211
+ # If set to True, the checkpoint will be evaluated after the training process is complete.
212
+ # IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency.
213
+ # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically.
214
+ eval_offline = False ,
210
215
# (bool) whether to use rnd model.
211
216
use_rnd_model = False ,
212
217
# (bool) Whether to use multi-gpu training.
@@ -1144,27 +1149,35 @@ def _state_dict_learn(self) -> Dict[str, Any]:
1144
1149
}
1145
1150
1146
1151
# ========== TODO: original version: load all parameters ==========
1147
- def _load_state_dict_learn (self , state_dict : Dict [str , Any ]) -> None :
1148
- """
1149
- Overview:
1150
- Load the state_dict variable into policy learn mode.
1151
- Arguments:
1152
- - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
1153
- """
1154
- self ._learn_model .load_state_dict (state_dict ['model' ])
1155
- self ._target_model .load_state_dict (state_dict ['target_model' ])
1156
- self ._optimizer_world_model .load_state_dict (state_dict ['optimizer_world_model' ])
1152
+ # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1153
+ # """
1154
+ # Overview:
1155
+ # Load the state_dict variable into policy learn mode.
1156
+ # Arguments:
1157
+ # - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
1158
+ # """
1159
+ # self._learn_model.load_state_dict(state_dict['model'])
1160
+ # self._target_model.load_state_dict(state_dict['target_model'])
1161
+ # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model'])
1157
1162
1158
- # # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters, head use re init weight ==========
1163
+ # # 仅加载 transformer_backbone 的参数,而 encoder 和 head 以及其他部分将保留原有的初始化参数。
1159
1164
# def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1160
1165
# """
1161
1166
# Overview:
1162
- # Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
1167
+ # Load the state_dict variable into policy learn mode,
1168
+ # loading only the transformer_backbone parameters.
1169
+ # The encoder, head, and other parts retain their original initialized parameters.
1163
1170
# Arguments:
1164
1171
# - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
1165
1172
# """
1173
+ # # 定义需要加载的参数前缀(仅 transformer_backbone)
1174
+ # include_prefixes = [
1175
+ # '_orig_mod.world_model.transformer.'
1176
+ # ]
1177
+
1166
1178
# # 定义需要排除的参数前缀
1167
1179
# exclude_prefixes = [
1180
+ # '_orig_mod.world_model.tokenizer.',
1168
1181
# '_orig_mod.world_model.head_policy_multi_task.',
1169
1182
# '_orig_mod.world_model.head_value_multi_task.',
1170
1183
# '_orig_mod.world_model.head_rewards_multi_task.',
@@ -1179,25 +1192,33 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1179
1192
# # 添加其他需要排除的具体参数名
1180
1193
# ]
1181
1194
1182
- # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]:
1195
+ # def filter_state_dict(state_dict_loader: Dict[str, Any], include_prefixes: list, exclude_prefixes: list = [] , exclude_keys: list = []) -> Dict[str, Any]:
1183
1196
# """
1184
- # 过滤掉需要排除的参数 。
1197
+ # 过滤仅包含需要加载的参数,并排除不需要的参数 。
1185
1198
# """
1186
1199
# filtered = {}
1187
1200
# for k, v in state_dict_loader.items():
1201
+ # # 仅包含指定前缀的参数
1202
+ # if not any(k.startswith(prefix) for prefix in include_prefixes):
1203
+ # continue
1204
+ # # 排除指定前缀的参数(如果有)
1188
1205
# if any(k.startswith(prefix) for prefix in exclude_prefixes):
1189
- # print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除
1206
+ # print(f"Excluding parameter by prefix : {k}") # 调试用
1190
1207
# continue
1208
+ # # 排除指定键的参数(如果有)
1191
1209
# if k in exclude_keys:
1192
1210
# print(f"Excluding specific parameter: {k}") # 调试用
1193
1211
# continue
1194
1212
# filtered[k] = v
1195
1213
# return filtered
1196
1214
1197
- # # 过滤并加载 'model' 部分
1215
+ # # 过滤并加载 'model' 部分(仅 transformer_backbone)
1198
1216
# if 'model' in state_dict:
1199
1217
# model_state_dict = state_dict['model']
1200
- # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys)
1218
+ # # print(f'='*20)
1219
+ # # print(f'model_state_dict:{model_state_dict.keys()}')
1220
+ # # print(f'='*20)
1221
+ # filtered_model_state_dict = filter_state_dict(model_state_dict, include_prefixes, exclude_prefixes, exclude_keys)
1201
1222
# missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False)
1202
1223
# if missing_keys:
1203
1224
# print(f"Missing keys when loading _learn_model: {missing_keys}")
@@ -1206,10 +1227,12 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1206
1227
# else:
1207
1228
# print("No 'model' key found in the state_dict.")
1208
1229
1209
- # # 过滤并加载 'target_model' 部分
1230
+ # # 不需要重新初始化 head 部分,因为它们未被加载,保持原有初始化参数
1231
+
1232
+ # # 过滤并加载 'target_model' 部分(仅 transformer_backbone)
1210
1233
# if 'target_model' in state_dict:
1211
1234
# target_model_state_dict = state_dict['target_model']
1212
- # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys)
1235
+ # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, include_prefixes, exclude_prefixes, exclude_keys)
1213
1236
# missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False)
1214
1237
# if missing_keys:
1215
1238
# print(f"Missing keys when loading _target_model: {missing_keys}")
@@ -1218,14 +1241,81 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
1218
1241
# else:
1219
1242
# print("No 'target_model' key found in the state_dict.")
1220
1243
1221
- # # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
1222
- # if 'optimizer_world_model' in state_dict:
1223
- # optimizer_state_dict = state_dict['optimizer_world_model']
1224
- # try:
1225
- # self._optimizer_world_model.load_state_dict(optimizer_state_dict)
1226
- # except Exception as e:
1227
- # print(f"Error loading optimizer state_dict: {e}")
1228
- # else:
1229
- # print("No 'optimizer_world_model' key found in the state_dict.")
1244
+ # # 不需要重新初始化 target_model 的 head,因为它们未被加载,保持原有初始化参数
1245
+
1230
1246
1231
- # # 如果需要,还可以加载其他部分,例如 scheduler 等
1247
+
1248
+ # # ========== TODO: pretrain-finetue version: only load encoder and transformer-backbone parameters, head use re init weight ==========
1249
+ def _load_state_dict_learn (self , state_dict : Dict [str , Any ]) -> None :
1250
+ """
1251
+ Overview:
1252
+ Load the state_dict variable into policy learn mode, excluding multi-task related parameters.
1253
+ Arguments:
1254
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved previously.
1255
+ """
1256
+ # 定义需要排除的参数前缀
1257
+ exclude_prefixes = [
1258
+ '_orig_mod.world_model.head_policy_multi_task.' ,
1259
+ '_orig_mod.world_model.head_value_multi_task.' ,
1260
+ '_orig_mod.world_model.head_rewards_multi_task.' ,
1261
+ '_orig_mod.world_model.head_observations_multi_task.' ,
1262
+ '_orig_mod.world_model.task_emb.'
1263
+ ]
1264
+
1265
+ # 定义需要排除的具体参数(如果有特殊情况)
1266
+ exclude_keys = [
1267
+ '_orig_mod.world_model.task_emb.weight' ,
1268
+ '_orig_mod.world_model.task_emb.bias' , # 如果存在则添加
1269
+ # 添加其他需要排除的具体参数名
1270
+ ]
1271
+
1272
+ def filter_state_dict (state_dict_loader : Dict [str , Any ], exclude_prefixes : list , exclude_keys : list = []) -> Dict [str , Any ]:
1273
+ """
1274
+ 过滤掉需要排除的参数。
1275
+ """
1276
+ filtered = {}
1277
+ for k , v in state_dict_loader .items ():
1278
+ if any (k .startswith (prefix ) for prefix in exclude_prefixes ):
1279
+ print (f"Excluding parameter: { k } " ) # 调试用,查看哪些参数被排除
1280
+ continue
1281
+ if k in exclude_keys :
1282
+ print (f"Excluding specific parameter: { k } " ) # 调试用
1283
+ continue
1284
+ filtered [k ] = v
1285
+ return filtered
1286
+
1287
+ # 过滤并加载 'model' 部分
1288
+ if 'model' in state_dict :
1289
+ model_state_dict = state_dict ['model' ]
1290
+ filtered_model_state_dict = filter_state_dict (model_state_dict , exclude_prefixes , exclude_keys )
1291
+ missing_keys , unexpected_keys = self ._learn_model .load_state_dict (filtered_model_state_dict , strict = False )
1292
+ if missing_keys :
1293
+ print (f"Missing keys when loading _learn_model: { missing_keys } " )
1294
+ if unexpected_keys :
1295
+ print (f"Unexpected keys when loading _learn_model: { unexpected_keys } " )
1296
+ else :
1297
+ print ("No 'model' key found in the state_dict." )
1298
+
1299
+ # 过滤并加载 'target_model' 部分
1300
+ if 'target_model' in state_dict :
1301
+ target_model_state_dict = state_dict ['target_model' ]
1302
+ filtered_target_model_state_dict = filter_state_dict (target_model_state_dict , exclude_prefixes , exclude_keys )
1303
+ missing_keys , unexpected_keys = self ._target_model .load_state_dict (filtered_target_model_state_dict , strict = False )
1304
+ if missing_keys :
1305
+ print (f"Missing keys when loading _target_model: { missing_keys } " )
1306
+ if unexpected_keys :
1307
+ print (f"Unexpected keys when loading _target_model: { unexpected_keys } " )
1308
+ else :
1309
+ print ("No 'target_model' key found in the state_dict." )
1310
+
1311
+ # 加载优化器的 state_dict,不需要过滤,因为优化器通常不包含模型参数
1312
+ # if 'optimizer_world_model' in state_dict:
1313
+ # optimizer_state_dict = state_dict['optimizer_world_model']
1314
+ # try:
1315
+ # self._optimizer_world_model.load_state_dict(optimizer_state_dict)
1316
+ # except Exception as e:
1317
+ # print(f"Error loading optimizer state_dict: {e}")
1318
+ # else:
1319
+ # print("No 'optimizer_world_model' key found in the state_dict.")
1320
+
1321
+ # 如果需要,还可以加载其他部分,例如 scheduler 等
0 commit comments