diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 4f85a6a80..1b141fb29 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -1,3 +1,4 @@ +import contextlib import logging import os import random @@ -342,6 +343,7 @@ def _compute_log_prob( target_tokens=batch["tokens"], allow_compile=not self.args.true_on_policy_mode, temperature=self.args.rollout_temperature, + requires_entropy_grad=False, ) batch[f"{store_prefix}log_probs"] = log_probs_result if store_prefix == "": @@ -547,6 +549,7 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): # Prepare model inputs model_args = self._get_model_inputs_args(packed_batch) logits = self.model(**model_args).logits.squeeze(0).float() + entropy_requires_grad = self.args.entropy_coef > 0 # Compute log probs and entropy log_probs, entropy_result = get_logprob_and_entropy( @@ -554,6 +557,7 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): target_tokens=packed_batch["tokens"], allow_compile=not self.args.true_on_policy_mode, temperature=self.args.rollout_temperature, + requires_entropy_grad=entropy_requires_grad, ) packed_batch["cur_log_probs"] = log_probs packed_batch["entropy"] = entropy_result @@ -861,6 +865,7 @@ def get_logprob_and_entropy( target_tokens: torch.Tensor, allow_compile: bool, temperature: float | None = None, + requires_entropy_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute log probabilities and entropy. @@ -878,9 +883,11 @@ def get_logprob_and_entropy( log_probs = gather_log_probs_packed( shifted_logits, target_tokens, allow_compile=allow_compile, temperature=temperature ) - log_probs_full = torch.log_softmax(shifted_logits, dim=-1) - probs = torch.softmax(shifted_logits, dim=-1) - entropy = -(probs * log_probs_full).sum(dim=-1) + entropy_context = torch.no_grad() if not requires_entropy_grad else contextlib.nullcontext() + with entropy_context: + log_probs_full = torch.log_softmax(shifted_logits, dim=-1) + probs = torch.softmax(shifted_logits, dim=-1) + entropy = -(probs * log_probs_full).sum(dim=-1) return log_probs, entropy diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index e5002a78e..d51775a09 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -117,6 +117,7 @@ def get_log_probs_and_entropy( total_lengths: list[int], response_lengths: list[int], with_entropy: bool = False, + requires_entropy_grad: bool = True, non_loss_data: bool = True, max_seq_lens: list[int] | None = None, ) -> dict[str, list[torch.Tensor]]: @@ -159,6 +160,7 @@ def get_log_probs_and_entropy( mpu.get_tensor_model_parallel_group(), with_entropy=with_entropy, chunk_size=args.log_probs_chunk_size, + requires_entropy_grad=requires_entropy_grad, ) log_probs_list.append(log_prob.squeeze(-1)) @@ -475,6 +477,7 @@ def policy_loss_function( response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] max_seq_lens = batch.get("max_seq_lens", None) + entropy_requires_grad = args.entropy_coef > 0 _, log_probs_and_entropy = get_log_probs_and_entropy( logits, @@ -483,6 +486,7 @@ def policy_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=True, + requires_entropy_grad=entropy_requires_grad, max_seq_lens=max_seq_lens, ) @@ -745,6 +749,7 @@ def sft_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=False, + requires_entropy_grad=False, max_seq_lens=batch.get("max_seq_lens", None), ) diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2404883ab..14dcac038 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -1,6 +1,7 @@ # Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py +import contextlib from argparse import Namespace import torch @@ -646,7 +647,9 @@ def chunked_gae( return advantages, returns -def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1): +def calculate_log_probs_and_entropy( + logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, requires_entropy_grad: bool = True +): logits = logits.contiguous() # TODO: not sure why we need to clone the logits here. # Without the clone, the backward will trigger inplace edit error. @@ -663,15 +666,19 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) if with_entropy: - entropys = [] - for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): - entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group) - entropys.append(entropy) - entropy = torch.cat(entropys, dim=0) + entropy_context = torch.no_grad() if not requires_entropy_grad else contextlib.nullcontext() + with entropy_context: + entropys = [] + for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + e = compute_entropy_from_logits(logits_chunk.clone(), tp_group) + entropys.append(e) + entropy = torch.cat(entropys, dim=0) else: log_prob = compute_log_probs(logits.clone(), tokens, tp_group) if with_entropy: - entropy = compute_entropy_from_logits(logits.clone(), tp_group) + entropy_context = torch.no_grad() if not requires_entropy_grad else contextlib.nullcontext() + with entropy_context: + entropy = compute_entropy_from_logits(logits.clone(), tp_group) else: log_prob = logits.new_zeros((0,)) if with_entropy: