Skip to content

Conversation

@zhangjiewu
Copy link

add ChronoEdit

This PR adds ChronoEdit, a state-of-the-art image editing model that reframes image editing as a video generation task to achieve physically consistent edits.

HF Model: https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers
Gradio Demo: https://huggingface.co/spaces/nvidia/ChronoEdit
Paper: https://arxiv.org/abs/2510.04290
Code: https://github.com/nv-tlabs/ChronoEdit
Website: https://research.nvidia.com/labs/toronto-ai/chronoedit/

cc: @sayakpaul @yiyixuxu @asomoza

Usage

Full model

import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image

model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")

image = load_image(
    "https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
    "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
    "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)

output = pipe(
    image=image,
    prompt=prompt,
    height=height,
    width=width,
    num_frames=5,
    num_inference_steps=50,
    guidance_scale=5.0,
    enable_temporal_reasoning=False,
    num_temporal_reasoning_steps=0,
).frames[0]
export_to_video(output, "output.mp4", fps=4)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")

Full model with temporal reasoning

output = pipe(
    image=image,
    prompt=prompt,
    height=height,
    width=width,
    num_frames=29,
    num_inference_steps=50,
    guidance_scale=5.0,
    enable_temporal_reasoning=True,
    num_temporal_reasoning_steps=50,
).frames[0]

With 8-steps distillation LoRA

import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image

model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")

image = load_image(
    "https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
    "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
    "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
)

output = pipe(
    image=image,
    prompt=prompt,
    height=height,
    width=width,
    num_frames=5,
    num_inference_steps=8,
    guidance_scale=1.0,
    enable_temporal_reasoning=False,
    num_temporal_reasoning_steps=0,
).frames[0]
export_to_video(output, "output.mp4", fps=4)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")

@sayakpaul sayakpaul requested review from DN6 and dg845 November 5, 2025 05:51
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from .transformer_wan import WanTimeTextImageEmbedding, WanTransformerBlock
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we copy over these 2 things and add a #Copied from, instead of importing from wan?

Copy link
Author

@zhangjiewu zhangjiewu Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that makes sense. so we’ll need to copy the all the modules in transformer_wan here.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR! I left one question about whether we support any number of num_frame
other than that, I think we should remove stuff that's in wan but not needed here for chrono to simplify the code a bit, but if you want to keep it consistent and may support these features in the future, that's ok too

self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.image_processor = image_processor

def _get_t5_prompt_embeds(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a Copied from if it's same one as Wan


return prompt_embeds

def encode_image(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

image_encoder: CLIPVisionModel = None,
transformer: ChronoEditTransformer3DModel = None,
transformer_2: ChronoEditTransformer3DModel = None,
boundary_ratio: Optional[float] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
boundary_ratio: Optional[float] = None,

if we don't support the two stage denoising loop, let's remove parameter and all its related logic, to simplify the pipeline a bit

num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
guidance_scale_2: Optional[float] = None,

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
last_image: Optional[torch.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a image editing task and can output video to show the reasoning process, no? what would be a meaningful use case to also pass a last_iamge parameter here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

if self.config.boundary_ratio is not None and image_embeds is not None:
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")

def prepare_latents(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is same as in wan i2v too?
if you want to just add a #Copied from and keep this method as it is, it's fine! we can also just remove all the logics we don't need here related to last_frame and expand_timesteps

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's the same as in wan i2v. I add reference to original function and remove all the logics for wan2.2.

freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

assert num_frames == 2 or num_frames == self.temporal_skip_len, (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't understand this check here, I think after temporal reasoning step, mum_frames is 2, but other than that e.g. if temporal reasoning is not enabled, this dimension will have various lengths, based on the num_frames variable the users passed to pipeline, no?
if our model can only work with fixed num_frames, maybe we can throw an error from the pipeline when we check the inputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it works on num_frames >= 2. I've removed this check in latest commit.

@zhangjiewu
Copy link
Author

Hi @yiyixuxu, thanks for your review and suggestions! I’ve updated the code accordingly in the latest commit. Please feel free to make any further changes if needed.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking great! do you add a doc page to oin this PR?
also tests, but we can help with tests if you need

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2025

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 6, 2025

tests, can just follow what wan did
https://github.com/huggingface/diffusers/blob/main/tests/pipelines/wan/test_wan.py#L39
only need fast tests for Chrono for now I think! we don't need slow test for now

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zhangjiewu
Copy link
Author

zhangjiewu commented Nov 7, 2025

looking great! do you add a doc page to oin this PR? also tests, but we can help with tests if you need

test added. will work on the doc now :)

@zhangjiewu
Copy link
Author

@yiyixuxu docs has been added. 104e886

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 7, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

Style fix runs successfully without any file modified.

@dg845
Copy link
Collaborator

dg845 commented Nov 8, 2025

Hi @zhangjiewu, could you perform the following?

  1. Can you run make fix-copies so that the CI repository consistency check succeeds?
  2. Can you add the docs at api/models/chronoedit_transformer_3d to the _toctree as well? For reference, here is how WanTransformer3DModel is added:
    - local: api/models/wan_transformer_3d
    title: WanTransformer3DModel
    Otherwise, the docs will not build successfully.

Thanks!

@zhangjiewu
Copy link
Author

Hey @dg845, I’ve completed the two tasks you commented on. Thank you!

@dg845
Copy link
Collaborator

dg845 commented Nov 8, 2025

I see that tests/pipelines/chronoedit/test_chronoedit.py::ChronoEditPipelineFastTests::test_inference fails both on the CI and when I tried it locally because the generated_slice is not close enough to the expected_slice. Is this failure expected?

@zhangjiewu
Copy link
Author

zhangjiewu commented Nov 9, 2025

Hi @dg845, I got these errors when running pytest tests/pipelines/chronoedit/test_chronoedit.py, even for tests/pipelines/wan/test_wan_image_to_video.py. Any thoughts?

============================ short test summary info =============================
FAILED tests/pipelines/chronoedit/test_chronoedit.py::ChronoEditPipelineFastTests::test_inference - RuntimeError: Expected all tensors to be on the same device, but found at lea...
FAILED tests/pipelines/chronoedit/test_chronoedit.py::ChronoEditPipelineFastTests::test_save_load_float16 - RuntimeError: expected scalar type Float but found Half
============== 2 failed, 29 passed, 3 skipped, 3 warnings in 38.25s =============

Could you try if the following input works?

inputs = {
    "image": image,
    "prompt": "dance monkey",
    "negative_prompt": "negative",  # TODO
    "height": image_height,
    "width": image_width,
    "generator": generator,
    "num_inference_steps": 2,
    "guidance_scale": 6.0,
    "num_frames": 5,
    "max_sequence_length": 16,
    "output_type": "pt",
}
...
self.assertEqual(generated_video.shape, (5, 3, 16, 16))

@sayakpaul
Copy link
Member

For test_save_load_float16, #12500 might be relevant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants