66import wandb
77from coati .distributed .consumer import BaseConsumer
88from coati .distributed .loss import PolicyLoss
9- from coati .distributed .utils import memory_efficient_logprob
9+ from coati .distributed .utils import entropy_from_logits , memory_efficient_logprob
1010from coati .trainer .utils import all_reduce_mean , all_reduce_sum
1111from transformers import AutoModelForCausalLM , AutoTokenizer
1212
@@ -75,6 +75,7 @@ def __init__(
7575 self .optimizer = HybridAdam (self .policy_model .parameters (), lr = grpo_config .get ("lr" , 1e-6 ))
7676 self .accum_loss = torch .zeros (1 , device = self .device )
7777 self .accum_kl = torch .zeros (1 , device = self .device )
78+ self .accum_entropy = torch .zeros (1 , device = self .device )
7879 self .accum_advantages = torch .zeros (1 , device = self .device )
7980 self .raw_train_batch_reward = []
8081 self .raw_train_batch_format_acc = []
@@ -244,6 +245,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
244245 else self .booster .no_sync (self .policy_model , self .optimizer )
245246 )
246247 with ctx :
248+ mini_batch_entropies = []
247249 for forward_micro_batch_start in range (0 , data ["input_ids" ].size (0 ), train_microbatch_size ):
248250 input_ids_forward_micro_batch = data ["input_ids" ][
249251 forward_micro_batch_start : forward_micro_batch_start + train_microbatch_size
@@ -310,9 +312,11 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
310312 data_policy_forward ["reference_action_log_probs" ] = reference_action_log_probs
311313
312314 kl = []
315+ policy_model_logits = torch .empty_like (input_ids_forward_micro_batch , device = self .device )
313316
314317 def _criterion (outputs , inputs ):
315318 action_logits = outputs .logits
319+ policy_model_logits .copy_ (action_logits )
316320 action_log_probs = memory_efficient_logprob (
317321 action_logits / self .generate_config ["temperature" ],
318322 inputs ["input_ids" ],
@@ -359,6 +363,20 @@ def _criterion(outputs, inputs):
359363 kl = all_reduce_mean (torch .mean (torch .stack (kl )).to (loss .device ), self .plugin ).data
360364 mean_kl .append (kl )
361365 mean_loss .append (all_reduce_mean (loss , self .plugin ).data )
366+ mini_batch_entropies .append (
367+ all_reduce_mean (
368+ (
369+ (
370+ (
371+ entropy_from_logits (policy_model_logits [:, - num_action :])
372+ * action_mask_forward_micro_batch
373+ ).sum (- 1 )
374+ )
375+ / action_mask_forward_micro_batch .sum (- 1 )
376+ ).detach (),
377+ self .plugin ,
378+ )
379+ )
362380 else :
363381 policy_model_logits = self .policy_model (
364382 input_ids = input_ids_forward_micro_batch ,
@@ -412,6 +430,20 @@ def _criterion(outputs, inputs):
412430 kl = all_reduce_mean (kl .mean (), self .plugin )
413431 mean_kl .append (kl .data )
414432 mean_loss .append (loss .data )
433+ mini_batch_entropies .append (
434+ all_reduce_mean (
435+ (
436+ (
437+ (
438+ entropy_from_logits (policy_model_logits [:, - num_action :])
439+ * action_mask_forward_micro_batch
440+ ).sum (- 1 )
441+ )
442+ / action_mask_forward_micro_batch .sum (- 1 )
443+ ).detach (),
444+ self .plugin ,
445+ )
446+ )
415447 if not self .plugin .pp_size > 1 or (
416448 self .plugin .pp_size > 1
417449 and self .booster .plugin .stage_manager .is_last_stage ()
@@ -423,7 +455,9 @@ def _criterion(outputs, inputs):
423455 ans_acc = all_reduce_mean (ans_acc .mean (), self .plugin )
424456 advantages = all_reduce_mean (advantages .mean (), self .plugin )
425457 response_length = all_reduce_mean (response_length .mean (), self .plugin )
458+ entropy = torch .cat (mini_batch_entropies , dim = 0 ).mean ()
426459 self .accum_loss .add_ (sum (mean_loss ) / len (mean_loss ))
460+ self .accum_entropy .add_ (entropy .data )
427461 if self .policy_loss_fn .beta > 0 :
428462 self .accum_kl .add_ (sum (mean_kl ) / len (mean_kl ))
429463 self .accum_advantages .add_ (advantages .data )
@@ -464,6 +498,7 @@ def _criterion(outputs, inputs):
464498 f"Response Length: { raw_batch_response_len_mean :.4f} " ,
465499 f"Sample_utilization: { sample_utilization :.4f} " ,
466500 f"Overlength samples ratio: { overlength_samples_ratio :.4f} " ,
501+ f"Entropy: { self .accum_entropy .item () / self .accum_count :.4f} " ,
467502 ] + ([f"KL: { self .accum_kl .item () / self .accum_count :.4f} " ] if self .policy_loss_fn .beta > 0 else [])
468503 print ("\n " .join (to_log_msg ))
469504 metrics = {
@@ -475,6 +510,7 @@ def _criterion(outputs, inputs):
475510 "train/advantages" : self .accum_advantages .item () / self .accum_count ,
476511 "train/learning_rate" : self .lr_scheduler .get_last_lr ()[0 ],
477512 "train/sample_utilization" : sample_utilization ,
513+ "train/entropy" : self .accum_entropy .item () / self .accum_count ,
478514 "train/overlength_samples_ratio" : overlength_samples_ratio ,
479515 "rollout/temperature" : data ["temperature" ].cpu ().numpy ()[0 ][0 ],
480516 }
@@ -484,6 +520,7 @@ def _criterion(outputs, inputs):
484520 self .wandb_run .log (metrics )
485521 self .accum_loss .zero_ ()
486522 self .accum_kl .zero_ ()
523+ self .accum_entropy .zero_ ()
487524 self .accum_advantages .zero_ ()
488525 self .accum_count = 0
489526 return loss_scalar
0 commit comments