Skip to content

Commit

Permalink
Add optional r1-style thinking reward (#551)
Browse files Browse the repository at this point in the history
* Add optional r1-style thinking reward

* quick change

* test

* quick change

* push latest change

* fix

* tested it to work

* Push changes

* push
  • Loading branch information
vwxyzjn authored Feb 6, 2025
1 parent 76c4f48 commit 1ff4692
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 7 deletions.
21 changes: 21 additions & 0 deletions docs/algorithms/grpo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Grouped Relative Policy Optimization (GRPO)

GRPO is an online RL method used in [DeepSeek R1 paper](https://arxiv.org/abs/2501.12948) and its first appearance is in [DeepSeekMath](https://arxiv.org/abs/2402.03300)

`open_instruct/grpo_vllm_thread_ray_gtrl.py` contains an implementation of GRPO.


## Get started


Here is a command to run GRPO on the Llama3.1-8b on [ai2-adapt-dev/rlvr_gsm8k_zs](https://huggingface.co/datasets/ai2-adapt-dev/rlvr_gsm8k_zs), which is simply a zero-shot version of the RLVR GSM8K dataset.


```bash
bash scripts/train/rlvr/grpo_llama3.1-8b.sh
```

The results look quite reasonable: with format score, score all going up, KL not exploding, and sequence length seems stable (at least at first)


![alt text](grpo_8b.png)
18 changes: 18 additions & 0 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ def visualize_token(tokens: list[int], tokenizer: PreTrainedTokenizer):
"{% endif %}"
"{% endfor %}"
),
"r1_simple_chat_postpend_think": (
"A conversation between User and Assistant. "
"The user asks a question, and the Assistant solves it. "
"The assistant first thinks about the reasoning process in "
"the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> "
"and <answer> </answer> tags, respectively, "
"i.e., <think> reasoning process here </think> "
"<answer> answer here </answer>."
"\n\n"
"{% for message in messages %}"
"{{ '\n\n' if not loop.first else '' }}"
"{{ message['role'].capitalize() + ': ' + message['content'] + '\n' }}"
"{% if loop.last and add_generation_prompt %}"
"{{ 'Assistant: <think>' }}"
"{% endif %}"
"{% endfor %}"
),
}
# flake8: noqa

Expand Down
29 changes: 23 additions & 6 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
'''
"""
Collection of 'ground truth rewards' for different datasets/tasks.
Used to give feedback to the model based on the ground truth answer.
'''
import re
"""

import json
import re
import string
from open_instruct.math_utils import last_boxed_only_string, remove_boxed, get_unnormalized_answer, normalize_final_answer, is_equiv, hendrycks_is_equiv

from open_instruct.if_functions import IF_FUNCTIONS_MAP
from open_instruct.math_utils import (
get_unnormalized_answer,
hendrycks_is_equiv,
is_equiv,
last_boxed_only_string,
normalize_final_answer,
remove_boxed,
)


def verify_gsm8k_sample(model_output, ground_truth_answer):
Expand Down Expand Up @@ -138,11 +147,19 @@ def verify_flan_sample(model_output, ground_truth_answer):
return normalize_answer(answer_string) == normalize_answer(ground_truth_answer)


def soft_format_reward_func(responses: list[str], reward_scale: float = 1.0) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r".*?</think>\s*<answer>.*?</answer>"
matches = [re.match(pattern, r, re.DOTALL) for r in responses]
return [reward_scale if match else 0.0 for match in matches]


# debug code
if __name__ == "__main__":
from datasets import load_dataset

ds = load_dataset("ai2-adapt-dev/prompts_with_constraints_for_ground_truth")
test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$"
for sample in ds['train']:
for sample in ds["train"]:
print(sample)
verify_ifeval_sample(test_model_output, sample['ground_truth'])
verify_ifeval_sample(test_model_output, sample["ground_truth"])
24 changes: 23 additions & 1 deletion open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TokenizerConfig,
get_cached_dataset_rlvr,
)
from open_instruct.ground_truth_utils import soft_format_reward_func

os.environ["NCCL_CUMEM_ENABLE"] = "0" # NOQA

Expand Down Expand Up @@ -234,6 +235,10 @@ class Args:
"""the reward model multiplier, for down/upscaling the reward model output"""
verification_reward: float = 10.0
"""the reward value for verifiable responses"""
add_r1_style_format_reward: bool = False
"""whether to add the R1 style format reward"""
r1_style_format_reward: float = 1.0
"""the reward value for R1 style format reward"""

# async setting
async_mode: bool = True
Expand Down Expand Up @@ -1058,7 +1063,9 @@ def vllm_generate(
sequence_lengths = []
if accelerator.is_main_process:
g_response_token_ids = response_ids_Q.get()
DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out
DUMMY_PAD_TOKEN = (
args.stop_token_id
) # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out
g_padded_response_ids = [
response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response))
for response in g_response_token_ids
Expand All @@ -1071,6 +1078,12 @@ def vllm_generate(
]
# print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}")
query_responses = torch.cat((queries, local_vllm_responses), 1)

if args.add_r1_style_format_reward:
decoded_response = tokenizer.batch_decode(local_vllm_responses)
format_scores = torch.tensor(
soft_format_reward_func(decoded_response, args.r1_style_format_reward), device=device
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
# print(f"get reward stuff starts {i=}")
query = queries[i : i + args.local_rollout_forward_batch_size]
Expand Down Expand Up @@ -1122,6 +1135,9 @@ def vllm_generate(
else:
verifiable_count = torch.tensor([0.0], device=device).float()

if args.add_r1_style_format_reward:
score += format_scores[i : i + args.local_rollout_forward_batch_size]

responses.append(response)
postprocessed_responses.append(postprocessed_response)
logprobs.append(logprob)
Expand All @@ -1140,6 +1156,10 @@ def vllm_generate(
verifiable_counts = torch.cat(verifiable_counts, 0)
verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0]
# print(f"get reward stuff finished")
if self.rank == 0:
print(f"{sequence_lengths=}")
print(f"{postprocessed_responses[0]=}")
print(f"{tokenizer.decode(postprocessed_responses[0])=}")
del (logprob, ref_logprob, score)
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1294,6 +1314,8 @@ def vllm_generate(
local_metrics.add("val/ratio", ratio_stats.mean())
local_metrics.add("val/ratio_var", ratio_stats.var())
local_metrics.add("val/stop_token_rate", contain_stop_token.float().mean())
if args.add_r1_style_format_reward:
local_metrics.add("val/format_scores", format_scores.float().mean())

metrics = {
"episode": episode,
Expand Down
40 changes: 40 additions & 0 deletions scripts/train/rlvr/grpo_llama3.1-8b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
python open_instruct/grpo_vllm_thread_ray_gtrl.py \
--exp_name $exp_name \
--output_dir /weka/oe-adapt-default/costah/models/$exp_name \
--dataset_mixer_list ai2-adapt-dev/rlvr_gsm8k_zs 1.0 \
--dataset_mixer_list_splits train \
--dataset_mixer_eval_list ai2-adapt-dev/rlvr_gsm8k_zs 1.0 \
--dataset_mixer_eval_list_splits train \
--max_token_length 2048 \
--max_prompt_token_length 2048 \
--response_length 2048 \
--number_samples_per_prompt 4 \
--model_name_or_path meta-llama/Llama-3.1-8B \
--stop_strings '"</answer>"' \
--add_r1_style_format_reward \
--non_stop_penalty False \
--stop_token eos \
--penalty_reward_value 0.0 \
--temperature 0.7 \
--ground_truths_key ground_truth \
--chat_template_name r1_simple_chat_postpend_think \
--sft_messages_key messages \
--learning_rate 3e-7 \
--total_episodes 1000000 \
--deepspeed_stage 3 \
--per_device_train_batch_size 1 \
--local_rollout_forward_batch_size 1 \
--local_mini_batch_size 16 \
--local_rollout_batch_size 16 \
--num_epochs 1 \
--actor_num_gpus_per_node 6 \
--vllm_tensor_parallel_size 2 \
--beta 0.01 \
--apply_verifiable_reward true \
--seed 3 \
--num_evals 100 \
--save_freq 100 \
--reward_model_multiplier 0.0 \
--no_try_launch_beaker_eval_jobs \
--gradient_checkpointing \
--with_tracking

0 comments on commit 1ff4692

Please sign in to comment.