Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull Request: Addition of "general-preference" Branch for General Preference Model (GPM) #201

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Thumbs.db

# model outputs/files
results/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down Expand Up @@ -175,4 +174,8 @@ beaker_configs/auto_created
# generated local directory
hf_snapshot_evals/
data/
output/
output/

scripts/results/eval-set/cephfs/
scripts/run_rm_rewardbench_27b.sh
scripts/run_batch_rm_rewardbench.sh
34 changes: 33 additions & 1 deletion rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
build_starling_rm,
)
from .ziya import ZiyaPipeline

from .gpm import GPMPipeline
# Please open a PR if you need to add more custom modeling code / utilize existing code for you model
REWARD_MODEL_CONFIG = {
"default": {
Expand Down Expand Up @@ -208,6 +208,38 @@
"model_type": "Seq. Classifier",
"torch_dtype": torch.bfloat16,
},
"general-preference/GPM-Llama-3.1-8B-Instruct": {
"model_builder": AutoModelForCausalLM.from_pretrained,
"pipeline_builder": GPMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"torch_dtype": torch.bfloat16,
},
"general-preference/GPM-Gemma-2B": {
"model_builder": AutoModelForCausalLM.from_pretrained,
"pipeline_builder": GPMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"torch_dtype": torch.bfloat16,
},
"general-preference/GPM-Gemma-2-2B": {
"model_builder": AutoModelForCausalLM.from_pretrained,
"pipeline_builder": GPMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"torch_dtype": torch.bfloat16,
},
"general-preference/GPM-Gemma-2-9B": {
"model_builder": AutoModelForCausalLM.from_pretrained,
"pipeline_builder": GPMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"torch_dtype": torch.bfloat16,
}
}

DPO_MODEL_CONFIG = {
Expand Down
228 changes: 228 additions & 0 deletions rewardbench/models/gpm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from typing import Optional, List, Dict
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
import torch.nn.functional as F
from transformers import AutoTokenizer
import os
from safetensors.torch import load_file
from huggingface_hub import snapshot_download

def get_tokenizer(pretrain, model, padding_side="left", use_fast=True):
tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
tokenizer.padding_side = padding_side
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return tokenizer

def get_reward_model(base_causal_model, base_llm_model, value_head_dim: int, add_prompt_head: bool, is_general_preference: bool=False):
class CustomRewardModel(base_causal_model):

def __init__(self, config: AutoConfig):
super().__init__(config)
setattr(self, self.base_model_prefix, base_llm_model(config))
self.is_general_preference = is_general_preference

self.value_head = nn.Linear(config.hidden_size, value_head_dim, bias=False)
if add_prompt_head:
self.prompt_head = nn.Linear(config.hidden_size, value_head_dim // 2, bias=False)

def custom_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
return_output=False,
) -> torch.Tensor:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
outputs = getattr(self, self.base_model_prefix)(
input_ids, attention_mask=attention_mask, position_ids=position_ids
)
last_hidden_states = outputs["last_hidden_state"]

if not self.is_general_preference:
values = self.value_head(last_hidden_states).squeeze(-1)
# left padding in training mode
if self.training:
reward = values[:, -1]
else:
eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
reward = values.gather(dim=1, index=eos_indices).squeeze(1)
if return_output:
return reward, outputs
else:
return reward, None
else:
values = self.value_head(last_hidden_states)
# left padding in training mode
if self.training:
reward = values[:, -1, :]
reward = F.normalize(reward, p=2, dim=-1) # Shape will be [batch_size, value_head_dim]
else:
eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1)
eos_indices = eos_indices.unsqueeze(1) # Change shape to [batch_size, 1]
reward_list = []
for dim in range(value_head_dim):
reward_list.append(values[:,:,dim].gather(dim=1, index=eos_indices))
reward = torch.cat(reward_list, dim=1)
reward = F.normalize(reward, p=2, dim=-1) # Shape will be [batch_size, value_head_dim]
if return_output:
return reward, outputs
else:
return reward, None

def create_skew_symmetric_block_matrix(self, dim, device, dtype, prompt_hidden_states):
"""
Create a batch of skew-symmetric block matrices where each matrix is data-dependent on
the corresponding prompt_hidden_states. Only the relevant block diagonal parts are generated.

Args:
- dim: Dimension of the square matrix (must be even).
- prompt_hidden_states: Tensor of shape [batch_size, hidden_dim].

Returns:
- batch_R_matrices: Tensor of shape [batch_size, dim, dim], with skew-symmetric block entries.
"""
if hasattr(self, 'prompt_head'):
batch_size = prompt_hidden_states.shape[0]

# Ensure that dim is even, as we're creating blocks of size 2x2
assert dim % 2 == 0, "dim must be even for skew-symmetric block generation"

# Pass through the linear layer to get the block diagonal entries (half of the matrix's off-diagonal blocks)
block_values = self.prompt_head(prompt_hidden_states).view(batch_size, dim // 2)
block_values = torch.softmax(block_values, dim=-1)

# Create a batch of zero matrices [batch_size, dim, dim]
batch_R_matrices = torch.zeros((batch_size, dim, dim), device=device, dtype=dtype)

# Fill only the block diagonal entries with the learned values
for i in range(0, dim, 2):
batch_R_matrices[:, i, i + 1] = -block_values[:, i // 2]
batch_R_matrices[:, i + 1, i] = block_values[:, i // 2] # Skew-symmetric condition
else:
raise AttributeError("prompt_head is not defined. Ensure 'add_prompt_head' is set to True during initialization.")

return batch_R_matrices

return CustomRewardModel

class GPMPipeline:
def __init__(self, model_name_or_path, device=torch.device("cuda:0"), is_general_preference: bool=True, tau: float=0.1, trust_remote_code: bool=True, **kwargs):

self.device = device
self.is_general_preference = is_general_preference
self.tau = tau

config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
config._attn_implementation = kwargs.get("attn_implementation", None)
base_class = AutoModel._model_mapping[type(config)]
base_causal_class = AutoModelForCausalLM._model_mapping.get(type(config), None)

try:
dir_path = snapshot_download(repo_id=model_name_or_path)
except Exception as e:
dir_path = model_name_or_path
combined_weights = {}
for filename in os.listdir(dir_path):
if filename.endswith(".safetensors"):
file_path = os.path.join(dir_path, filename)
weights = load_file(file_path)
combined_weights.update(weights)

if "value_head.weight" in combined_weights:
self.value_head_dim = combined_weights["value_head.weight"].shape[0]

self.add_prompt_head = True if "prompt_head.weight" in combined_weights else False

cls_class = get_reward_model(base_causal_class, base_class, value_head_dim=self.value_head_dim, add_prompt_head=self.add_prompt_head, is_general_preference=self.is_general_preference)

# configure model
self.model = cls_class.from_pretrained(
model_name_or_path,
config=config,
trust_remote_code=trust_remote_code,
torch_dtype=kwargs.get("torch_dtype", "float16"),
)
# configure tokenizer
self.tokenizer = get_tokenizer(model_name_or_path, self.model, "left", use_fast=True)
self.tokenizer.truncation_side = "right"

# prepare model
self.model.to(device)
self.model.eval()

def __call__(self, samples: List[List[Dict[str, str]]], return_prompt=False, **kwargs):
_ = kwargs.get("batch_size", 1)
self.truncation = kwargs.get("truncation", True)
self.padding = kwargs.get("padding", True)
self.max_length = kwargs.get("max_length", 2048)

input_texts = [self.tokenizer.apply_chat_template(sample, tokenize=False) for sample in samples]

inputs = self.tokenizer(
input_texts,
truncation=self.truncation,
max_length=self.max_length,
padding=self.padding,
return_tensors="pt",
).to(self.device)

inputs["input_ids"][:, -1] = self.tokenizer.eos_token_id
inputs["attention_mask"][:, -1] = 1

with torch.no_grad():
rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)

chosen_response_len_list = []
if return_prompt:
prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
for i in range(len(input_texts)):
prompt_token = self.tokenizer(
prompt_texts[i],
max_length=self.max_length,
padding=False,
truncation=self.truncation,
return_tensors="pt",
)
chosen_token = self.tokenizer(
input_texts[i],
max_length=self.max_length,
padding=False,
truncation=self.truncation,
return_tensors="pt",
)
chosen_response_len = chosen_token["attention_mask"].sum() - prompt_token["attention_mask"].sum()
chosen_response_len_list.append(chosen_response_len)
chosen_response_len = torch.tensor(chosen_response_len_list).view(-1, 1).to(self.device)
if return_prompt:
chosen_last_hidden_states = outputs["last_hidden_state"]
prompt_end_index = chosen_last_hidden_states.size(1) - chosen_response_len - 1
prompt_end_index_expanded = prompt_end_index.unsqueeze(-1).expand(-1, -1, chosen_last_hidden_states.size(-1))
prompt_hidden_state = torch.gather(chosen_last_hidden_states, dim=1, index=prompt_end_index_expanded).squeeze(1)
return rewards, prompt_hidden_state
else:
return rewards

def generate_high_dim_result(self, chosen_reward, rejected_reward):
R_matrix = torch.zeros((self.value_head_dim, self.value_head_dim), device=chosen_reward.device, dtype=chosen_reward.dtype)
for i in range(0, self.value_head_dim, 2):
R_matrix[i, i+1] = -1
R_matrix[i+1, i] = 1
if chosen_reward.device == rejected_reward.device == R_matrix.device:
transformed_chosen = torch.matmul(chosen_reward, R_matrix.T)
result = torch.bmm(transformed_chosen.view(chosen_reward.shape[0], 1, self.value_head_dim), rejected_reward.view(rejected_reward.shape[0], self.value_head_dim, 1))
result = result.view(chosen_reward.shape[0])
return result

def generate_high_dim_result_with_prompt(self, chosen_reward, rejected_reward, prompt_hidden_states):
R_matrix = self.model.create_skew_symmetric_block_matrix(self.value_head_dim, chosen_reward.device, chosen_reward.dtype, prompt_hidden_states)
if chosen_reward.device == rejected_reward.device == R_matrix.device:
transformed_chosen = torch.bmm(chosen_reward.view(chosen_reward.shape[0], 1, self.value_head_dim), R_matrix.transpose(1, 2))
result = torch.bmm(transformed_chosen, rejected_reward.view(rejected_reward.shape[0], self.value_head_dim, 1))
result = result.view(chosen_reward.shape[0])
return result


28 changes: 28 additions & 0 deletions scripts/results/eval-set/general-preference/GPM-Gemma-2-2B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"alpacaeval-easy": 0.77,
"alpacaeval-hard": 0.7052631578947368,
"alpacaeval-length": 0.8736842105263158,
"chat_template": "tokenizer",
"donotanswer": 0.5588235294117647,
"hep-cpp": 0.7378048780487805,
"hep-go": 0.7073170731707317,
"hep-java": 0.8170731707317073,
"hep-js": 0.7317073170731707,
"hep-python": 0.725609756097561,
"hep-rust": 0.7195121951219512,
"llmbar-adver-GPTInst": 0.8369565217391305,
"llmbar-adver-GPTOut": 0.7446808510638298,
"llmbar-adver-manual": 0.4782608695652174,
"llmbar-adver-neighbor": 0.7313432835820896,
"llmbar-natural": 0.78,
"math-prm": 0.9485458612975392,
"model": "general-preference/GPM-Gemma-2-2B",
"model_type": "Custom Classifier",
"mt-bench-easy": 1.0,
"mt-bench-hard": 0.5945945945945946,
"mt-bench-med": 0.95,
"refusals-dangerous": 0.93,
"refusals-offensive": 0.99,
"xstest-should-refuse": 0.8896103896103896,
"xstest-should-respond": 0.984
}
28 changes: 28 additions & 0 deletions scripts/results/eval-set/general-preference/GPM-Gemma-2-9B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"alpacaeval-easy": 0.93,
"alpacaeval-hard": 0.9368421052631579,
"alpacaeval-length": 0.9052631578947369,
"chat_template": "tokenizer",
"donotanswer": 0.7573529411764706,
"hep-cpp": 0.9329268292682927,
"hep-go": 0.9329268292682927,
"hep-java": 0.9512195121951219,
"hep-js": 0.9451219512195121,
"hep-python": 0.9512195121951219,
"hep-rust": 0.9329268292682927,
"llmbar-adver-GPTInst": 0.9456521739130435,
"llmbar-adver-GPTOut": 0.8936170212765957,
"llmbar-adver-manual": 0.8043478260869565,
"llmbar-adver-neighbor": 0.835820895522388,
"llmbar-natural": 0.94,
"math-prm": 0.9731543624161074,
"model": "general-preference/GPM-Gemma-2-9B",
"model_type": "Custom Classifier",
"mt-bench-easy": 1.0,
"mt-bench-hard": 0.7567567567567568,
"mt-bench-med": 0.975,
"refusals-dangerous": 0.9,
"refusals-offensive": 0.99,
"xstest-should-refuse": 0.9415584415584416,
"xstest-should-respond": 0.984
}
28 changes: 28 additions & 0 deletions scripts/results/eval-set/general-preference/GPM-Gemma-2B.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"alpacaeval-easy": 0.75,
"alpacaeval-hard": 0.7263157894736842,
"alpacaeval-length": 0.8315789473684211,
"chat_template": "tokenizer",
"donotanswer": 0.5735294117647058,
"hep-cpp": 0.6646341463414634,
"hep-go": 0.6707317073170732,
"hep-java": 0.6707317073170732,
"hep-js": 0.6951219512195121,
"hep-python": 0.6951219512195121,
"hep-rust": 0.6219512195121951,
"llmbar-adver-GPTInst": 0.782608695652174,
"llmbar-adver-GPTOut": 0.7021276595744681,
"llmbar-adver-manual": 0.45652173913043476,
"llmbar-adver-neighbor": 0.664179104477612,
"llmbar-natural": 0.71,
"math-prm": 0.9395973154362416,
"model": "general-preference/GPM-Gemma-2B",
"model_type": "Custom Classifier",
"mt-bench-easy": 0.9285714285714286,
"mt-bench-hard": 0.43243243243243246,
"mt-bench-med": 0.8,
"refusals-dangerous": 0.94,
"refusals-offensive": 1.0,
"xstest-should-refuse": 0.8506493506493507,
"xstest-should-respond": 0.884
}
Loading