Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

take top bottom of generating n #369

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 72 additions & 37 deletions open_instruct/online_dpo_vllm_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class Args:
"""the beta value of the RLHF objective (KL coefficient)"""
num_generation_per_prompt: int = 2
"""the number of generations per prompt (currently only support 2)"""
take_top_bottom_generation: bool = False
"""learn on only one pair from each num_generation_per_prompt sample
the top and bottom scoring completions are chosen"""
loss_type: Literal["sigmoid", "ipo"] = "sigmoid"
"""the loss type for the DPO algorithm"""

Expand Down Expand Up @@ -255,11 +258,18 @@ def calculate_runtime_args_and_accelerator(args: Args, model_config: ModelConfig
args.num_training_steps = args.total_episodes // args.batch_size
args.eval_freq = max(1, args.num_training_steps // args.num_evals)
# DPO logic: repeats the same prompt `num_generation_per_prompt` times
# if take_top_bottom_generation only 2 repeats are used for training
# otherwise, all num_generation_per_prompt are used for training
args.num_training_samples_per_prompt = 2 if args.take_top_bottom_generation else args.num_generation_per_prompt
if args.num_generation_per_prompt > 2 and not args.take_top_bottom_generation:
raise NotImplementedError("Currently only supports take_top_bottom_generation for generating n > 2 completions")

args.local_dataloader_batch_size = exact_div(
args.local_batch_size,
args.num_generation_per_prompt,
args.num_training_samples_per_prompt,
"`local_batch_size` must be a multiple of `num_generation_per_prompt`",
)
args.generated_batch_size = args.local_dataloader_batch_size * args.num_generation_per_prompt * args.world_size
if args.push_to_hub:
if args.hf_repo_id is None: # auto-generate one
args.hf_repo_id = "open_instruct_dev"
Expand Down Expand Up @@ -587,7 +597,7 @@ def repeat_generator():
thread.start()
torch.cuda.set_device(device)

g_vllm_responses = torch.zeros((args.batch_size, args.response_length), device=device, dtype=torch.long)
g_vllm_responses = torch.zeros((args.generated_batch_size, args.response_length), device=device, dtype=torch.long)

# set up the metrics and initial states
stats_shape = (args.num_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
Expand Down Expand Up @@ -657,9 +667,7 @@ def repeat_generator():
training_time_start = time.time()
with torch.no_grad():
context_length = queries.shape[1]
responses = []
postprocessed_responses = []
ref_logprobs = []
scores = []
sequence_lengths = []
if accelerator.is_main_process:
Expand Down Expand Up @@ -687,15 +695,6 @@ def repeat_generator():
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]

ref_output = forward(ref_model, query_response, tokenizer.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
Expand All @@ -710,20 +709,14 @@ def repeat_generator():
reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length
)

responses.append(response)
postprocessed_responses.append(postprocessed_response)
ref_logprobs.append(ref_logprob)
sequence_lengths.append(sequence_length)
scores.append(score)
responses = torch.cat(responses, 0)
postprocessed_responses = torch.cat(postprocessed_responses, 0)
ref_logprobs = torch.cat(ref_logprobs, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
accelerator.gather(scores)
del (ref_logprob, score)
gc.collect()
torch.cuda.empty_cache()
del score

# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
# responses not passing that filter will receive a low (fixed) score
Expand All @@ -737,42 +730,84 @@ def repeat_generator():
contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value)
)

# num_examples is the number of unique prompts for which we get num_generation_per_prompt completions
num_examples = scores.size(0) // args.num_generation_per_prompt
scores_reshaped = scores.reshape(args.num_generation_per_prompt, num_examples).t()

# Get the max scores and their local indices
chosen_scores, chosen_local_indices = torch.max(scores_reshaped, dim=1)

# Get the min scores and their local indices
rejected_scores, rejected_local_indices = torch.min(scores_reshaped, dim=1)

scores_margin = chosen_scores - rejected_scores

# Calculate the global indices
num_examples_range = torch.arange(num_examples).to(scores.device)
chosen_indices = chosen_local_indices * num_examples + num_examples_range
rejected_indices = rejected_local_indices * num_examples + num_examples_range

if args.take_top_bottom_generation:
# reduce query_responses from (num_examples * num_generation_per_prompt) to (num_examples * 2)
filtered_indices = torch.cat((chosen_indices,rejected_indices),0)
# put all chosen first then rejected
query_responses = query_responses[filtered_indices]
queries = queries[filtered_indices]
postprocessed_responses = postprocessed_responses[filtered_indices]
sequence_lengths = sequence_lengths[filtered_indices]
scores = scores[filtered_indices]

chosen_indices = torch.arange(num_examples)
rejected_indices = torch.arange(num_examples) + num_examples

ref_logprobs = []
responses = []
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
ref_output = forward(ref_model, query_response, tokenizer.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()

ref_logprobs.append(ref_logprob)
responses.append(response)

ref_logprobs = torch.cat(ref_logprobs, 0)
responses = torch.cat(responses, 0)
del ref_logprob
gc.collect()
torch.cuda.empty_cache()

# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)

# num_examples should be same as args.local_batch_size divided by 2
num_examples = scores.size(0) // 2
first_half = scores[:num_examples]
second_half = scores[num_examples:]

num_examples_range = torch.arange(num_examples).to(scores.device)
chosen_indices = torch.where(
first_half >= second_half, num_examples_range.clone(), num_examples_range.clone() + num_examples
)
rejected_indices = torch.where(
first_half < second_half, num_examples_range.clone(), num_examples_range.clone() + num_examples
)
scores_margin = scores[chosen_indices] - scores[rejected_indices]


logprobs = []
concat_indices = []
# Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch
for epoch_idx in range(args.num_epochs):
b_inds = np.random.permutation(args.local_batch_size // args.num_generation_per_prompt)
b_inds = np.random.permutation(args.local_batch_size // args.num_training_samples_per_prompt)
minibatch_idx = 0
for mini_batch_start in range(
0,
args.local_batch_size // args.num_generation_per_prompt,
args.local_mini_batch_size // args.num_generation_per_prompt,
args.local_batch_size // args.num_training_samples_per_prompt,
args.local_mini_batch_size // args.num_training_samples_per_prompt,
):
mini_batch_end = mini_batch_start + args.local_mini_batch_size // args.num_generation_per_prompt
mini_batch_end = mini_batch_start + args.local_mini_batch_size // args.num_training_samples_per_prompt
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
gradient_accumulation_idx = 0
for micro_batch_start in range(
0,
args.local_mini_batch_size // args.num_generation_per_prompt,
args.local_mini_batch_size // args.num_training_samples_per_prompt,
args.per_device_train_batch_size,
):
with accelerator.accumulate(model):
Expand Down