-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Description
System Info
There is bug in how trainer (SFTTrainer) saves the checkpoint when we use FSDPv2 (SMPD) on TPU. This behavior does not show up with old method to run Torch XLA code ( xla_spawn.py). This behavior causes the new checkpoint to be almost exactly as the base model , throwing this error with PEFT
Found missing adapter keys while loading the checkpoint: {missing_keys}
even without PEFT, the weight of the models seems not affected by the training process.
The problem may related to how the saving function with FSDPv2 Torch XLA works in the trainer file. The same code is working 100% with GPU and also is working with xla_spawn.py FSDP method.
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
To replicate save the code as sft.py and run it with PJRT_DEVICE=TPU XLA_USE_SPMD=1 python3 sft.py:
import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig,PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer, SFTConfig
import wandb
wandb.init(mode="disabled")
device = xm.xla_device() # Set up TPU device.
print(device)
def train():
model_id = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
data = load_dataset("philschmid/dolly-15k-oai-style",split="train")
lora_config = LoraConfig(r=8,target_modules=["k_proj", "v_proj"],task_type="CAUSAL_LM")
fsdp_config = {'fsdp_transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'], 'xla': True, 'xla_fsdp_v2': True, 'xla_fsdp_grad_ckpt': True}
args=SFTConfig(
per_device_train_batch_size=8,
num_train_epochs=1,
max_steps=-1,
output_dir="output",
optim="adafactor",
logging_steps=50,
learning_rate=2e-5,
max_seq_length=2048,
packing=True,
dataset_text_field=None,
save_strategy="no",
dataloader_drop_last = True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config)
trainer = SFTTrainer(
model=model,
train_dataset=data,
tokenizer = tokenizer,
args=args,
peft_config=lora_config)
trainer.train()
final_model=trainer.model
final_model.to("cpu")
final_model.save_pretrained("./LoRa")
if __name__ == "__main__":
train()
You will notice in the output folder, that the saved model is not in LoRa format (not two adapter files adapter_config.json adapter_model.safetensors). This is because with FSDPv2, we will ended up here (You can check by adding print statement).
transformers/src/transformers/trainer.py
Line 3821 in 62db3e6
| state_dict = xm._maybe_convert_to_cpu(model.state_dict()) |
However, if we use the same code with GPU or with old xla_spawn (FSDP) method, this issue will disappear. To replicate the same code with FSDP first run
wget https://raw.githubusercontent.com/huggingface/transformers/refs/heads/main/examples/pytorch/xla_spawn.py
then save the below code and run it with python3 xla_spawn --num_cores x sft.py :
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments
from trl import SFTTrainer,SFTConfig
import os
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
import transformers
import wandb
wandb.init(mode="disabled")
def main():
data = load_dataset("philschmid/dolly-15k-oai-style",split="train")
model_id = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
#target_modules=["k_proj", "v_proj","embed_tokens", "lm_head"]
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj", "k_proj", "v_proj","embed_tokens", "lm_head"],
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model=model,
train_dataset=data,
args=SFTConfig(
per_device_train_batch_size=1,
num_train_epochs=3,
max_steps=-1,
output_dir="./output",
logging_steps=50,
learning_rate=5e-5,
max_seq_length=2048,
save_steps=1000000,
save_only_model=True,
packing=True,
dataset_num_proc=40,
),
peft_config=lora_config,
)
trainer.train()
final_model=trainer.model
final_model.to("cpu")
final_model.save_pretrained("./LoRa")
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()
With this code everything works great! because the saving function will ended up here:
transformers/src/transformers/trainer.py
Line 3824 in 62db3e6
| model.save_pretrained( |
I merged the LoRa adapter with the base model and the generated output is as expected from a finetuned model!
Finally, please note that this issue is not related to PEFT, because even if you use SFTTrainer without PEFT, this issue still exist. I believe it has to do with how we save checkpoint with FSDPv2 when we use TPUs.
Expected behavior
The model with LoRa should save two adapter files and when we merge LoRa with the base model we should not have this message (You should update PEFT to the latest version (0.14.0) as it adds additional check to detect problems with LoRa checkpoints.) :
Found missing adapter keys while loading the checkpoint: {missing_keys}