Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GRPO] generate with prompt containing the first <think> tag #283

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Feb 11, 2025

@qgallouedec
Copy link
Member

Allow bootstrapping directly in GRPO huggingface/trl#2829

@kashif
Copy link
Collaborator Author

kashif commented Feb 11, 2025

"""Script to test format rewards for different models using vLLM."""

import argparse
from typing import List

import torch
from datasets import load_dataset

from open_r1.grpo import SYSTEM_PROMPT
from open_r1.rewards import format_reward
from vllm import LLM, SamplingParams


def format_prompt(question: str) -> List[dict]:
    """Format the prompt as a conversation."""
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": question},
        {"role": "assistant", "content": "Let me solve this step by step.\n<think>"},
    ]


def apply_chat_template(messages: List[dict], tokenizer) -> str:
    """Apply the model's chat template if available, otherwise use our fixed template."""
    if hasattr(tokenizer, "apply_chat_template"):
        # Temporarily override the model's chat template
        original_template = tokenizer.chat_template
        try:
            return tokenizer.apply_chat_template(messages, tokenize=False, continue_final_message=True)
        finally:
            # Restore the original template
            tokenizer.chat_template = original_template

    # Fallback to simple template if no tokenizer chat template support
    formatted = ""
    for msg in messages:
        if msg["role"] == "system":
            formatted += f"System: {msg['content']}\n\n"
        elif msg["role"] == "user":
            formatted += f"User: {msg['content']}\n"
        elif msg["role"] == "assistant":
            formatted += f"Assistant: {msg['content']}\n"
    return formatted


def main():
    parser = argparse.ArgumentParser()
    # Model arguments
    parser.add_argument(
        "--model", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", help="Model ID or path"
    )
    parser.add_argument("--model_revision", type=str, default="main", help="Model revision to use")
    parser.add_argument(
        "--torch_dtype", type=str, default="bfloat16", help="PyTorch dtype (float16, bfloat16, float32)"
    )

    # Dataset arguments
    parser.add_argument("--dataset_name", type=str, default="open-r1/LIMO", help="Dataset to use for testing")
    parser.add_argument("--dataset_split", type=str, default="test", help="Dataset split to use")
    parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to test")

    # Generation arguments
    parser.add_argument("--max_tokens", type=int, default=4096, help="Maximum number of tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
    parser.add_argument("--max_prompt_length", type=int, default=768, help="Maximum length for prompts")

    # vLLM arguments
    parser.add_argument("--vllm_device", type=str, default="auto", help="Device to use for vLLM")
    parser.add_argument(
        "--vllm_gpu_memory_utilization", type=float, default=0.7, help="GPU memory utilization for vLLM"
    )

    args = parser.parse_args()

    # Set torch dtype
    if args.torch_dtype == "bfloat16":
        torch_dtype = torch.bfloat16
    elif args.torch_dtype == "float16":
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    # Initialize vLLM
    print(f"Loading model {args.model}...")
    llm = LLM(
        model=args.model,
        revision=args.model_revision,
        dtype=torch_dtype,
        gpu_memory_utilization=args.vllm_gpu_memory_utilization,
        device=args.vllm_device,
    )
    tokenizer = llm.get_tokenizer()

    sampling_params = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_tokens,
    )

    # Load dataset
    print(f"Loading dataset {args.dataset_name}...")
    dataset = load_dataset(args.dataset_name, split=args.dataset_split)

    # Sample questions from dataset
    if args.num_samples:
        dataset = dataset.select(range(min(args.num_samples, len(dataset))))

    # Test each question
    for example in dataset:
        print("\n" + "=" * 80)
        print(f"Question: {example['problem']}")
        print("-" * 80)

        # Format prompt
        messages = format_prompt(example["problem"])
        prompt = apply_chat_template(messages, tokenizer)

        # Generate completion
        outputs = llm.generate(prompt, sampling_params)
        completion = outputs[0].outputs[0].text
        print(f"Completion:\n{completion}")

        # Check format reward
        reward = format_reward([[{"content": completion}]])
        print(f"\nFormat reward: {reward[0]}")

        # Print analysis
        if reward[0] == 0:
            print("\nAnalysis: Format check failed. Looking for pattern:")
            print("- Must contain <think>...</think> followed by <answer>...</answer>")
            print("- Check if tags are properly closed and in correct order")
        else:
            print("\nAnalysis: Format check passed!")

        print(f"\nGround truth answer: {example.get('answer', 'N/A')}")


if __name__ == "__main__":
    main()

moving the script to a comment and removing it from the PR

@kashif kashif requested a review from lewtun February 11, 2025 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants