-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
In the train_sana_sprint_diffusers.py example script, the custom save_model_hook incorrectly saves the wrong model. The isinstance() check is too broad because both the trained transformer and the frozen pretrained_model are instances of the same class. This causes the hook to save the trained model and then immediately overwrite it with the frozen, untrained reference model in the same checkpoint directory.
Reproduction
# In train_sana_sprint_diffusers.py, this save hook causes the issue:
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
unwrapped_model = unwrap_model(model)
# This check is too broad and matches both the trained and frozen models
if isinstance(unwrapped_model, type(unwrap_model(transformer))):
model = unwrapped_model
model.save_pretrained(os.path.join(output_dir, "transformer"))
# ... rest of the function ...
Logs
System Info
- OS: Windows 11
- Python: 3.10.18
diffusers
version: 0.35.1transformers
version: 4.57.0torch
version: 2.8.0accelerate
version: 1.10.1huggingface-hub
version: 0.35.3safetensors
version: 0.6.2
Who can help?
Suggested Fix
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for model in models:
unwrapped_model = unwrap_model(model)
# Handle transformer model
if isinstance(unwrapped_model, type(unwrap_model(transformer))):
model = unwrapped_model
if model.config.guidance_embeds:
model.save_pretrained(os.path.join(output_dir, "transformer"))
# Handle discriminator model (only save heads)
elif isinstance(unwrapped_model, type(unwrap_model(disc))):
# Save only the heads
torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, "disc_heads.pt"))
else:
raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
def load_model_hook(models, input_dir):
transformer_ = None
disc_ = None
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()
unwrapped_model = unwrap_model(model)
if isinstance(unwrapped_model, type(unwrap_model(transformer))):
if unwrapped_model.config.guidance_embeds:
transformer_ = model # noqa: F841
elif isinstance(unwrapped_model, type(unwrap_model(disc))):
# Load only the heads
heads_state_dict = torch.load(os.path.join(input_dir, "disc_heads.pt"))
unwrapped_model.heads.load_state_dict(heads_state_dict)
disc_ = model # noqa: F841
else:
raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")
else:
# DeepSpeed case
transformer_ = SanaTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841
disc_heads_state_dict = torch.load(os.path.join(input_dir, "disc_heads.pt")) # noqa: F841
# You'll need to handle how to load the heads in DeepSpeed case
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working