Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit aa2d08e

Browse files
author
harborn
authored
[Finetune] Support fine-tuning on Gaudi (#155)
* [Fine-tuning] Enable fine-tuning on Gaudi * update * upate * update * update * update
1 parent a51fd46 commit aa2d08e

File tree

7 files changed

+161
-83
lines changed

7 files changed

+161
-83
lines changed

docs/finetune_parameters.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ The following are the parameters supported in the finetuning workflow.
2323
|validation_file|None|A json file containing the validation data.|
2424
|validation_split_percentage|5|The percentage of the train set used as validation set in case there's no validation split|
2525
|preprocessing_num_workers|None|The number of processes to use for the preprocessing.|
26+
|max_length|512|Padding sequential data to max length of a batch|
27+
|group|True|Whether to concatenate the sentence for more efficient training|
28+
|block_size|512|The block size of concatenated sentence|
29+
|shuffle|False|Whether shuffle the data at every epoch|
30+
2631

2732
## Training Parameters
2833
|Configuration Name| Default|Meaning|

llm_on_ray/common/dataprocesser/general_processer.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def torch_call(self, examples):
8484

8585
class GeneralProcesser(DataProcesser):
8686
def prepare(self, tokenizer, dataset):
87-
per_device_train_batch_size = self.config.get("per_device_train_batch_size", 1)
88-
per_device_eval_batch_size = self.config.get("per_device_eval_batch_size", 1)
89-
group = self.config.get("group", False)
90-
self.config.get("shuffle", False)
87+
per_device_train_batch_size = self.config.get("per_device_train_batch_size")
88+
per_device_eval_batch_size = self.config.get("per_device_eval_batch_size")
89+
max_length = self.config.get("max_length")
90+
group = self.config.get("group")
91+
block_size = self.config.get("block_size")
92+
shuffle = self.config.get("shuffle")
9193
tokenizer.pad_token = tokenizer.eos_token
9294

9395
if isinstance(dataset, datasets.Dataset):
@@ -123,8 +125,6 @@ def prompt(rec):
123125
)
124126
column_names += [TEXT_COLUMN_NAME]
125127

126-
max_length = self.config.get("max_length", 1024)
127-
128128
def tokenize_function(examples):
129129
return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length)
130130

@@ -136,7 +136,6 @@ def tokenize_function(examples):
136136
)
137137

138138
if group:
139-
block_size = self.config.get("block_size", 1024)
140139

141140
def group_texts(examples):
142141
# Concatenate all texts.
@@ -160,30 +159,30 @@ def group_texts(examples):
160159
load_from_cache_file=False,
161160
desc=f"Grouping texts in chunks of {block_size}",
162161
)
163-
default_data_collator = transformers.default_data_collator
164-
165-
else:
166-
default_data_collator = DataCollatorForCompletionOnlyLM(
167-
tokenizer=tokenizer,
168-
mlm=False,
169-
return_tensors="pt",
170-
pad_to_multiple_of=8,
171-
)
172162

173-
train_dataset = tokenized_datasets["train"]
174-
train_dataloader = torch.utils.data.DataLoader(
175-
train_dataset,
176-
shuffle=True,
177-
collate_fn=default_data_collator,
178-
batch_size=per_device_train_batch_size,
163+
data_collator = DataCollatorForCompletionOnlyLM(
164+
tokenizer=tokenizer,
165+
mlm=False,
166+
return_tensors="pt",
167+
pad_to_multiple_of=8,
179168
)
180169

170+
train_dataset = tokenized_datasets["train"]
171+
train_dataloader_params = {
172+
"shuffle": shuffle,
173+
"collate_fn": data_collator,
174+
"batch_size": per_device_train_batch_size,
175+
"pin_memory": True,
176+
}
177+
train_dataloader = torch.utils.data.DataLoader(train_dataset, **train_dataloader_params)
178+
181179
eval_dataloader = None
182180
if "validation" in tokenized_datasets:
183181
eval_dataset = tokenized_datasets["validation"]
184-
eval_dataloader = torch.utils.data.DataLoader(
185-
eval_dataset,
186-
collate_fn=default_data_collator,
187-
batch_size=per_device_eval_batch_size,
188-
)
182+
eval_dataloader_params = {
183+
"shuffle": shuffle,
184+
"collate_fn": data_collator,
185+
"batch_size": per_device_eval_batch_size,
186+
}
187+
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, **eval_dataloader_params)
189188
return train_dataloader, eval_dataloader

llm_on_ray/common/torch_config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def backend_cls(self):
2323
return EnableCCLBackend
2424

2525

26-
def libs_import():
26+
def xpu_libs_import():
2727
"""try to import IPEX and oneCCL."""
2828
try:
2929
import intel_extension_for_pytorch
@@ -39,6 +39,14 @@ def libs_import():
3939
raise ImportError("Please install torch-ccl") from ccl_not_exist
4040

4141

42+
def hpu_libs_import():
43+
"""try to import habana frameworkfs for torch"""
44+
try:
45+
import habana_frameworks.torch # noqa: F401
46+
except ImportError as habana_not_exist:
47+
raise ImportError("Please install habana_frameworks") from habana_not_exist
48+
49+
4250
def _set_torch_distributed_env_vars(device):
4351
if device is not None:
4452
os.environ["ACCELERATE_TORCH_DEVICE"] = device
@@ -48,6 +56,11 @@ class EnableCCLBackend(_TorchBackend):
4856
device: Optional[str] = None
4957

5058
def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
59+
libs_import = (
60+
hpu_libs_import
61+
if self.device is not None and self.device.startswith("hpu")
62+
else xpu_libs_import
63+
)
5164
for i in range(len(worker_group)):
5265
worker_group.execute_single_async(i, libs_import)
5366
super().on_start(worker_group, backend_config)

llm_on_ray/common/trainer/default_trainer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
135135
# self.model, self.optimizer, self.lr_scheduler, ..., are prepared with 2 steps
136136
# because it is recommended way to prepare model and optimizer while using FSDP.
137137
# https://huggingface.co/docs/accelerate/usage_guides/fsdp#a-few-caveats-to-be-aware-of
138-
accelerate_mode = self.config.get("accelerate_mode")
139-
if accelerate_mode in ["GPU_DEEPSPEED"]:
138+
self.accelerate_mode = self.config.get("accelerate_mode")
139+
if self.accelerate_mode in ["GPU_DEEPSPEED"]:
140140
lr = lr_scheduler_config.get("learning_rate", 0.001)
141141
weight_decay = lr_scheduler_config.get("weight_decay", 0)
142142
from accelerate.utils import DummyOptim, DummyScheduler
@@ -163,6 +163,14 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
163163
self.lr_scheduler,
164164
) = accelerator.prepare(optimizer, train_dataloader, eval_dataloader, lr_scheduler)
165165

166+
if self.accelerate_mode in ["HPU_DDP"]:
167+
import habana_frameworks.torch.core as htcore
168+
from habana_frameworks.torch.utils.internal import is_lazy
169+
170+
self.htcore = htcore
171+
else:
172+
self.htcore = None
173+
166174
checkpoint = self.config.get("checkpoint")
167175
if checkpoint is not None:
168176
self.recovery(checkpoint)
@@ -180,12 +188,20 @@ def train(self):
180188
logger.info(f"Start training epoch {idx}, total_steps {total_steps}")
181189
for step, batch in enumerate(self.train_dataloader):
182190
with self.accelerator.accumulate(self.model):
191+
self.model.train()
192+
batch = batch.to(device=self.accelerator.device)
183193
outputs = self.model(**batch)
184194
loss = outputs.loss
185195
self.accelerator.backward(loss)
196+
if self.htcore is not None:
197+
self.htcore.mark_step()
186198
self.optimizer.step()
199+
if self.htcore is not None:
200+
self.htcore.mark_step()
187201
if self.lr_scheduler is not None:
188202
self.lr_scheduler.step()
203+
if self.htcore is not None:
204+
self.htcore.mark_step()
189205
self.optimizer.zero_grad()
190206

191207
if step % logging_steps == 0:

0 commit comments

Comments
 (0)