Skip to content

Conversation

@Beichen-Ma
Copy link

Motivation

Try to optimize context parallel for FSDP backend #1062 by eliminating the all_gather for entropy and log_probs. Instead of gathering log_probs and entropy across all CP ranks and then distributing them, each rank now computes and retains only its local portian, with handling during unpacking and loss computation.

Changes

  • get_logprob_and_entropy_with_cp change from all_gather full log_probs/entropy to return local shards only
  • unpack_sequences_with_cp handle tensor slicing, extracting only the portion of each sequence's response that overlaps with the current rank
  • all_reduce the final loss across cp ranks

Impact

During training, get_logprob_and_entropy_with_cp is called 3 times to obtain ref logprobs, actor logprobs and current logprobs. Three all_gather operations on full-lengths response log_probs+entropy replaced by 1 scalar all_reduce.

Concerns

  • The sum_of_sample_mean is now computed as a sum of local contributions.

@yueming-yuan
Copy link
Collaborator

(this might not be compatible with new FSDP impl...)

@Beichen-Ma Beichen-Ma closed this Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants