diff --git a/cookbooks/training_judge_model/grpo/README.md b/cookbooks/training_judge_model/grpo/README.md index c81c0e1e8..0f0a05eaf 100644 --- a/cookbooks/training_judge_model/grpo/README.md +++ b/cookbooks/training_judge_model/grpo/README.md @@ -338,6 +338,13 @@ export N_NODES=2 bash pointwise/run_pointwise.sh ``` +### GRPO training using OpenJudge grader dataset +```bash +export RAY_ADDRESS="http://:8265" +export N_NODES=2 +bash pointwise/run_pointwise_grader.sh +``` + ### Resource Requirements | Configuration | GPUs | Memory per GPU | Recommended Model Size | diff --git a/cookbooks/training_judge_model/grpo/grader_rl_dataset.py b/cookbooks/training_judge_model/grpo/grader_rl_dataset.py new file mode 100644 index 000000000..5ec8ea984 --- /dev/null +++ b/cookbooks/training_judge_model/grpo/grader_rl_dataset.py @@ -0,0 +1,650 @@ +import copy +import logging +import os +from dataclasses import dataclass, field +from pprint import pprint +from typing import Any, Dict, List, Union + +import datasets +import verl.utils.torch_functional as verl_F +from omegaconf import DictConfig +from torch.utils.data import Dataset +from transformers import PreTrainedTokenizer +from verl.utils.model import compute_position_id_with_mask + +logger = logging.getLogger(__name__) + + +@dataclass +class DatasetConfig: + """Dataset configuration for different datasets and templates.""" + + dataset_name: str = "default" + prompt_template: str = "" + input_field: str = "input" + output_field: str = "output" + response_field: str = "answer" + label_field: str = "label" + score_field: str = "score" + max_samples: int = -1 + custom_principles: List[str] = field(default_factory=list) + custom_task_description: str = "" + template_config: Dict[str, Any] = field(default_factory=dict) + file_format: str = "" + file_formats_supported: List[str] = field(default_factory=lambda: ["parquet", "json", "jsonl"]) + + +class BaseChatRLDataset(Dataset): + """Base class for chat reinforcement learning dataset.""" + + def __init__( + self, + data_files: Union[str, List[str]], + tokenizer: PreTrainedTokenizer, + config: Union[DictConfig, Dict[str, Any]], + processor=None, # Keep for backward compatibility, but not used + max_samples: int = -1, # Add max_samples parameter + ): + # Initialize basic attributes + self.data_files = self._normalize_data_files(data_files) + self.original_data_files = copy.deepcopy(self.data_files) + self.tokenizer = tokenizer + self.config = config + self.max_samples = max_samples + + # Parse configuration - support both DictConfig and regular dict + self._parse_config(config) + + # Validate file formats and load data + self._validate_file_formats() + self._load_dataset() + + def _normalize_data_files(self, data_files): + """Convert data files to list format.""" + if isinstance(data_files, str): + data_files = [data_files] + elif hasattr(data_files, "_iter_"): # Handle ListConfig or similar + data_files = list(data_files) + return copy.deepcopy(data_files) + + def _parse_config(self, config: Union[DictConfig, Dict[str, Any]]): + """Parse configuration parameters from either DictConfig or dict.""" + if hasattr(config, "__dict__") or isinstance(config, dict): + config_dict = dict(config) if hasattr(config, "__dict__") else config + else: + config_dict = {} + + # Load configuration settings with defaults + self.cache_dir = os.path.expanduser(config_dict.get("cache_dir", "~/.cache/verl/rlhf")) + self.prompt_key = config_dict.get("prompt_key", "prompt") + self.max_prompt_length = config_dict.get("max_prompt_length", 1024) + self.return_raw_chat = config_dict.get("return_raw_chat", False) + self.truncation = config_dict.get("truncation", "error") + self.filter_overlong_prompts = config_dict.get("filter_overlong_prompts", True) + self.num_workers = min( + config_dict.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)), os.cpu_count() + ) + self.serialize_dataset = False + + # Add dataset-specific configuration + self.dataset_config_dict = config_dict.get("dataset_config", {}) + self.input_field = self.dataset_config_dict.get("input_field", "input") + self.output_field = self.dataset_config_dict.get("output_field", "output") + self.file_format = self.dataset_config_dict.get("file_format", "") + + def _validate_file_formats(self): + """Validate file formats and determine actual formats.""" + validated_files = [] + for file_path in self.data_files: + actual_format = self._detect_file_format(file_path) + if actual_format not in ["parquet", "json", "jsonl"]: + raise ValueError( + f"Unsupported file format for {file_path}: {actual_format}. " + f"Supported formats: parquet, json, jsonl" + ) + validated_files.append((file_path, actual_format)) + + self.validated_files = validated_files + pprint(f"Validated {len(validated_files)} files with formats: {[fmt for _, fmt in validated_files]}") + + def _detect_file_format(self, file_path: str) -> str: + """Detect file format based on extension or auto-detection.""" + if self.file_format: + return self.file_format + + _, ext = os.path.splitext(file_path.lower()) + if ext in [".parquet", ".json", ".jsonl"]: + return ext[1:] + else: + # Try to detect format by attempting to load + try: + # Test if it's parquet + test_ds = datasets.load_dataset("parquet", data_files=file_path) + return "parquet" + except: + try: + # Test if it's json + test_ds = datasets.load_dataset("json", data_files=file_path) + return "json" + except: + raise ValueError(f"Cannot determine format for file: {file_path}") + + def _load_dataset(self): + """Load and process dataset from multiple files of different formats.""" + self._download_files() + + # Load dataframes from different formats + dataframes = [] + for file_path, file_format in self.validated_files: + df = self._load_single_file(file_path, file_format) + dataframes.append(df) + + self.dataframe = datasets.concatenate_datasets(dataframes) + total = len(self.dataframe) + pprint(f"Combined dataset length: {total}") + + # Handle max_samples parameter + if self.max_samples > 0 and self.max_samples < total: + import numpy as np + + indices = np.arange(self.max_samples) + self.dataframe = self.dataframe.select(indices.tolist()) + pprint(f"Selected {self.max_samples} samples (total: {total})") + + # Filter overlong prompts + if self.filter_overlong_prompts: + self._filter_long_prompts() + + def _download_files(self): + """Download files to local cache.""" + from verl.utils.fs import copy_to_local + + downloaded_files = [] + for file_path, file_format in self.validated_files: + downloaded_path = copy_to_local(src=file_path, cache_dir=self.cache_dir) + downloaded_files.append((downloaded_path, file_format)) + + self.validated_files = downloaded_files + + def _load_single_file(self, file_path: str, file_format: str) -> Dataset: + """Load a single file based on its format.""" + pprint(f"Loading {file_format} file: {file_path}") + + if file_format == "parquet": + dataset = datasets.load_dataset("parquet", data_files=file_path)["train"] + elif file_format in ["json", "jsonl"]: + dataset = datasets.load_dataset("json", data_files=file_path)["train"] + else: + raise ValueError(f"Unsupported file format: {file_format}") + + pprint(f"Loaded {len(dataset)} samples from {file_path}") + return dataset + + def _filter_long_prompts(self): + """Filter out overlong prompts.""" + # Extract tokenizer and params to local variables to avoid pickle serialization issues + tokenizer = self.tokenizer + max_length = self.max_prompt_length + prompt_key = self.prompt_key + input_field = self.input_field + + def is_prompt_valid(doc): + try: + # Inline prompt extraction logic - handles both nested and flat structures + prompt = None + + # Try nested structure first: doc[input_field][prompt_key] + if input_field and input_field in doc: + inner = doc[input_field] + if isinstance(inner, dict) and prompt_key in inner: + prompt = inner[prompt_key] + elif isinstance(inner, str): + prompt = inner # Fallback: input_field might directly contain the prompt string + + # Fallback to top-level prompt_key if nested extraction failed + if prompt is None and prompt_key in doc: + prompt = doc[prompt_key] + + # Keep samples where prompt can't be extracted (safer than dropping data) + if not prompt or not isinstance(prompt, str): + return True + + # Check token length + return len(tokenizer.encode(prompt)) <= max_length + + except Exception as e: + logger.error(f"Error during filtering: {e}") + return True # Keep sample on error to avoid data loss + + original_len = len(self.dataframe) + self.dataframe = self.dataframe.filter( + is_prompt_valid, + num_proc=1, # Use single process to avoid serialization issues + desc=f"Filtering prompts exceeding {max_length} tokens", + ) + pprint(f"Dataset length after filtering: {len(self.dataframe)} (original: {original_len})") + + def _extract_prompt_from_doc(self, doc: dict, input_field: str, prompt_key: str) -> str: + """Extract prompt from document supporting multiple formats.""" + # Handle the new JSON structure with input.query + if input_field in doc: + input_data = doc[input_field] + if isinstance(input_data, dict) and "query" in input_data: + return input_data["query"] + elif isinstance(input_data, list): + for msg in input_data: + if isinstance(msg, dict) and msg.get("role") == "user" and msg.get("content"): + return msg["content"] + elif isinstance(input_data, str): + return input_data + + # Fallback to old data structure + prompt = doc.get(prompt_key) + if prompt is None: + prompt = doc.get("x", []) + if prompt: + return prompt[-1].get("content", "") + + if isinstance(prompt, str): + return prompt[: self.max_prompt_length] + elif isinstance(prompt, list) and prompt: + return prompt[0].get("content", "") if isinstance(prompt[0], dict) else str(prompt[0]) + + return "" + + def _extract_prompt(self, example): + """Extract prompt from example - supports configurable field names.""" + return self._extract_prompt_from_doc(example, self.input_field, self.prompt_key) + + def _build_messages(self, example: dict) -> List[dict]: + """Build chat messages from example - subclasses must override.""" + raise NotImplementedError("Subclasses must implement _build_messages") + + def _format_template(self, messages: List[dict], example: dict) -> str: + """Format template - subclasses must override.""" + raise NotImplementedError("Subclasses must implement _format_template") + + def _extract_ground_truth(self, row_dict): + """Extract ground truth label - subclasses must override.""" + raise NotImplementedError("Subclasses must implement _extract_ground_truth") + + def __getitem__(self, item): + """Get an item from the dataset.""" + row_dict = dict(self.dataframe[item]) + messages = self._build_messages(row_dict) + + # Format prompt + raw_prompt_messages = self._format_template(messages, row_dict) + + # Try using enable_thinking parameter, fallback if not supported + try: + raw_prompt = self.tokenizer.apply_chat_template( + raw_prompt_messages, add_generation_prompt=True, tokenize=False, enable_thinking=True + ) + except TypeError: + # If tokenizer doesn't support enable_thinking parameter, skip it + raw_prompt = self.tokenizer.apply_chat_template( + raw_prompt_messages, add_generation_prompt=True, tokenize=False + ) + + # Tokenize + model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False) + input_ids = model_inputs["input_ids"] + attention_mask = model_inputs["attention_mask"] + + # Post-process + input_ids, attention_mask = verl_F.postprocess_data( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=self.max_prompt_length, + pad_token_id=self.tokenizer.pad_token_id, + left_pad=True, + truncation=self.truncation, + ) + + # Compute position IDs + position_ids = compute_position_id_with_mask(attention_mask) + + # Prepare raw prompt IDs + raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False) + if len(raw_prompt_ids) > self.max_prompt_length: + if self.truncation == "left": + raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :] + elif self.truncation == "right": + raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length] + elif self.truncation == "error": + raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} exceeds {self.max_prompt_length}") + + # Build result + result = { + "input_ids": input_ids[0], + "attention_mask": attention_mask[0], + "position_ids": position_ids[0], + "raw_prompt_ids": raw_prompt_ids, + "index": row_dict.get("index", item), + "extra_info": copy.deepcopy(row_dict), + "reward_model": {"ground_truth": self._extract_ground_truth(row_dict)}, + "data_source": row_dict.get("source", "default"), + } + + if self.return_raw_chat: + result["raw_prompt"] = messages + + return result + + def __len__(self): + return len(self.dataframe) + + def resume_dataset_state(self): + """Resume dataset state for checkpointing.""" + self.serialize_dataset = not hasattr(self, "original_data_files") + if not self.serialize_dataset: + self.data_files = copy.deepcopy(self.original_data_files) + self._load_dataset() + else: + pprint("Using old dataloader checkpoint file, recommend training from scratch") + + def __getstate__(self): + """Get state for serialization.""" + state = self.__dict__.copy() + state.pop("dataframe", None) + return state + + def get_dataset_info(self) -> Dict[str, Any]: + """Get basic information about the dataset.""" + return { + "total_samples": len(self), + "data_files": self.data_files, + "validated_files": [(path, fmt) for path, fmt in self.validated_files], + "input_field": getattr(self, "input_field", "input"), + "output_field": getattr(self, "output_field", "output"), + "max_prompt_length": self.max_prompt_length, + "filter_overlong_prompts": self.filter_overlong_prompts, + } + + +class PointwiseChatRLDataset(BaseChatRLDataset): + """Pointwise chat reinforcement learning dataset - for single response quality scoring.""" + + def __init__(self, data_files, tokenizer, config, processor=None, max_samples: int = -1): + # Parse dataset config before calling parent constructor + self.dataset_config = self._parse_dataset_config(config) + super().__init__(data_files, tokenizer, config, processor, max_samples) + pprint(f"Using Pointwise mode with dataset: {self.dataset_config.dataset_name}") + pprint(f"Prompt template: {self.dataset_config.prompt_template}") + + def _parse_dataset_config(self, config) -> DatasetConfig: + """Parse dataset configuration from config.""" + if hasattr(config, "dataset_config"): + dataset_config_dict = config.dataset_config + if isinstance(dataset_config_dict, DatasetConfig): + return dataset_config_dict + elif isinstance(dataset_config_dict, dict): + config_dict = {**dataset_config_dict} + return DatasetConfig(**config_dict) + elif isinstance(config, dict) and "dataset_config" in config: + dataset_config_dict = config["dataset_config"] + if isinstance(dataset_config_dict, DatasetConfig): + return dataset_config_dict + elif isinstance(dataset_config_dict, dict): + config_dict = {**dataset_config_dict} + return DatasetConfig(**config_dict) + + return DatasetConfig() + + def _parse_config(self, config: Union[DictConfig, Dict[str, Any]]): + """Override parent method to incorporate dataset config.""" + super()._parse_config(config) + + # Use the parsed config dict from parent + self.input_field = self.dataset_config_dict.get("input_field", "input") + self.output_field = self.dataset_config_dict.get("output_field", "output") + self.file_format = self.dataset_config_dict.get("file_format", "") + + def _build_messages(self, example: dict) -> List[dict]: + """Build chat messages from example - Pointwise mode with text format only.""" + messages = [] + + # Check if it's the new JSON structure (has 'input' key with nested structure) + if "input" in example and isinstance(example["input"], dict) and "query" in example["input"]: + # New JSON format + query = example["input"].get("query", "") + if query: + messages.append({"role": "user", "content": query}) + + # Get chosen response (positive example) + if "chosen" in example and isinstance(example["chosen"], dict): + response_data = example["chosen"].get("response", {}) + if isinstance(response_data, dict): + response_content = response_data.get("content", "") + if response_content: + messages.append({"role": "assistant", "content": response_content}) + else: + # Old format - handle standard structure + messages = self._build_old_format_messages(example) + + return messages + + def _build_old_format_messages(self, example: dict) -> List[dict]: + """Build messages in old format for backward compatibility.""" + messages = [] + + # Extract user message from input field + input_key = self.dataset_config.input_field + if input_key in example and example[input_key]: + input_data = example[input_key] + if isinstance(input_data, list): + for msg in input_data: + if isinstance(msg, dict) and msg.get("role") == "user" and msg.get("content"): + messages.append({"role": "user", "content": msg["content"]}) + elif isinstance(input_data, str): + messages.append({"role": "user", "content": input_data}) + + # Pointwise mode: get first response + output_key = self.dataset_config.output_field + if output_key in example and example[output_key]: + output_data = example[output_key] + output_item = output_data[0] if isinstance(output_data, list) else output_data + response_key = self.dataset_config.response_field + + if isinstance(output_item, dict): + answer = output_item.get(response_key, {}) + if isinstance(answer, dict) and answer.get("role") == "assistant": + content = answer.get("content", "") + if content: + messages.append({"role": "assistant", "content": content}) + elif isinstance(answer, str): + messages.append({"role": "assistant", "content": answer}) + + # Fallback to original structure + if len(messages) <= 1: + prompt = self._extract_prompt(example) + if prompt: + messages = [{"role": "user", "content": prompt}] + + return messages + + def _import_with_fallback(self, module_path: str, name: str, fallback: str) -> str: + """Helper to import a variable from a module with a fallback.""" + try: + module = __import__(module_path, fromlist=[name]) + return getattr(module, name) + except ImportError: + return fallback + + def _format_template(self, messages: List[dict], example: dict) -> str: + """Format template based on configured prompt template type.""" + # Import the specific prompts from openjudge + CORRECTNESS_PROMPT_EN = self._import_with_fallback( + "openjudge.graders.common.correctness", + "CORRECTNESS_PROMPT_EN", + "Evaluate the factual correctness of the response. Query: {query}, Response: {response}", + ) + + HALLUCINATION_PROMPT_EN = self._import_with_fallback( + "openjudge.graders.common.hallucination", + "HALLUCINATION_PROMPT_EN", + "Check for hallucinations in the response. Query: {query}, Response: {response}", + ) + + RELEVANCE_PROMPT_EN = self._import_with_fallback( + "openjudge.graders.common.relevance", + "RELEVANCE_PROMPT_EN", + "Evaluate the relevance of the response to the query. Query: {query}, Response: {response}", + ) + + HARMFULNESS_PROMPT_EN = self._import_with_fallback( + "openjudge.graders.common.harmfulness", + "HARMFULNESS_PROMPT_EN", + "Evaluate the harmfulness of the response to the query. Query: {query}, Response: {response}", + ) + + INSTRUCTION_FOLLOWING_PROMPT_EN = self._import_with_fallback( + "openjudge.graders.common.instruction_following", + "INSTRUCTION_FOLLOWING_PROMPT_EN", + "Evaluate the instruction_following of the response to the query. Query: {query}, Response: {response}", + ) + + task_type = example.get("task_type", "unknown") + + if task_type == "correctness": + grader_template = CORRECTNESS_PROMPT_EN + elif task_type == "hallucination": + grader_template = HALLUCINATION_PROMPT_EN + elif task_type == "relevance": + grader_template = RELEVANCE_PROMPT_EN + elif task_type == "harmlessness": + grader_template = HARMFULNESS_PROMPT_EN + elif task_type == "instruction_following": + grader_template = INSTRUCTION_FOLLOWING_PROMPT_EN + else: + # Default to correctness if unknown template + pprint(f"task type: {task_type}") + raise ValueError( + f"Unknown task type: {task_type}. Valid types: correctness, hallucination, relevance, " + f"harmlessness, instruction_following" + ) + return self._format_grader_template(messages, example, grader_template) + + def _format_grader_template(self, messages: List[dict], example: dict, grader_prompt: str) -> str: + """Format correctness evaluation template using openjudge prompt.""" + if "input" in example and isinstance(example["input"], dict) and "query" in example["input"]: + # New JSON format + query = example["input"].get("query", "") + context = example["input"].get("context") or "" # Handle null value + reference_response = example["input"].get("reference", "") + + response = "" + if "answer" in example and isinstance(example["answer"], dict): + answer_response = example["answer"].get("response", {}) + if isinstance(answer_response, dict): + response = answer_response.get("content", "") + # Also try 'response' field as fallback + elif "response" in example and isinstance(example["response"], dict): + response = example["response"].get("content", "") + else: + # Old format - extract from messages + query = next((msg["content"] for msg in messages if msg["role"] == "user"), "") + response = self._get_response_content(example) + reference_response = None + context = None + + instruction = query + # Replace placeholders in the grader prompt + formatted_prompt = grader_prompt.format( + query=query or "", + response=response or "", + reference_response=reference_response or "", + context=str(context) or "", + instruction=instruction or "", + ) + + return [{"role": "user", "content": formatted_prompt}] + + def _get_response_content(self, example: dict) -> str: + """Helper method to extract response content consistently.""" + # Check if it's the new JSON structure + if "input" in example and isinstance(example["input"], dict) and "query" in example["input"]: + # New JSON format - get from chosen response + if "chosen" in example and isinstance(example["chosen"], dict): + response_data = example["chosen"].get("response", {}) + if isinstance(response_data, dict): + return response_data.get("content", "") + else: + # Old format - use original logic + output_key = self.dataset_config.output_field + response_key = self.dataset_config.response_field + + if output_key in example and example[output_key]: + output_data = example[output_key] + output_item = output_data[0] if isinstance(output_data, list) else output_data + if isinstance(output_item, dict): + answer = output_item.get(response_key, {}) + if isinstance(answer, dict): + return answer.get("content", "") + elif isinstance(answer, str): + return answer + return "" + + def _extract_ground_truth(self, row_dict): + """Extract pointwise ground truth label with configurable fields.""" + try: + score_value = 0 + # Check if it's the new JSON structure + if "input" in row_dict and isinstance(row_dict["input"], dict) and "query" in row_dict["input"]: + # New JSON format - extract score value + if "score" in row_dict: + score_value = row_dict["score"] + else: + # Old format - use original logic + output_key = self.dataset_config.output_field + label_key = self.dataset_config.label_field + score_key = self.dataset_config.score_field + + output_data = row_dict.get(output_key, []) + if output_data: + output_item = output_data[0] if isinstance(output_data, list) else output_data + if isinstance(output_item, dict): + answer = output_item.get(self.dataset_config.response_field, {}) + if isinstance(answer, dict): + label_data = answer.get(label_key, {}) + if isinstance(label_data, dict): + score_value = label_data.get(score_key, 0) + elif isinstance(label_data, (int, float)): + score_value = label_data + elif isinstance(label_data, str): + try: + score_value = float(label_data) + except ValueError: + score_value = 0 + + return { + self.dataset_config.score_field: score_value, + "task_type": "pointwise", + "prompt_template": self.dataset_config.prompt_template, + } + except Exception as e: + pprint(f"Failed to extract label from {row_dict}: {e}") + return { + self.dataset_config.score_field: 0, + "task_type": "pointwise", + "prompt_template": self.dataset_config.prompt_template, + } + + def get_dataset_info(self) -> Dict[str, Any]: + """Get information about the current dataset configuration.""" + base_info = super().get_dataset_info() + base_info.update( + { + "dataset_name": self.dataset_config.dataset_name, + "prompt_template": self.dataset_config.prompt_template, + "input_field": self.dataset_config.input_field, + "output_field": self.dataset_config.output_field, + "response_field": self.dataset_config.response_field, + "label_field": self.dataset_config.label_field, + "score_field": self.dataset_config.score_field, + "file_format": self.dataset_config.file_format, + "config": self.dataset_config.__dict__, + } + ) + return base_info diff --git a/cookbooks/training_judge_model/grpo/pointwise/grader_reward_fn.py b/cookbooks/training_judge_model/grpo/pointwise/grader_reward_fn.py new file mode 100644 index 000000000..9a1ee6442 --- /dev/null +++ b/cookbooks/training_judge_model/grpo/pointwise/grader_reward_fn.py @@ -0,0 +1,164 @@ +import json +import re +from pprint import pprint + + +def filter_thinking_parts(text): + """ + Filter thinking parts from text (for models like Qwen3 that support thinking mode). + + Supported thinking tag formats: + - ... + """ + if not isinstance(text, str): + return text + + # Define regex patterns for thinking parts + thinking_patterns = [r".*?"] + + # Apply all patterns sequentially for filtering + filtered_text = text + for pattern in thinking_patterns: + filtered_text = re.sub(pattern, "", filtered_text, flags=re.DOTALL | re.IGNORECASE) + + # Clean up extra whitespace + filtered_text = re.sub(r"\n\s*\n", "\n\n", filtered_text) # Merge multiple newlines + filtered_text = filtered_text.strip() + + return filtered_text + + +def extract_score(response_text): + """ + Extract score from model response. + Extract score from tag. + """ + # Handle case where response_text might not be a string + if not isinstance(response_text, str): + response_text = str(response_text) + + # Extract score from tag + score_pattern = r"(.*?)" + match = re.search(score_pattern, response_text, re.DOTALL) + pprint(f"response_text: {response_text}") + if match: + score_content = match.group(1).strip() + # Extract numbers from content + numbers = re.findall(r"\d+", score_content) + if numbers: + try: + score = int(numbers[0]) # Take the first number as score + if 0 <= score <= 5: # Assume score range can be 0-1 (binary) or 1-5 (multi-class) + return score + except Exception: + pass + else: + try: + response_dict = json.loads(response_text) + return response_dict.get("score", 0) + except Exception: + pass + + return 0 # Default to 0 if extraction fails + + +def calculate_reward(predicted_score, true_score): + """ + Calculate reward based on the difference between predicted and true scores. + Smaller difference results in higher reward. + + For binary classification scenarios (true_score is 0 or 1): + - Correct prediction (exact match) -> Reward 1.0 + - Wrong prediction -> Reward 0.0 + """ + if true_score is None: + return 0.0 + + # Calculate difference + diff = abs(predicted_score - true_score) + + # For binary classification (0 or 1), use exact match + if true_score in [0, 1] and predicted_score in [0, 1]: + return 1.0 if diff == 0 else 0.0 + + # For multi-class scenarios (1-5), use difference calculation + # Convert difference to reward score (smaller difference = higher reward) + max_possible_diff = 4 + normalized_diff = min(diff / max_possible_diff, 1.0) + + # Reward = 1 - normalized difference + reward = 1.0 - normalized_diff + + return reward + + +def compute_score(data_source, solution_str, ground_truth, extra_info=None, **kwargs): + """ + compute_score function compatible with naive.py. + + Args: + data_source: Data source type + solution_str: Model generated response + ground_truth: Ground truth label (obtained from reward_model field) + extra_info: Additional information + """ + try: + # First filter out thinking parts (support thinking mode for models like Qwen3) + filtered_solution = filter_thinking_parts(solution_str) + + # Extract score from filtered solution_str + predicted_score = extract_score(filtered_solution) + + # Handle ground_truth - could be a number or dict + if isinstance(ground_truth, dict): + true_score = ground_truth.get("score", 0) + elif isinstance(ground_truth, (int, float)): + true_score = int(ground_truth) + elif isinstance(ground_truth, str) and ground_truth.isdigit(): + true_score = int(ground_truth) + else: + # If ground_truth is unavailable, try to get from extra_info + if extra_info and isinstance(extra_info, dict): + output_data = extra_info.get("output", []) + if output_data and len(output_data) > 0: + label_data = output_data[0].get("label", {}) + true_score = label_data.get("score", 0) + else: + true_score = 0 + else: + true_score = 0 + + # Calculate reward + reward = calculate_reward(predicted_score, true_score) + + # Return detailed information + return { + "score": reward, + "predicted_score": predicted_score, + "true_score": true_score, + "data_source": data_source, + } + + except Exception as e: + pprint(f"Error in compute_score: {e}") + # Return default values + return {"score": 0.0, "error": str(e), "data_source": data_source} + + +if __name__ == "__main__": + # Test cases + test_response = """Let me analyze this answer step by step: +1. First, I'll check if the answer is well-structured... +4. Finally, I'll look at the overall score... + +2""" + + ground_truth = {"score": 3, "task_type": "pointwise"} + + # Test compute_score function + result = compute_score(data_source="test", solution_str=test_response, ground_truth=ground_truth) + + print("Test Result:") + print(f" Predicted Score: {result.get('predicted_score')}") + print(f" True Score: {result.get('true_score')}") + print(f" Reward: {result.get('score')}") diff --git a/cookbooks/training_judge_model/grpo/pointwise/run_pointwise_grader.sh b/cookbooks/training_judge_model/grpo/pointwise/run_pointwise_grader.sh new file mode 100644 index 000000000..02f9d55e3 --- /dev/null +++ b/cookbooks/training_judge_model/grpo/pointwise/run_pointwise_grader.sh @@ -0,0 +1,150 @@ +#!/bin/bash +# OpenJudge GRPO Pointwise Training Script +# Train judge models using GRPO reinforcement learning for scoring + +set -x +TIMESTAMP=$(date "+%m%dT%H%M") + +# ============================================================================ +# Ray Cluster Configuration +# ============================================================================ +RAY_ADDRESS=${RAY_ADDRESS:-http://127.0.0.1:8265} +N_GPUS_PER_NODE=${N_GPUS_PER_NODE:-8} +N_NODES=${N_NODES:-1} + +# ============================================================================ +# Path Configuration +# ============================================================================ +# Model: Use HuggingFace model ID or local path +MODEL_PATH=${MODEL_PATH:-Qwen/Qwen3-8B} + +# Data: Download from HuggingFace or use local parquet files +TRAIN_FILE=${TRAIN_FILE:-./data/rewardbench2_pointwise_train.parquet} +VAL_FILE=${VAL_FILE:-./data/rewardbench2_pointwise_val.parquet} + +# Output directory +SAVE_PATH=${SAVE_PATH:-./checkpoints/grpo/pointwise} + +# Set env variables +# export SAVE_PATH=/workspace/output +# export MODEL_PATH=/models/Qwen/Qwen3-0.6B +# export TRAIN_FILE=/data/train_rm/grpo/pointwise/train.parquet +# export VAL_FILE=/data/train_rm/grpo/pointwise/val.parquet + +# Grader data files +# export TRAIN_FILE='["/data/text/correctness/correctness_eval_v1_train.jsonl","/data/text/hallucination/hallucination_eval_v1_train.jsonl","/data/text/relevance/relevance_eval_v1_train.jsonl","/data/text/harmlessness/harmlessness_eval_v1_train.jsonl","/data/text/instruction_following/instruction_following_eval_v1_train.jsonl"]' +# export VAL_FILE='["/data/text/correctness/correctness_eval_v1_val.jsonl","/data/text/hallucination/hallucination_eval_v1_val.jsonl","/data/text/relevance/relevance_eval_v1_val.jsonl","/data/text/harmlessness/harmlessness_eval_v1_val.jsonl","/data/text/instruction_following/instruction_following_eval_v1_val.jsonl"]' + +# Get script directory for relative paths +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +GRPO_DIR="$(dirname "$SCRIPT_DIR")" + +# Custom modules +CUSTOM_REWARD_FUNCTION_PATH=${SCRIPT_DIR}/grader_reward_fn.py +CUSTOM_CHAT_RL_DATASET_PATH=${GRPO_DIR}/grader_rl_dataset.py +RUNTIME_ENV_PATH=${GRPO_DIR}/runtime_env.yaml + +# ============================================================================ +# Training Configuration +# ============================================================================ +PROJECT_NAME=OpenJudge +EXPERIMENT_NAME=grpo-pointwise-${TIMESTAMP} + +# ============================================================================ +# Hyperparameters +# ============================================================================ +# Data settings +# TRAIN_BATCH_SIZE=96 +# VAL_BATCH_SIZE=192 +TRAIN_BATCH_SIZE=48 +VAL_BATCH_SIZE=48 + +MAX_PROMPT_LENGTH=4096 +MAX_RESPONSE_LENGTH=2048 + +# Optimizer settings +LR=1e-6 +KL_LOSS_COEF=0.001 + +# GRPO settings +ROLLOUT_N=4 # Number of samples per prompt + +# Training settings +TOTAL_EPOCHS=1 +SAVE_FREQ=20 +TEST_FREQ=10 + +# ============================================================================ +# Environment Setup +# ============================================================================ +# Disable PyTorch expandable segments for vLLM compatibility +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False + +# ============================================================================ +# Run Training with Ray +# ============================================================================ +echo "=== GRPO Pointwise Training Configuration ===" +echo "RAY_ADDRESS: $RAY_ADDRESS" +echo "MODEL_PATH: $MODEL_PATH" +echo "TRAIN_FILE: $TRAIN_FILE" +echo "N_GPUS_PER_NODE: $N_GPUS_PER_NODE" +echo "N_NODES: $N_NODES" +echo "SCRIPT_DIR: $SCRIPT_DIR" +echo "GRPO_DIR: $GRPO_DIR" +echo "==============================================" + +# Change to GRPO directory to ensure runtime_env.yaml working_dir resolves correctly +cd "${GRPO_DIR}" + +ray job submit --address="${RAY_ADDRESS}" \ + --runtime-env="${RUNTIME_ENV_PATH}" \ + -- \ + python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${VAL_FILE}" \ + data.train_batch_size=$TRAIN_BATCH_SIZE \ + data.val_batch_size=$VAL_BATCH_SIZE \ + data.max_prompt_length=$MAX_PROMPT_LENGTH \ + data.max_response_length=$MAX_RESPONSE_LENGTH \ + data.filter_overlong_prompts=True \ + data.truncation='right' \ + data.prompt_key='input' \ + data.custom_cls.path="${CUSTOM_CHAT_RL_DATASET_PATH}" \ + data.custom_cls.name="PointwiseChatRLDataset" \ + reward_model.reward_manager='naive' \ + custom_reward_function.path="${CUSTOM_REWARD_FUNCTION_PATH}" \ + custom_reward_function.name='compute_score' \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=$LR \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=24 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=$KL_LOSS_COEF \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$N_GPUS_PER_NODE \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=$ROLLOUT_N \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','swanlab'] \ + trainer.project_name=${PROJECT_NAME} \ + trainer.experiment_name=${EXPERIMENT_NAME} \ + trainer.n_gpus_per_node=${N_GPUS_PER_NODE} \ + trainer.nnodes=${N_NODES} \ + trainer.save_freq=$SAVE_FREQ \ + trainer.test_freq=$TEST_FREQ \ + trainer.total_epochs=$TOTAL_EPOCHS \ + trainer.val_before_train=False \ + trainer.default_local_dir="${SAVE_PATH}/${EXPERIMENT_NAME}" + +echo "Training completed! Checkpoints saved to: ${SAVE_PATH}/${EXPERIMENT_NAME}" diff --git a/cookbooks/training_judge_model/grpo/pointwise/utils/preprocess_grader_data.py b/cookbooks/training_judge_model/grpo/pointwise/utils/preprocess_grader_data.py new file mode 100644 index 000000000..227bedfd9 --- /dev/null +++ b/cookbooks/training_judge_model/grpo/pointwise/utils/preprocess_grader_data.py @@ -0,0 +1,217 @@ +import json +import random +import sys +from pathlib import Path + + +def get_data_files(): + """Get data files list from user input""" + print("Please enter data file paths (one file per line, press Enter on empty line to finish):") + data_files = [] + while True: + file_path = input().strip() + if not file_path: # Empty line ends input + break + data_files.append(file_path) + + if not data_files: + # Exit with error if no files are entered + print("Error: No files provided. Please enter at least one file path.") + sys.exit(1) + + return data_files + + +def get_split_params(): + """Get train/validation split ratio and random seed from user""" + print("\nEnter train/validation split ratio (default 0.8):") + try: + split_ratio_input = input().strip() + if split_ratio_input: + split_ratio = float(split_ratio_input) + else: + split_ratio = 0.8 + except ValueError: + print("Invalid input, using default split ratio 0.8") + split_ratio = 0.8 + + if split_ratio <= 0 or split_ratio >= 1: + print("Split ratio must be between 0 and 1. Using default 0.8") + split_ratio = 0.8 + + print(f"Train/Validation split ratio: {split_ratio}/{1 - split_ratio}") + + print("\nEnter random seed (default 42):") + try: + seed_input = input().strip() + if seed_input: + seed = int(seed_input) + else: + seed = 42 + except ValueError: + print("Invalid input, using default seed 42") + seed = 42 + + print(f"Random seed: {seed}") + + print("\nEnter number of samples to keep (optional, press Enter for all):") + try: + sample_num_input = input().strip() + if sample_num_input: + sample_num = int(sample_num_input) + else: + sample_num = None + except ValueError: + print("Invalid input, keeping all samples") + sample_num = None + + return split_ratio, seed, sample_num + + +def validate_files(data_files): + """Validate if files exist""" + valid_files = [] + for file_path in data_files: + if Path(file_path).exists(): + valid_files.append(file_path) + print(f"✓ File exists: {file_path}") + else: + print(f"✗ File does not exist: {file_path}") + + return valid_files + + +def process_single_file(data_file: str, split_ratio: float, seed: int, sample_num: int) -> bool: + """Process a single data file with train/validation split""" + print(f"Processing file: {data_file}") + + try: + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + except FileNotFoundError: + print(f"Error: File not found - {data_file}") + return False + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON format in file {data_file} - {e}") + return False + except UnicodeDecodeError as e: + print(f"Error: Encoding issue in file {data_file} - {e}") + return False + except Exception as e: + print(f"Error reading file {data_file}: {e}") + return False + + # Set random seed for reproducible results + random.seed(seed) + + output_data = [] + try: + for item in data: + if not isinstance(item, dict): + print(f"Warning: Skipping non-dict item in {data_file}") + continue + + # Check if required keys exist + if "input" not in item or "chosen" not in item or "rejected" not in item: + print(f"Warning: Missing required keys in item from {data_file}. Skipping item.") + continue + + # Extract task_type from item if available, otherwise use "unknown" + task_type = item.get("task_type", "unknown") + + output_data.append( + { + "input": item["input"], + "answer": item["chosen"], + "label": 1, # positive example + "score": 5.0, + "task_type": task_type, + } + ) + output_data.append( + { + "input": item["input"], + "answer": item["rejected"], + "label": 0, # negative example + "score": 1.0, + "task_type": task_type, + } + ) + except KeyError as e: + print(f"Error: Missing required key {e} in file {data_file}") + return False + except Exception as e: + print(f"Error processing data in file {data_file}: {e}") + return False + + # Randomly shuffle the data + random.shuffle(output_data) + + # Apply random sampling if specified + if sample_num is not None and len(output_data) > sample_num: + output_data = output_data[:sample_num] + print(f"Sampled {len(output_data)} items from {len(output_data)} total") + + # Split data into train and validation sets + split_idx = int(len(output_data) * split_ratio) + train_data = output_data[:split_idx] + val_data = output_data[split_idx:] + + print(f"Train samples: {len(train_data)}, Validation samples: {len(val_data)}") + + try: + path = Path(data_file) + base_name = path.stem + + # Write training data + train_output_file = path.parent.joinpath(base_name + "_train.jsonl").as_posix() + with open(train_output_file, "w", encoding="utf-8") as f: + for item in train_data: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + print(f"Training file generated: {train_output_file}") + + # Write validation data + val_output_file = path.parent.joinpath(base_name + "_val.jsonl").as_posix() + with open(val_output_file, "w", encoding="utf-8") as f: + for item in val_data: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + print(f"Validation file generated: {val_output_file}") + + return True + except Exception as e: + print(f"Error writing output files: {e}") + return False + + +def main(): + """Main function""" + print("# Preprocess grader data") + print("# Data source: https://huggingface.co/datasets/agentscope-ai/OpenJudge/tree/main") + print() + + # Get file list from user input + data_files = get_data_files() + + # Get split parameters + split_ratio, seed, sample_num = get_split_params() + + # Validate files + valid_files = validate_files(data_files) + + if not valid_files: + print("No valid files found, please check file paths.") + sys.exit(1) + + # Process each valid file + success_count = 0 + for data_file in valid_files: + if process_single_file(data_file, split_ratio, seed, sample_num): + success_count += 1 + # Add blank line separator + print() + + print(f"Processing completed! Successfully processed {success_count}/{len(valid_files)} files.") + + +if __name__ == "__main__": + main()