Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit abdb863

Browse files
committed
refactor
Signed-off-by: minmingzhu <[email protected]>
1 parent 8e6e352 commit abdb863

File tree

5 files changed

+495
-422
lines changed

5 files changed

+495
-422
lines changed

llm_on_ray/finetune/dpo_funetuing.py renamed to llm_on_ray/finetune/data_preprocess.py

Lines changed: 5 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,15 @@
1414
# limitations under the License.
1515
#
1616
import datasets
17-
import torch
18-
from peft import LoraConfig
19-
from transformers import AutoModelForCausalLM
2017
from typing import Dict
2118

2219
IGNORE_INDEX = -100
2320

2421

25-
class DPOIntelOrcaProcesser:
22+
class DPOIntelOrcaPreprocesser:
2623
@staticmethod
2724
def tokenize_dataset(config, tokenizer, dataset):
2825
tokenizer.pad_token = tokenizer.eos_token
29-
if isinstance(dataset, datasets.Dataset):
30-
column_names = dataset.column_names
31-
32-
if isinstance(dataset, datasets.DatasetDict):
33-
column_names = dataset["train"].column_names
3426

3527
def return_prompt_and_responses(samples) -> Dict[str, str]:
3628
return {
@@ -44,15 +36,11 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
4436
"rejected": samples["rejected"],
4537
}
4638

47-
raw_datasets = dataset.map(
39+
dataset = dataset.map(
4840
return_prompt_and_responses,
49-
remove_columns=column_names,
5041
load_from_cache_file=False,
5142
desc="Tokenize dataset",
5243
)
53-
train_dataset = raw_datasets["train"]
54-
column_names = train_dataset.column_names
55-
5644
"""
5745
Copied from https://github.com/intel/intel-extension-for-transformers/blob/5ba5fa8048b63bec8a3be8a7122a3db8344ad065/
5846
intel_extension_for_transformers/neural_chat/examples/finetuning/dpo_pipeline/dpo_clm.py#L308
@@ -145,6 +133,8 @@ def preprocess_function(examples):
145133

146134
return examples
147135

136+
train_dataset = dataset["train"]
137+
column_names = list(train_dataset.features)
148138
if train_dataset is not None:
149139
# Create train feature from dataset
150140
train_dataset = train_dataset.map(
@@ -154,7 +144,7 @@ def preprocess_function(examples):
154144
desc="Running tokenizer on train dataset",
155145
)
156146

157-
eval_dataset = raw_datasets.get("validation")
147+
eval_dataset = dataset.get("validation")
158148

159149
if eval_dataset is not None:
160150
column_names = eval_dataset.column_names
@@ -167,78 +157,3 @@ def preprocess_function(examples):
167157
tokenized_datasets = {"train": train_dataset, "validation": eval_dataset}
168158

169159
return tokenized_datasets
170-
171-
172-
class DPOFuneTuning:
173-
def __init__(self, config):
174-
self.config = config
175-
self.torch_dtype = (
176-
self.config["Dataset"]["torch_dtype"]
177-
if self.config["Dataset"]["torch_dtype"] in ["auto", None]
178-
else getattr(torch, self.config["Dataset"]["torch_dtype"])
179-
)
180-
181-
def get_model(self):
182-
# load policy model
183-
model = AutoModelForCausalLM.from_pretrained(
184-
self.config["General"]["base_model"],
185-
config=self.config,
186-
low_cpu_mem_usage=True,
187-
torch_dtype=self.torch_dtype,
188-
use_auth_token=True if self.config["General"]["config"]["use_auth_token"] else None,
189-
)
190-
model.config.use_cache = False
191-
return model
192-
193-
def get_model_ref(self):
194-
# load reference model
195-
model_ref = AutoModelForCausalLM.from_pretrained(
196-
self.config["General"]["base_model"],
197-
config=self.config,
198-
low_cpu_mem_usage=True,
199-
torch_dtype=self.torch_dtype,
200-
use_auth_token=True if self.config["General"]["config"]["use_auth_token"] else None,
201-
)
202-
model_ref.config.use_cache = False
203-
return model_ref
204-
205-
def dpo_train(self, training_args, tokenized_datasets, tokenizer):
206-
from trl import DPOTrainer
207-
208-
lora_config = self.config["General"].get("lora_config", None)
209-
return DPOTrainer(
210-
self.get_model(),
211-
self.get_model_ref() if lora_config is not None else None,
212-
args=training_args,
213-
beta=self.config["Training"].get("beta"),
214-
train_dataset=tokenized_datasets["train"],
215-
eval_dataset=tokenized_datasets["validation"]
216-
if tokenized_datasets.get("validation") is not None
217-
else None,
218-
tokenizer=tokenizer,
219-
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
220-
max_length=self.config["Dataset"].get("max_length"),
221-
max_prompt_length=self.config["Dataset"].get("max_prompt_length"),
222-
)
223-
224-
225-
class GaudiDPOFuneTuning(DPOFuneTuning):
226-
def dpo_train(self, training_args, gaudi_config, tokenized_datasets, tokenizer):
227-
from optimum.habana.trl import GaudiDPOTrainer as DPOTrainer
228-
229-
lora_config = self.config["General"].get("lora_config", None)
230-
return DPOTrainer(
231-
self.get_model(),
232-
self.get_model_ref() if lora_config is not None else None,
233-
args=training_args,
234-
gaudi_config=gaudi_config,
235-
beta=self.config["Training"].get("beta"),
236-
train_dataset=tokenized_datasets["train"],
237-
eval_dataset=tokenized_datasets["validation"]
238-
if tokenized_datasets.get("validation") is not None
239-
else None,
240-
tokenizer=tokenizer,
241-
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
242-
max_length=self.config["Dataset"].get("max_length"),
243-
max_prompt_length=self.config["Dataset"].get("max_prompt_length"),
244-
)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#
2+
# Copyright 2023 The LLM-on-Ray Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
import torch
17+
import transformers
18+
from peft import LoraConfig
19+
from transformers import AutoModelForCausalLM
20+
from typing import Dict
21+
22+
from llm_on_ray.finetune.data_preprocess import DPOIntelOrcaPreprocesser
23+
from itertools import chain
24+
25+
from llm_on_ray.finetune.finetuning import Finetuning
26+
27+
IGNORE_INDEX = -100
28+
29+
30+
class DPOFineTuning(Finetuning):
31+
def tokenize_dataset(self, config: Dict, tokenizer, dataset):
32+
print("tokenize_dataset")
33+
print(dataset)
34+
config["Dataset"].get("group", True)
35+
config["Dataset"].get("block_size", 512)
36+
tokenizer.pad_token = tokenizer.eos_token
37+
tokenized_dataset = DPOIntelOrcaPreprocesser.tokenize_dataset(config, tokenizer, dataset)
38+
print(tokenized_dataset)
39+
return tokenized_dataset
40+
41+
def load_model(self, config: Dict):
42+
model_name = config["General"]["base_model"]
43+
model_dtype = self.convert_dtype(config["Training"].get("mixed_precision", "no"))
44+
model_config = config["General"].get("config", {})
45+
model = transformers.AutoModelForCausalLM.from_pretrained(
46+
model_name, torch_dtype=model_dtype, **model_config
47+
)
48+
49+
egc = config["General"].get("enable_gradient_checkpointing", False)
50+
if egc:
51+
model.enable_input_require_grads()
52+
model.gradient_checkpointing_enable()
53+
model.config.use_cache = False
54+
55+
model.to(dtype=model_dtype, device=torch.device(config["Training"]["device"]))
56+
57+
return model
58+
59+
def load_model_ref(self, config: Dict):
60+
model_name = config["General"]["base_model"]
61+
model_dtype = self.convert_dtype(config["Training"].get("mixed_precision", "no"))
62+
model_config = config["General"].get("config", {})
63+
64+
# load reference model
65+
model_ref = transformers.AutoModelForCausalLM.from_pretrained(
66+
model_name, torch_dtype=model_dtype, **model_config
67+
)
68+
69+
model_ref.config.use_cache = False
70+
model_ref.to(dtype=model_dtype, device=torch.device(config["Training"]["device"]))
71+
72+
return model_ref
73+
74+
def get_trainer(self, config: Dict, model, tokenizer, tokenized_dataset, data_collator):
75+
device = config["Training"]["device"]
76+
lora_config = config["General"].get("lora_config", None)
77+
78+
if device in ["cpu", "gpu"]:
79+
from transformers import Trainer, TrainingArguments
80+
from trl import DPOTrainer
81+
82+
training_args = self.convert_to_training_args(TrainingArguments, config)
83+
84+
trainer = DPOTrainer(
85+
model,
86+
self.load_model_ref(config) if lora_config is not None else None,
87+
args=training_args,
88+
beta=config["Training"].get("beta"),
89+
train_dataset=tokenized_dataset["train"],
90+
eval_dataset=tokenized_dataset["validation"]
91+
if tokenized_dataset.get("validation") is not None
92+
else None,
93+
tokenizer=tokenizer,
94+
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
95+
max_length=config["Dataset"].get("max_length"),
96+
max_prompt_length=config["Dataset"].get("max_prompt_length"),
97+
)
98+
elif device in ["hpu"]:
99+
from optimum.habana.trl import GaudiDPOTrainer as DPOTrainer
100+
from optimum.habana.transformers import GaudiTrainingArguments
101+
from optimum.habana import GaudiConfig
102+
103+
# If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config
104+
gaudi_config_name = config["General"].get("gaudi_config_name", None)
105+
if gaudi_config_name is not None:
106+
gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name)
107+
else:
108+
gaudi_config = GaudiConfig()
109+
gaudi_config.use_fused_adam = True
110+
gaudi_config.use_fused_clip_norm = True
111+
112+
training_args = self.convert_to_training_args(GaudiTrainingArguments, config)
113+
trainer = DPOTrainer(
114+
model,
115+
self.load_model_ref(config) if lora_config is not None else None,
116+
args=training_args,
117+
gaudi_config=gaudi_config,
118+
beta=config["Training"].get("beta"),
119+
train_dataset=tokenized_dataset["train"],
120+
eval_dataset=tokenized_dataset["validation"]
121+
if tokenized_dataset.get("validation") is not None
122+
else None,
123+
tokenizer=tokenizer,
124+
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
125+
max_length=config["Dataset"].get("max_length"),
126+
max_prompt_length=config["Dataset"].get("max_prompt_length"),
127+
)
128+
129+
return training_args, trainer

0 commit comments

Comments
 (0)