Skip to content

FSDP Torch XLA vs. FSDPv2 (SMPD) Torch XLA checkpoint saving bug #36004

@salrowili

Description

@salrowili

System Info

There is bug in how trainer (SFTTrainer) saves the checkpoint when we use FSDPv2 (SMPD) on TPU. This behavior does not show up with old method to run Torch XLA code ( xla_spawn.py). This behavior causes the new checkpoint to be almost exactly as the base model , throwing this error with PEFT

Found missing adapter keys while loading the checkpoint: {missing_keys}

even without PEFT, the weight of the models seems not affected by the training process.

The problem may related to how the saving function with FSDPv2 Torch XLA works in the trainer file. The same code is working 100% with GPU and also is working with xla_spawn.py FSDP method.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

To replicate save the code as sft.py and run it with PJRT_DEVICE=TPU XLA_USE_SPMD=1 python3 sft.py:

import torch
import torch_xla
import peft
import trl
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig,PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer, SFTConfig
import wandb
wandb.init(mode="disabled")
device = xm.xla_device() # Set up TPU device.
print(device)

def train():
	model_id = "meta-llama/Llama-3.2-1B-Instruct"
	model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
	tokenizer =  AutoTokenizer.from_pretrained(model_id)
	tokenizer.pad_token = tokenizer.eos_token
	data = load_dataset("philschmid/dolly-15k-oai-style",split="train")
	lora_config = LoraConfig(r=8,target_modules=["k_proj", "v_proj"],task_type="CAUSAL_LM")
	fsdp_config = {'fsdp_transformer_layer_cls_to_wrap': ['LlamaDecoderLayer'], 'xla': True, 'xla_fsdp_v2': True, 'xla_fsdp_grad_ckpt': True}
	args=SFTConfig(
                per_device_train_batch_size=8,
                num_train_epochs=1,
		max_steps=-1,
                output_dir="output",
                optim="adafactor",
                logging_steps=50,
                learning_rate=2e-5,
                max_seq_length=2048,
                packing=True,
                dataset_text_field=None,
                save_strategy="no",
                dataloader_drop_last = True,  # Required for SPMD.
                fsdp="full_shard",
                fsdp_config=fsdp_config)
	trainer = SFTTrainer(
	model=model,
	train_dataset=data,
	tokenizer = tokenizer,
	args=args,
	peft_config=lora_config)
	trainer.train()
	final_model=trainer.model
	final_model.to("cpu")
	final_model.save_pretrained("./LoRa")
if __name__ == "__main__":
	train()

You will notice in the output folder, that the saved model is not in LoRa format (not two adapter files adapter_config.json adapter_model.safetensors). This is because with FSDPv2, we will ended up here (You can check by adding print statement).

state_dict = xm._maybe_convert_to_cpu(model.state_dict())

However, if we use the same code with GPU or with old xla_spawn (FSDP) method, this issue will disappear. To replicate the same code with FSDP first run
wget https://raw.githubusercontent.com/huggingface/transformers/refs/heads/main/examples/pytorch/xla_spawn.py
then save the below code and run it with python3 xla_spawn --num_cores x sft.py :

from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TrainingArguments
from trl import SFTTrainer,SFTConfig
import os
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
import transformers
import wandb
wandb.init(mode="disabled")
def main():


	data = load_dataset("philschmid/dolly-15k-oai-style",split="train")
        model_id = "meta-llama/Llama-3.2-1B-Instruct"
	tokenizer = AutoTokenizer.from_pretrained(model_id)
	tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
	model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

	#target_modules=["k_proj", "v_proj","embed_tokens", "lm_head"]
	lora_config = LoraConfig(
		r=16,
		lora_alpha=32,
		lora_dropout=0.05,
		bias="none",
		target_modules=["q_proj", "k_proj", "v_proj","embed_tokens", "lm_head"],
		task_type="CAUSAL_LM",
		  )

	trainer = SFTTrainer(
		model=model,
		train_dataset=data,
		args=SFTConfig(
			per_device_train_batch_size=1,
			num_train_epochs=3,
			max_steps=-1,
			output_dir="./output",
			logging_steps=50,
		learning_rate=5e-5,
		max_seq_length=2048,
		save_steps=1000000,
		save_only_model=True,
		packing=True,
		dataset_num_proc=40,
		),
		peft_config=lora_config,
	)

	trainer.train()
	final_model=trainer.model
	final_model.to("cpu")
	final_model.save_pretrained("./LoRa")

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

With this code everything works great! because the saving function will ended up here:

model.save_pretrained(

I merged the LoRa adapter with the base model and the generated output is as expected from a finetuned model!

Finally, please note that this issue is not related to PEFT, because even if you use SFTTrainer without PEFT, this issue still exist. I believe it has to do with how we save checkpoint with FSDPv2 when we use TPUs.

Expected behavior

The model with LoRa should save two adapter files and when we merge LoRa with the base model we should not have this message (You should update PEFT to the latest version (0.14.0) as it adds additional check to detect problems with LoRa checkpoints.) :

Found missing adapter keys while loading the checkpoint: {missing_keys}

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions