Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging
import os
import random
Expand Down Expand Up @@ -342,6 +343,7 @@ def _compute_log_prob(
target_tokens=batch["tokens"],
allow_compile=not self.args.true_on_policy_mode,
temperature=self.args.rollout_temperature,
requires_entropy_grad=False,
)
batch[f"{store_prefix}log_probs"] = log_probs_result
if store_prefix == "":
Expand Down Expand Up @@ -547,13 +549,15 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum):
# Prepare model inputs
model_args = self._get_model_inputs_args(packed_batch)
logits = self.model(**model_args).logits.squeeze(0).float()
entropy_requires_grad = self.args.entropy_coef > 0

# Compute log probs and entropy
log_probs, entropy_result = get_logprob_and_entropy(
logits=logits,
target_tokens=packed_batch["tokens"],
allow_compile=not self.args.true_on_policy_mode,
temperature=self.args.rollout_temperature,
requires_entropy_grad=entropy_requires_grad,
)
packed_batch["cur_log_probs"] = log_probs
packed_batch["entropy"] = entropy_result
Expand Down Expand Up @@ -861,6 +865,7 @@ def get_logprob_and_entropy(
target_tokens: torch.Tensor,
allow_compile: bool,
temperature: float | None = None,
requires_entropy_grad: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute log probabilities and entropy.

Expand All @@ -878,9 +883,11 @@ def get_logprob_and_entropy(
log_probs = gather_log_probs_packed(
shifted_logits, target_tokens, allow_compile=allow_compile, temperature=temperature
)
log_probs_full = torch.log_softmax(shifted_logits, dim=-1)
probs = torch.softmax(shifted_logits, dim=-1)
entropy = -(probs * log_probs_full).sum(dim=-1)
entropy_context = torch.no_grad() if not requires_entropy_grad else contextlib.nullcontext()
with entropy_context:
log_probs_full = torch.log_softmax(shifted_logits, dim=-1)
probs = torch.softmax(shifted_logits, dim=-1)
entropy = -(probs * log_probs_full).sum(dim=-1)
return log_probs, entropy


Expand Down
5 changes: 5 additions & 0 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_log_probs_and_entropy(
total_lengths: list[int],
response_lengths: list[int],
with_entropy: bool = False,
requires_entropy_grad: bool = True,
non_loss_data: bool = True,
max_seq_lens: list[int] | None = None,
) -> dict[str, list[torch.Tensor]]:
Expand Down Expand Up @@ -159,6 +160,7 @@ def get_log_probs_and_entropy(
mpu.get_tensor_model_parallel_group(),
with_entropy=with_entropy,
chunk_size=args.log_probs_chunk_size,
requires_entropy_grad=requires_entropy_grad,
)

log_probs_list.append(log_prob.squeeze(-1))
Expand Down Expand Up @@ -475,6 +477,7 @@ def policy_loss_function(
response_lengths = batch["response_lengths"]
total_lengths = batch["total_lengths"]
max_seq_lens = batch.get("max_seq_lens", None)
entropy_requires_grad = args.entropy_coef > 0

_, log_probs_and_entropy = get_log_probs_and_entropy(
logits,
Expand All @@ -483,6 +486,7 @@ def policy_loss_function(
total_lengths=total_lengths,
response_lengths=response_lengths,
with_entropy=True,
requires_entropy_grad=entropy_requires_grad,
max_seq_lens=max_seq_lens,
)

Expand Down Expand Up @@ -745,6 +749,7 @@ def sft_loss_function(
total_lengths=total_lengths,
response_lengths=response_lengths,
with_entropy=False,
requires_entropy_grad=False,
max_seq_lens=batch.get("max_seq_lens", None),
)

Expand Down
21 changes: 14 additions & 7 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py

import contextlib
from argparse import Namespace

import torch
Expand Down Expand Up @@ -646,7 +647,9 @@ def chunked_gae(
return advantages, returns


def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1):
def calculate_log_probs_and_entropy(
logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, requires_entropy_grad: bool = True
):
logits = logits.contiguous()
# TODO: not sure why we need to clone the logits here.
# Without the clone, the backward will trigger inplace edit error.
Expand All @@ -663,15 +666,19 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool
log_probs.append(log_prob)
log_prob = torch.cat(log_probs, dim=0)
if with_entropy:
entropys = []
for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group)
entropys.append(entropy)
entropy = torch.cat(entropys, dim=0)
entropy_context = torch.no_grad() if not requires_entropy_grad else contextlib.nullcontext()
with entropy_context:
entropys = []
for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True):
e = compute_entropy_from_logits(logits_chunk.clone(), tp_group)
entropys.append(e)
entropy = torch.cat(entropys, dim=0)
else:
log_prob = compute_log_probs(logits.clone(), tokens, tp_group)
if with_entropy:
entropy = compute_entropy_from_logits(logits.clone(), tp_group)
entropy_context = torch.no_grad() if not requires_entropy_grad else contextlib.nullcontext()
with entropy_context:
entropy = compute_entropy_from_logits(logits.clone(), tp_group)
else:
log_prob = logits.new_zeros((0,))
if with_entropy:
Expand Down
Loading