Skip to content

关于grpo的几个问题 #2

@rlczddl

Description

@rlczddl

# Forward pass through current policy (for old_log_probs - no gradients needed)
# Note: old_log_probs represent the policy at data collection time (frozen)
with torch.no_grad():
current_log_probs = self._get_log_probs(prompts, responses, use_ref=False)
# Forward pass through current policy (for new policy - needs gradients)
# Note: new_log_probs are used for ratio computation and need gradients
new_log_probs = self._get_log_probs(prompts, responses, use_ref=False)
# Forward pass through reference policy (frozen - no gradients)
# Note: ref_log_probs are used for KL divergence and should be frozen
ref_log_probs = self._get_log_probs(prompts, responses, use_ref=True)
# Compute loss
loss_dict = self.compute_loss(
new_log_probs,
ref_log_probs,
advantages,
current_log_probs
)
这里回放、旧策略,新策略的计算似乎也有些问题

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions