@@ -250,6 +250,9 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
250250 input_ids_forward_micro_batch = data ["input_ids" ][
251251 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
252252 ]
253+ old_action_log_probs_micro_batch = old_action_log_probs [
254+ forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255+ ]
253256 attention_mask_forward_micro_batch = data ["attention_mask" ][
254257 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
255258 ]
@@ -306,17 +309,22 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
306309 "action_mask" : action_mask_forward_micro_batch ,
307310 "advantages" : advantages_forward_micro_batch ,
308311 "loss_mask" : loss_mask_forward_micro_batch ,
312+ "old_action_log_probs" : old_action_log_probs_micro_batch ,
309313 "source" : self .rank ,
310314 }
311315 if reference_action_log_probs is not None :
312316 data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
313317
314318 kl = []
315- policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
316319
317320 def _criterion (outputs , inputs ):
318321 action_logits = outputs .logits
319- policy_model_logits .copy_ (action_logits )
322+ mini_batch_entropies .append (
323+ (
324+ ((entropy_from_logits (action_logits [:, - num_action :]) * inputs ["action_mask" ]).sum (- 1 ))
325+ / inputs ["action_mask" ].sum (- 1 )
326+ ).detach ()
327+ )
320328 action_log_probs = memory_efficient_logprob (
321329 action_logits / self .generate_config ["temperature" ],
322330 inputs ["input_ids" ],
@@ -339,7 +347,7 @@ def _criterion(outputs, inputs):
339347
340348 loss , _ = self .policy_loss_fn (
341349 action_log_probs ,
342- action_log_probs ,
350+ inputs [ "old_action_log_probs" ] ,
343351 inputs ["advantages" ].repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
344352 per_token_kl ,
345353 inputs ["action_mask" ],
@@ -363,20 +371,6 @@ def _criterion(outputs, inputs):
363371 kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
364372 mean_kl .append (kl )
365373 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
366- mini_batch_entropies .append (
367- all_reduce_mean (
368- (
369- (
370- (
371- entropy_from_logits (policy_model_logits [:, - num_action :])
372- * action_mask_forward_micro_batch
373- ).sum (- 1 )
374- )
375- / action_mask_forward_micro_batch .sum (- 1 )
376- ).detach (),
377- self .plugin ,
378- )
379- )
380374 else :
381375 policy_model_logits = self .policy_model (
382376 input_ids = input_ids_forward_micro_batch ,
@@ -415,7 +409,7 @@ def _criterion(outputs, inputs):
415409
416410 loss , _ = self .policy_loss_fn (
417411 action_log_probs ,
418- old_action_log_probs ,
412+ old_action_log_probs_micro_batch ,
419413 advantages_forward_micro_batch .repeat_interleave (action_log_probs .size (- 1 ), dim = - 1 ),
420414 per_token_kl ,
421415 action_mask_forward_micro_batch ,
@@ -455,7 +449,7 @@ def _criterion(outputs, inputs):
455449 ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
456450 advantages = all_reduce_mean (advantages .mean (), self .plugin )
457451 response_length = all_reduce_mean (response_length .mean (), self .plugin )
458- entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
452+ entropy = all_reduce_mean ( torch .cat (mini_batch_entropies , dim = 0 ).mean (), self . plugin )
459453 self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
460454 self .accum_entropy .add_ (entropy .data )
461455 if self .policy_loss_fn .beta > 0 :
0 commit comments