Skip to content

Latest commit

 

History

History
173 lines (139 loc) · 6.87 KB

File metadata and controls

173 lines (139 loc) · 6.87 KB

Direct Preference Optimization in NeMo RL

Direct Preference Optimization (DPO) is an RL-free alignment algorithm that operates on preference data. Given a prompt and a pair of chosen and rejected responses, DPO aims to increase the probability of the chosen response and decrease the probability of the rejected response relative to a frozen reference model. The actor is initialized using the reference model. For more details, refer to the DPO paper.

Launch a DPO Run

The script examples/run_dpo.py can be used to launch a DPO experiment. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the cluster documentation.

Be sure to launch the job using uv. The command to launch a DPO job is as follows:

uv run examples/run_dpo.py --config <PATH TO YAML CONFIG> <OVERRIDES>

If not specified, config will default to examples/configs/dpo.yaml.

Configuration

NeMo RL allows users to configure DPO experiments using yaml config files. An example DPO configuration file can be found here.

To override a value in the config, either update the value in the yaml file directly, or pass the override via the command line. For example:

uv run examples/run_dpo.py \
    cluster.gpus_per_node=8 \
    dpo.sft_loss_weight=0.1 \
    dpo.preference_average_log_probs=True \
    logger.wandb.name="dpo-dev-8-gpu"

Reminder: Don't forget to set your HF_HOME, WANDB_API_KEY, and HF_DATASETS_CACHE (if needed). You'll need to do a huggingface-cli login as well for Llama models.

Datasets

Each class representing a NeMo RL DPO dataset is expected to have the following attributes:

  1. formatted_ds: The dictionary of formatted datasets. This dictionary should contain train and validation splits, and each split should conform to the format described below.
  2. task_spec: The TaskDataSpec for this dataset. This should specify the name you choose for this dataset.

DPO datasets are expected to follow a specific format with three key fields:

  • prompt: The input prompt/context
  • chosen_response: The preferred/winning response
  • rejected_response: The non-preferred/losing response

data/hf_datasets/helpsteer3.py provides an example of how to format data for DPO:

def format_helpsteer3(data):
    response_1 = data["response1"]
    response_2 = data["response2"]
    overall_preference = data["overall_preference"]

    if overall_preference < 0:
        chosen = response_1
        rejected = response_2
    elif overall_preference == 0:
        chosen = response_1
        rejected = response_1
    else:
        chosen = response_2
        rejected = response_1

    return {
        "prompt": data["context"],
        "chosen_response": chosen,
        "rejected_response": rejected,
    }

We also provide a DPODataset class that is compatible with jsonl-formatted preference datsets. This class assumes train and validation datasets have been split and processed into the expected format offline. The jsonl files should consist of examples with prompt, chosen_response, and rejected_response keys.

Adding Custom DPO Datasets

Adding a new DPO dataset is straightforward. Your custom dataset class should:

  1. Implement the required format conversion in the constructor
  2. Set up the appropriate task_spec

Here's a minimal example which simply re-keys an existing jsonl dataset:

from datasets import load_dataset
from nemo_rl.data.interfaces import TaskDataSpec
from docs.helpers import make_dpo_dataset

class CustomDPODataset:
    def preprocess_dataset(
        self,
        data,
        prompt_key: str = "context",
        chosen_key: str = "chosen",
        rejected_key: str = "rejected"
    ):
        return {
            "prompt": data[prompt_key],
            "chosen_response": data[chosen_key],
            "rejected_response": data[rejected_key],
        }
    
    def __init__(
        self,
        train_data_path: str,
        val_data_path: str,
        prompt_key: str,
        chosen_key: str,
        rejected_key: str,
    ):
        # Load and format your dataset
        fn_kwargs={
                "prompt_key": prompt_key, 
                "chosen_key": chosen_key, 
                "rejected_key": rejected_key
            }
        formatted_ds = {
            "train": load_dataset("json", data_files=train_data_path, split="train").map(
                self.preprocess_dataset, 
                fn_kwargs=fn_kwargs,
            ),
            "validation": load_dataset("json", data_files=val_data_path, split="train").map(
                self.preprocess_dataset, 
                fn_kwargs=fn_kwargs,
            ),
        }
        
        # Initialize task spec with dataset name
        self.task_spec = TaskDataSpec(
            task_name="custom_dpo",
        )
        self.formatted_ds = formatted_ds

# Create temporary files using helper function
train_file, val_file = make_dpo_dataset()

# Initialize dataset
dataset = CustomDPODataset(
    train_data_path=train_file.name,
    val_data_path=val_file.name,
    prompt_key="context",
    chosen_key="chosen",
    rejected_key="rejected"
)

# Test dataset properties
print(f"Task name: {dataset.task_spec.task_name}")
print(f"Train examples: {len(dataset.formatted_ds['train'])}")
print(f"Validation examples: {len(dataset.formatted_ds['validation'])}")
print(f"First train example prompt: {dataset.formatted_ds['train'][0]['prompt']}")
print(f"First train example chosen response: {dataset.formatted_ds['train'][0]['chosen_response']}")
print(f"First train example rejected response: {dataset.formatted_ds['train'][0]['rejected_response']}")
Task name: custom_dpo
Train examples: 2
Validation examples: 2
First train example prompt: What is 2+2?
First train example chosen response: 4
First train example rejected response: 5

DPO-Specific Parameters

The DPO implementation in NeMo RL supports several key parameters that can be adjusted:

  • dpo.reference_policy_kl_penalty: Controls the strength of the KL penalty term
  • dpo.preference_loss_weight: Weight for the preference loss
  • dpo.sft_loss_weight: Weight for the auxiliary SFT loss
  • dpo.preference_average_log_probs: Whether to average log probabilities over tokens in the preference loss term
  • dpo.sft_average_log_probs: Whether to average log probabilities over tokens in the SFT loss term

These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case.

Evaluate the Trained Model

Upon completion of the training process, you can refer to our evaluation guide to assess model capabilities.