From 4366ca0f684342f30cb0d230d987ebffc7f863a2 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 6 Feb 2026 06:42:20 -0800 Subject: [PATCH 1/3] support data subset Signed-off-by: Yuki Huang --- docs/guides/dpo.md | 2 ++ docs/guides/grpo.md | 1 + docs/guides/rm.md | 2 ++ docs/guides/sft.md | 1 + examples/configs/dpo.yaml | 2 ++ examples/configs/rm.yaml | 2 ++ nemo_rl/data/__init__.py | 2 ++ .../binary_preference_dataset.py | 4 ++- .../preference_datasets/preference_dataset.py | 4 ++- .../response_datasets/response_dataset.py | 4 ++- nemo_rl/data/datasets/utils.py | 18 +++++++++-- .../data/datasets/test_response_dataset.py | 31 +++++++++++++++++++ 12 files changed, 68 insertions(+), 5 deletions(-) diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index 40471bbea5..14a7c0c0e5 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -102,6 +102,7 @@ data: train: # this dataset will override prompt_key and use the default values for other vars data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) + subset: null # used for HuggingFace datasets split: train # used for HuggingFace datasets validation: # this dataset will use the default values for other vars except data_path @@ -146,6 +147,7 @@ data: # this dataset will override prompt_key and use the default values for other vars data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) prompt_key: context + subset: null # used for HuggingFace datasets split: train # used for HuggingFace datasets validation: # this dataset will use the default values for other vars except data_path diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index a084aab251..9fc8687920 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -51,6 +51,7 @@ data: # this dataset will override input_key and use the default values for other vars data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) input_key: question + subset: null # used for HuggingFace datasets split: train # used for HuggingFace datasets split_validation_size: 0.05 # use 5% of the training data as validation data seed: 42 # seed for train/validation split when split_validation_size > 0 diff --git a/docs/guides/rm.md b/docs/guides/rm.md index 888774cc28..85ef76398f 100644 --- a/docs/guides/rm.md +++ b/docs/guides/rm.md @@ -91,6 +91,7 @@ data: train: # this dataset will override prompt_key and use the default values for other vars data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) + subset: null # used for HuggingFace datasets split: train # used for HuggingFace datasets validation: # this dataset will use the default values for other vars except data_path @@ -135,6 +136,7 @@ data: # this dataset will override prompt_key and use the default values for other vars data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) prompt_key: context + subset: null # used for HuggingFace datasets split: train # used for HuggingFace datasets validation: # this dataset will use the default values for other vars except data_path diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 09df3d5d2a..e98518db03 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -84,6 +84,7 @@ data: # this dataset will override input_key and use the default values for other vars data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) input_key: question + subset: null # used for HuggingFace datasets split: train # used for HuggingFace datasets split_validation_size: 0.05 # use 5% of the training data as validation data seed: 42 # seed for train/validation split when split_validation_size > 0 diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index f2b57b0bbd..579b53f264 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -195,6 +195,7 @@ data: # train: # # this dataset will override prompt_key and use the default values for other vars # data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) + # subset: null # used for HuggingFace datasets # split: train # used for HuggingFace datasets # validation: # # this dataset will use the default values for other vars except data_path @@ -214,6 +215,7 @@ data: # # this dataset will override prompt_key and use the default values for other vars # data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) # prompt_key: context + # subset: null # used for HuggingFace datasets # split: train # used for HuggingFace datasets # validation: # # this dataset will use the default values for other vars except data_path diff --git a/examples/configs/rm.yaml b/examples/configs/rm.yaml index 4b0936fec5..9e89d6b199 100644 --- a/examples/configs/rm.yaml +++ b/examples/configs/rm.yaml @@ -146,6 +146,7 @@ data: # train: # # this dataset will override prompt_key and use the default values for other vars # data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) + # subset: null # used for HuggingFace datasets # split: train # used for HuggingFace datasets # validation: # # this dataset will use the default values for other vars except data_path @@ -165,6 +166,7 @@ data: # # this dataset will override prompt_key and use the default values for other vars # data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace) # prompt_key: context + # subset: null # used for HuggingFace datasets # split: train # used for HuggingFace datasets # validation: # # this dataset will use the default values for other vars except data_path diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 2fb26ebd90..abede196bf 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -20,6 +20,7 @@ class ResponseDatasetConfig(TypedDict): data_path: NotRequired[str] input_key: NotRequired[str] output_key: NotRequired[str] + subset: NotRequired[str] split: NotRequired[str] prompt_file: NotRequired[str | None] system_prompt_file: NotRequired[str | None] @@ -38,6 +39,7 @@ class PreferenceDatasetConfig(TypedDict): prompt_key: NotRequired[str] chosen_key: NotRequired[str] rejected_key: NotRequired[str] + subset: NotRequired[str] split: NotRequired[str] prompt_file: NotRequired[str | None] system_prompt_file: NotRequired[str | None] diff --git a/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py b/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py index e2cf8d7024..c28d146944 100644 --- a/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py +++ b/nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py @@ -36,6 +36,7 @@ class BinaryPreferenceDataset(RawDataset): prompt_key: Key for the input prompt/context, default is "prompt" chosen_key: Key for the preferred/winning response, default is "chosen" rejected_key: Key for the non-preferred/losing response, default is "rejected" + subset: Optional subset name for the dataset, used for HuggingFace datasets split: Optional split name for the dataset, used for HuggingFace datasets """ @@ -45,6 +46,7 @@ def __init__( prompt_key: str = "prompt", chosen_key: str = "chosen", rejected_key: str = "rejected", + subset: Optional[str] = None, split: Optional[str] = None, **kwargs, ): @@ -57,7 +59,7 @@ def __init__( self.task_name = self.task_name[1:] # load from local or huggingface - self.dataset = load_dataset_from_path(data_path, split) + self.dataset = load_dataset_from_path(data_path, subset, split) # format the dataset self.dataset = self.dataset.map( diff --git a/nemo_rl/data/datasets/preference_datasets/preference_dataset.py b/nemo_rl/data/datasets/preference_datasets/preference_dataset.py index 939b62295c..660970e257 100644 --- a/nemo_rl/data/datasets/preference_datasets/preference_dataset.py +++ b/nemo_rl/data/datasets/preference_datasets/preference_dataset.py @@ -40,12 +40,14 @@ class PreferenceDataset(RawDataset): Args: data_path: Path to the dataset JSON file + subset: Optional subset name for the dataset, used for HuggingFace datasets split: Optional split name for the dataset, used for HuggingFace datasets """ def __init__( self, data_path: str, + subset: Optional[str] = None, split: Optional[str] = None, **kwargs, ): @@ -54,7 +56,7 @@ def __init__( self.task_name = self.task_name[1:] # load from local or huggingface - self.dataset = load_dataset_from_path(data_path, split) + self.dataset = load_dataset_from_path(data_path, subset, split) # format the dataset self.dataset = self.dataset.add_column( diff --git a/nemo_rl/data/datasets/response_datasets/response_dataset.py b/nemo_rl/data/datasets/response_datasets/response_dataset.py index 554c5db452..e6197a501f 100644 --- a/nemo_rl/data/datasets/response_datasets/response_dataset.py +++ b/nemo_rl/data/datasets/response_datasets/response_dataset.py @@ -33,6 +33,7 @@ class ResponseDataset(RawDataset): data_path: Path to the dataset JSON file input_key: Key for the input text, default is "input" output_key: Key for the output text, default is "output" + subset: Optional subset name for the dataset, used for HuggingFace datasets split: Optional split name for the dataset, used for HuggingFace datasets split_validation_size: Size of the validation data, default is 0 seed: Seed for train/validation split when split_validation_size > 0, default is 42 @@ -43,6 +44,7 @@ def __init__( data_path: str, input_key: str = "input", output_key: str = "output", + subset: Optional[str] = None, split: Optional[str] = None, split_validation_size: float = 0, seed: int = 42, @@ -56,7 +58,7 @@ def __init__( self.task_name = self.task_name[1:] # load from local or huggingface - self.dataset = load_dataset_from_path(data_path, split) + self.dataset = load_dataset_from_path(data_path, subset, split) # format the dataset if "messages" not in self.dataset.column_names: diff --git a/nemo_rl/data/datasets/utils.py b/nemo_rl/data/datasets/utils.py index 24d646878c..3a7d269c71 100644 --- a/nemo_rl/data/datasets/utils.py +++ b/nemo_rl/data/datasets/utils.py @@ -62,11 +62,16 @@ def pil_to_base64(image: Image.Image, format: str = "PNG") -> str: return f"data:image/png;base64,{img_str}" -def load_dataset_from_path(data_path: str, data_split: Optional[str] = "train"): +def load_dataset_from_path( + data_path: str, + data_subset: Optional[str] = None, + data_split: Optional[str] = "train", +): """Load a dataset from a local file, huggingface dataset, or Arrow dataset (saved with save_to_disk). Args: data_path: The path to the dataset. + data_subset: The subset to load from the dataset. Only supported for huggingface datasets. data_split: The split to load from the dataset. """ FILEEXT2TYPE = { @@ -78,12 +83,21 @@ def load_dataset_from_path(data_path: str, data_split: Optional[str] = "train"): ".txt": "text", } suffix = os.path.splitext(data_path)[-1] + # load from local file (not save_to_disk format) if dataset_type := FILEEXT2TYPE.get(suffix): + assert data_subset is None, ( + "data_subset is only supported for huggingface datasets" + ) raw_dataset = load_dataset(dataset_type, data_files=data_path) else: try: - raw_dataset = load_dataset(data_path) + # load from huggingface + if data_subset: + raw_dataset = load_dataset(data_path, data_subset) + else: + raw_dataset = load_dataset(data_path) except ValueError as e: + # load from local file (save_to_disk format) if "load_from_disk" in str(e): raw_dataset = load_from_disk(data_path) else: diff --git a/tests/unit/data/datasets/test_response_dataset.py b/tests/unit/data/datasets/test_response_dataset.py index 661f85c14e..fa27c74a01 100644 --- a/tests/unit/data/datasets/test_response_dataset.py +++ b/tests/unit/data/datasets/test_response_dataset.py @@ -110,6 +110,37 @@ def test_response_dataset(input_key, output_key, is_save_to_disk, file_ext, toke assert combined_message == " Question: Hello Answer: Hi there!" +def test_response_dataset_gsm8k_with_subset(): + # load the dataset + data_config = { + "dataset_name": "ResponseDataset", + "data_path": "openai/gsm8k", + "input_key": "question", + "output_key": "answer", + "subset": "main", + "split": "train", + } + dataset = load_response_dataset(data_config) + + # check the input and output keys + assert dataset.input_key == "question" + assert dataset.output_key == "answer" + + # check the first example + first_example = dataset.dataset[0] + + # only contains messages and task_name + assert len(first_example.keys()) == 2 + assert "messages" in first_example + assert "task_name" in first_example + + # check the content + assert first_example["messages"][0]["role"] == "user" + assert first_example["messages"][0]["content"][:20] == "Natalia sold clips t" + assert first_example["messages"][1]["role"] == "assistant" + assert first_example["messages"][1]["content"][:20] == "Natalia sold 48/2 = " + + def test_helpsteer3_dataset(): # load the dataset data_config = {"dataset_name": "HelpSteer3"} From 8904c7ef26ce618b2172b62b1b2ecbab337f69a4 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 6 Feb 2026 08:10:56 -0800 Subject: [PATCH 2/3] update doc Signed-off-by: Yuki Huang --- docs/guides/dpo.md | 17 ++++++++++++----- docs/guides/grpo.md | 21 ++++++++++++++++++++- docs/guides/rm.md | 17 ++++++++++++----- docs/guides/sft.md | 12 ++++++------ 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index 14a7c0c0e5..d25b1195a9 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -32,8 +32,14 @@ uv run examples/run_dpo.py \ ## Datasets -Each DPO dataset class is expected to have the following attributes: -1. `dataset`: The formatted dataset, which should be formatted like +DPO datasets in NeMo RL are encapsulated using classes. Each DPO data class is expected to have the following attributes: + 1. `dataset`: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below. + 2. `task_name`: A string identifier that uniquely identifies the dataset. + +If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. An example implementation can be found in [response_datasets/tulu3.py](../../nemo_rl/data/datasets/preference_datasets/tulu3.py). + +**Note:** The `task_name` field is required in each formatted example. + ```json { "context": [], // list of dicts - The prompt message (including previous turns, if any) @@ -46,10 +52,10 @@ Each DPO dataset class is expected to have the following attributes: "rank": 1, // int — The rank of the completion (lower rank is preferred) "completion": [] // list of dicts — The completion message(s) } - ] + ], + "task_name": "task_name" // identifier for the task } ``` -2. `task_name`: The unique task identifier for this dataset. This should specify the name you choose for this dataset. DPO training supports only two completions (where the lowest rank is preferred and the highest one is rejected), with each completion being a single response. For example: ```json @@ -87,7 +93,8 @@ DPO training supports only two completions (where the lowest rank is preferred a } ] } - ] + ], + "task_name": "task_name" } ``` diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 9fc8687920..fa9f2fd76d 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -38,6 +38,25 @@ To support this, we need to know: #### Dataset +GRPO datasets in NeMo RL are encapsulated using classes. Each GRPO data class is expected to have the following attributes: + 1. `dataset`: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below. + 2. `task_name`: A string identifier that uniquely identifies the dataset. + +GRPO datasets are expected to follow the HuggingFace chat format. Refer to the [chat dataset document](../design-docs/chat-datasets.md) for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. [response_datasets/deepscaler.py](../../nemo_rl/data/datasets/response_datasets/deepscaler.py) has an example: + +**Note:** The `task_name` field is required in each formatted example. + +```python +def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + return { + "messages": [ + {"role": "user", "content": data["problem"]}, + {"role": "assistant", "content": data["answer"]}, + ], + "task_name": self.task_name, + } +``` + By default, NeMo RL has some built-in supported datasets (e.g., [OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [Squad](../../nemo_rl/data/datasets/response_datasets/squad.py), etc.). You can see the full list [here](../../nemo_rl/data/datasets/response_datasets/__init__.py). All of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk. @@ -82,7 +101,7 @@ We support using multiple datasets for train and validation. You can refer to `e ```yaml data: _override_: true # override the data config instead of merging with it - # other data settings, see `examples/configs/sft.yaml` for more details + # other data settings, see `examples/configs/grpo_math_1B.yaml` for more details ... # dataset settings train: diff --git a/docs/guides/rm.md b/docs/guides/rm.md index 85ef76398f..b768d62775 100644 --- a/docs/guides/rm.md +++ b/docs/guides/rm.md @@ -21,8 +21,14 @@ The default YAML config shares the same base template as the SFT config but incl ## Datasets -Each RM dataset class is expected to have the following attributes: -1. `dataset`: The formatted dataset, which should be formatted like +RM datasets in NeMo RL are encapsulated using classes. Each RM data class is expected to have the following attributes: + 1. `dataset`: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below. + 2. `task_name`: A string identifier that uniquely identifies the dataset. + +If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. An example implementation can be found in [response_datasets/tulu3.py](../../nemo_rl/data/datasets/preference_datasets/tulu3.py). + +**Note:** The `task_name` field is required in each formatted example. + ```json { "context": [], // list of dicts - The prompt message (including previous turns, if any) @@ -35,10 +41,10 @@ Each RM dataset class is expected to have the following attributes: "rank": 1, // int — The rank of the completion (lower rank is preferred) "completion": [] // list of dicts — The completion message(s) } - ] + ], + "task_name": "task_name" // identifier for the task } ``` -2. `task_name`: The unique task identifier for this dataset. This should specify the name you choose for this dataset. Currently, RM training supports only two completions (where the lowest rank is preferred and the highest one is rejected), with each completion being a single response. For example: ```json @@ -76,7 +82,8 @@ Currently, RM training supports only two completions (where the lowest rank is p } ] } - ] + ], + "task_name": "task_name" } ``` diff --git a/docs/guides/sft.md b/docs/guides/sft.md index e98518db03..368ccd216e 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -31,11 +31,13 @@ uv run examples/run_sft.py \ ## Datasets SFT datasets in NeMo RL are encapsulated using classes. Each SFT data class is expected to have the following attributes: - 1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below. - 2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset. + 1. `dataset`: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below. + 2. `task_name`: A string identifier that uniquely identifies the dataset. SFT datasets are expected to follow the HuggingFace chat format. Refer to the [chat dataset document](../design-docs/chat-datasets.md) for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. [response_datasets/squad.py](../../nemo_rl/data/datasets/response_datasets/squad.py) has an example: +**Note:** The `task_name` field is required in each formatted example. + ```python def format_data(self, data: dict[str, Any]) -> dict[str, Any]: return { @@ -52,7 +54,8 @@ def format_data(self, data: dict[str, Any]) -> dict[str, Any]: "role": "assistant", "content": data["answers"]["text"][0], }, - ] + ], + "task_name": self.task_name, } ``` @@ -216,9 +219,6 @@ Without `use_preserving_dataset: true`, the loader would incorrectly add: This corrupts your training data and can lead to models generating invalid tool calls. The `PreservingDataset` mode maintains the exact structure of each tool call. -Adding a new dataset is a straightforward process. -As long as your custom dataset has the `formatted_ds` and `task_spec` attributes described above, it can serve as a drop-in replacement for Squad and OpenAssistant. - ## Evaluate the Trained Model Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities. From aeeb9eb20c4e80964d435164fbbd91a07094e19f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Fri, 6 Feb 2026 08:24:50 -0800 Subject: [PATCH 3/3] coderabbit Signed-off-by: Yuki Huang --- docs/guides/dpo.md | 2 +- docs/guides/rm.md | 2 +- nemo_rl/data/__init__.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index d25b1195a9..f5451dc760 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -36,7 +36,7 @@ DPO datasets in NeMo RL are encapsulated using classes. Each DPO data class is e 1. `dataset`: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below. 2. `task_name`: A string identifier that uniquely identifies the dataset. -If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. An example implementation can be found in [response_datasets/tulu3.py](../../nemo_rl/data/datasets/preference_datasets/tulu3.py). +If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. An example implementation can be found in [preference_datasets/tulu3.py](../../nemo_rl/data/datasets/preference_datasets/tulu3.py). **Note:** The `task_name` field is required in each formatted example. diff --git a/docs/guides/rm.md b/docs/guides/rm.md index b768d62775..50e7c3950f 100644 --- a/docs/guides/rm.md +++ b/docs/guides/rm.md @@ -25,7 +25,7 @@ RM datasets in NeMo RL are encapsulated using classes. Each RM data class is exp 1. `dataset`: A dictionary containing the formatted datasets. Each example in the dataset must conform to the format described below. 2. `task_name`: A string identifier that uniquely identifies the dataset. -If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. An example implementation can be found in [response_datasets/tulu3.py](../../nemo_rl/data/datasets/preference_datasets/tulu3.py). +If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. An example implementation can be found in [preference_datasets/tulu3.py](../../nemo_rl/data/datasets/preference_datasets/tulu3.py). **Note:** The `task_name` field is required in each formatted example. diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index abede196bf..82e5e57bf5 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -20,7 +20,7 @@ class ResponseDatasetConfig(TypedDict): data_path: NotRequired[str] input_key: NotRequired[str] output_key: NotRequired[str] - subset: NotRequired[str] + subset: NotRequired[str | None] split: NotRequired[str] prompt_file: NotRequired[str | None] system_prompt_file: NotRequired[str | None] @@ -39,7 +39,7 @@ class PreferenceDatasetConfig(TypedDict): prompt_key: NotRequired[str] chosen_key: NotRequired[str] rejected_key: NotRequired[str] - subset: NotRequired[str] + subset: NotRequired[str | None] split: NotRequired[str] prompt_file: NotRequired[str | None] system_prompt_file: NotRequired[str | None]