@@ -343,7 +343,7 @@ def _init_learn(self) -> None:
343
343
wandb .watch (self ._learn_model .representation_network , log = "all" )
344
344
345
345
# TODO: ========
346
- self .accumulation_steps = 4 # 累积的步数
346
+ self .accumulation_steps = 1 # 累积的步数
347
347
348
348
# @profile
349
349
def _forward_learn (self , data : Tuple [torch .Tensor ]) -> Dict [str , Union [float , int ]]:
@@ -467,8 +467,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
467
467
# assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
468
468
469
469
# Core learn model update step
470
- if train_iter % self .accumulation_steps == 0 : # 每 accumulation_steps 步更新一次参数
471
- # print(f'train_iter:{train_iter}')
470
+ # print(f'train_iter:{train_iter}')
471
+ # 假设 train_iter 是从 0 开始计数
472
+ if (train_iter % self .accumulation_steps ) == 0 :
473
+ # 每个累计周期的第一个step时清零梯度
474
+ # print(f'train_iter:{train_iter} self._optimizer_world_model.zero_grad()')
472
475
self ._optimizer_world_model .zero_grad ()
473
476
474
477
weighted_total_loss = weighted_total_loss / self .accumulation_steps # 累积梯度时对 loss 进行缩放
@@ -481,16 +484,30 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
481
484
# if param.requires_grad:
482
485
# print(name, param.grad.norm())
483
486
484
- if self ._cfg .analysis_sim_norm :
485
- del self .l2_norm_before , self .l2_norm_after , self .grad_norm_before , self .grad_norm_after
486
- self .l2_norm_before , self .l2_norm_after , self .grad_norm_before , self .grad_norm_after = self ._learn_model .encoder_hook .analyze ()
487
- self ._target_model .encoder_hook .clear_data ()
488
-
489
- total_grad_norm_before_clip_wm = torch .nn .utils .clip_grad_norm_ (self ._learn_model .world_model .parameters (),
490
- self ._cfg .grad_clip_value )
487
+ # if self._cfg.analysis_sim_norm:
488
+ # del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
489
+ # self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
490
+ # self._target_model.encoder_hook.clear_data()
491
+
492
+ # total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(),
493
+ # self._cfg.grad_clip_value)
494
+
495
+ # 判断是否完成了一个累计周期(例如:如果 accumulation_steps=4, 则在 4,8,12,... 次迭代时更新参数)
496
+ if (train_iter + 1 ) % self .accumulation_steps == 0 :
497
+ # print(f'train_iter:{train_iter} self._optimizer_world_model.step()')
498
+
499
+ # ========== 分析梯度模的代码 ==========
500
+ if self ._cfg .analysis_sim_norm :
501
+ # 删除上次的分析结果,防止累积过多内存
502
+ del self .l2_norm_before , self .l2_norm_after , self .grad_norm_before , self .grad_norm_after
503
+ self .l2_norm_before , self .l2_norm_after , self .grad_norm_before , self .grad_norm_after = self ._learn_model .encoder_hook .analyze ()
504
+ self ._target_model .encoder_hook .clear_data ()
505
+
506
+ # 对梯度进行裁剪
507
+ total_grad_norm_before_clip_wm = torch .nn .utils .clip_grad_norm_ (
508
+ self ._learn_model .world_model .parameters (), self ._cfg .grad_clip_value
509
+ )
491
510
492
- if train_iter % self .accumulation_steps == 0 : # 每 accumulation_steps 步更新一次参数
493
- # print(f'pos 2 train_iter:{train_iter}')
494
511
495
512
if self ._cfg .multi_gpu :
496
513
self .sync_gradients (self ._learn_model )
@@ -503,8 +520,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
503
520
# Core target model update step
504
521
self ._target_model .update (self ._learn_model .state_dict ())
505
522
506
- if self .accumulation_steps > 1 :
523
+ if self .accumulation_steps > 1 :
507
524
torch .cuda .empty_cache ()
525
+ else :
526
+ total_grad_norm_before_clip_wm = torch .tensor (0. )
508
527
509
528
if torch .cuda .is_available ():
510
529
torch .cuda .synchronize ()
0 commit comments