Skip to content

Chroma Pipeline #11698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 117 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
ff0b9a3
working state from hameerabbasi and iddl
Ednaordinary Jun 10, 2025
3c2865c
working state form hameerabbasi and iddl (transformer)
Ednaordinary Jun 10, 2025
e271af9
working state (normalization)
Ednaordinary Jun 10, 2025
15f2bd5
working state (embeddings)
Ednaordinary Jun 10, 2025
32e6a00
add chroma loader
Ednaordinary Jun 10, 2025
bc36a0d
add chroma to mappings
Ednaordinary Jun 10, 2025
33ea0b6
add chroma to transformer init
Ednaordinary Jun 10, 2025
22ecd19
take out variant stuff
Ednaordinary Jun 10, 2025
b0df969
get decently far in changing variant stuff
Ednaordinary Jun 10, 2025
c8cbb31
add chroma init
Ednaordinary Jun 10, 2025
3265923
make chroma output class
Ednaordinary Jun 10, 2025
7400278
add chroma transformer to dummy tp
Ednaordinary Jun 12, 2025
c22930d
add chroma to init
Ednaordinary Jun 12, 2025
4e698b1
add chroma to init
Ednaordinary Jun 12, 2025
5eb4b82
fix single file
Ednaordinary Jun 12, 2025
f0c75b6
update
Ednaordinary Jun 12, 2025
6441e70
update
Ednaordinary Jun 12, 2025
a6f231c
add chroma to auto pipeline
Ednaordinary Jun 12, 2025
7445cf4
add chroma to pipeline init
Ednaordinary Jun 12, 2025
af918c8
change to chroma transformer
Ednaordinary Jun 12, 2025
2fcc75a
take out variant from blocks
Ednaordinary Jun 12, 2025
0b027a2
swap embedder location
Ednaordinary Jun 12, 2025
6c0aed1
remove prompt_2
Ednaordinary Jun 12, 2025
f190c02
work on swapping text encoders
Ednaordinary Jun 12, 2025
38429ff
remove mask function
Ednaordinary Jun 12, 2025
7c75d8e
dont modify mask (for now)
Ednaordinary Jun 12, 2025
c9b46af
wrap attn mask
Ednaordinary Jun 12, 2025
146255a
no attn mask (can't get it to work)
Ednaordinary Jun 12, 2025
3309ffe
remove pooled prompt embeds
Ednaordinary Jun 12, 2025
77b429e
change to my own unpooled embeddeer
Ednaordinary Jun 12, 2025
df7fde7
fix load
Ednaordinary Jun 12, 2025
68f771b
take pooled projections out of transformer
Ednaordinary Jun 12, 2025
a3b6697
Merge branch 'main' into chroma
Ednaordinary Jun 12, 2025
f783f38
ensure correct dtype for chroma embeddings
Ednaordinary Jun 12, 2025
f6de1af
update
Ednaordinary Jun 12, 2025
ab79421
use dn6 attn mask + fix true_cfg_scale
Ednaordinary Jun 12, 2025
442f77a
use chroma pipeline output
Ednaordinary Jun 12, 2025
e69d730
use DN6 embeddings
Ednaordinary Jun 12, 2025
01bc0dc
remove guidance
Ednaordinary Jun 12, 2025
e31c948
remove guidance embed (pipeline)
Ednaordinary Jun 12, 2025
406ab3b
remove guidance from embeddings
Ednaordinary Jun 12, 2025
1bd8fdf
don't return length
Ednaordinary Jun 12, 2025
2d57f3d
Merge branch 'main' into chroma
Ednaordinary Jun 12, 2025
3e2452d
dont change dtype
Ednaordinary Jun 12, 2025
1efa772
remove unused stuff, fix up docs
Ednaordinary Jun 12, 2025
619921c
add chroma autodoc
Ednaordinary Jun 12, 2025
f821f2a
add .md (oops)
Ednaordinary Jun 12, 2025
b0cf680
initial chroma docs
Ednaordinary Jun 12, 2025
0c5eb44
undo don't change dtype
Ednaordinary Jun 12, 2025
42c0e8e
undo arxiv change
Ednaordinary Jun 12, 2025
da846d1
fix hf papers regression in more places
Ednaordinary Jun 12, 2025
18327cb
Update docs/source/en/api/pipelines/chroma.md
Ednaordinary Jun 12, 2025
3f39b1a
do_cfg -> self.do_classifier_free_guidance
Ednaordinary Jun 12, 2025
a93e64d
Update docs/source/en/api/models/chroma_transformer.md
Ednaordinary Jun 12, 2025
3e36a21
Update chroma.md
Ednaordinary Jun 12, 2025
a1fac68
Move chroma layers into transformer
Ednaordinary Jun 12, 2025
1442c97
Remove pruned AdaLayerNorms
Ednaordinary Jun 12, 2025
03fbd52
Add chroma fast tests
Ednaordinary Jun 12, 2025
bedb320
(untested) batch cond and uncond
Ednaordinary Jun 12, 2025
fe5af79
Add # Copied from for shift
Ednaordinary Jun 12, 2025
6a0db55
Update # Copied from statements
Ednaordinary Jun 12, 2025
abf8a33
update norm imports
Ednaordinary Jun 12, 2025
7235805
Revert cond + uncond batching
Ednaordinary Jun 12, 2025
15ca813
Add transformer tests
Ednaordinary Jun 12, 2025
f8d4a1a
move chroma test (oops)
Ednaordinary Jun 12, 2025
c8d6aef
chroma init
Ednaordinary Jun 12, 2025
cfd5b34
fix chroma pipeline fast tests
Ednaordinary Jun 12, 2025
2347d53
Update src/diffusers/models/transformers/transformer_chroma.py
Ednaordinary Jun 12, 2025
d31cf81
Move Approximator and Embeddings
Ednaordinary Jun 12, 2025
c85e46b
Fix auto pipeline + make style, quality
Ednaordinary Jun 12, 2025
19733af
make style
DN6 Jun 13, 2025
f49b149
Apply style fixes
github-actions[bot] Jun 13, 2025
68b9cce
switch to new input ids
Ednaordinary Jun 13, 2025
ad01d63
Merge branch 'main' into chroma
Ednaordinary Jun 13, 2025
e97a4dd
fix # Copied from error
Ednaordinary Jun 13, 2025
fd36924
remove # Copied from on protected members
Ednaordinary Jun 13, 2025
2bc51c8
try to fix import
Ednaordinary Jun 13, 2025
523150f
fix import
Ednaordinary Jun 13, 2025
c330f08
make fix-copes
Ednaordinary Jun 13, 2025
381e64b
revert style fix
Ednaordinary Jun 13, 2025
f35ec17
Merge remote-tracking branch '11698/chroma' into chroma-final
DN6 Jun 13, 2025
35dc65b
update chroma transformer params
DN6 Jun 13, 2025
74fe45e
update chroma transformer approximator init params
DN6 Jun 13, 2025
926dcc6
update to pad tokens
DN6 Jun 13, 2025
89faa71
fix batch inference
DN6 Jun 13, 2025
829c6f1
Make more pipeline tests work
Ednaordinary Jun 13, 2025
272685c
Merge branch 'main' into chroma
Ednaordinary Jun 13, 2025
8766493
Make most transformer tests work
Ednaordinary Jun 13, 2025
60e41b7
Merge remote-tracking branch 'origin/chroma' into chroma
Ednaordinary Jun 13, 2025
28dea06
fix docs
Ednaordinary Jun 13, 2025
bea8b0d
make style, make quality
Ednaordinary Jun 13, 2025
00ebba9
skip batch tests
Ednaordinary Jun 13, 2025
2b6722e
fix test skipping
Ednaordinary Jun 13, 2025
de9a07f
fix test skipping again
Ednaordinary Jun 13, 2025
6735507
fix for tests
DN6 Jun 13, 2025
f1be3eb
Merge branch 'chroma-fork' into chroma-final
DN6 Jun 13, 2025
b85229e
Fix all pipeline test
Ednaordinary Jun 13, 2025
bf56c95
Merge branch 'chroma-fork' into chroma-final
DN6 Jun 13, 2025
292469d
update
DN6 Jun 13, 2025
178c4ec
push local changes, fix docs
Ednaordinary Jun 13, 2025
16b6e33
add encoder test, remove pooled dim
Ednaordinary Jun 13, 2025
06fb995
default proj dim
Ednaordinary Jun 13, 2025
49a4c8b
fix tests
Ednaordinary Jun 13, 2025
3fe4ad6
fix equal size list input
Ednaordinary Jun 13, 2025
41751a3
update
DN6 Jun 13, 2025
fd3e944
push local changes, fix docs
Ednaordinary Jun 13, 2025
8694f2c
add encoder test, remove pooled dim
Ednaordinary Jun 13, 2025
4e24f26
default proj dim
Ednaordinary Jun 13, 2025
0978b60
fix tests
Ednaordinary Jun 13, 2025
c711e8f
fix equal size list input
Ednaordinary Jun 13, 2025
589e939
Revert "fix equal size list input"
DN6 Jun 13, 2025
2b559e9
Merge branch 'chroma-fork' into chroma-final
DN6 Jun 13, 2025
a967e66
update
DN6 Jun 13, 2025
4f00bae
update
DN6 Jun 13, 2025
0497faa
update
DN6 Jun 13, 2025
e10f701
update
DN6 Jun 13, 2025
d267bb6
update
DN6 Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
Expand Down Expand Up @@ -405,6 +407,8 @@
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/cogview3
Expand Down
19 changes: 19 additions & 0 deletions docs/source/en/api/models/chroma_transformer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# ChromaTransformer2DModel

A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)

## ChromaTransformer2DModel

[[autodoc]] ChromaTransformer2DModel
71 changes: 71 additions & 0 deletions docs/source/en/api/pipelines/chroma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Chroma

<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>

Chroma is a text to image generation model based on Flux.

Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).

<Tip>

Chroma can use all the same optimizations as Flux.

</Tip>

## Inference (Single File)

The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.

The following example demonstrates how to run Chroma from a single file.

Then run the following example

```python
import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline
from transformers import T5EncoderModel

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)

text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)

pipe.enable_model_cpu_offload()

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=4.0,
output_type="pil",
num_inference_steps=26,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]

image.save("image.png")
```

## ChromaPipeline

[[autodoc]] ChromaPipeline
- all
- __call__
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"AutoencoderTiny",
"AutoModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
Expand Down Expand Up @@ -352,6 +353,7 @@
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChromaPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
Expand Down Expand Up @@ -768,6 +770,7 @@
AutoencoderTiny,
AutoModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
Expand Down Expand Up @@ -940,6 +943,7 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChromaPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
Expand Down Expand Up @@ -97,6 +98,10 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
Expand Down
169 changes: 169 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

return checkpoint


def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())

for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
num_guidance_layers = (
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
)
mlp_ratio = 4.0
inner_dim = 3072

# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight

# guidance
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.bias"
)
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.in_proj.weight"
)
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.bias"
)
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
"distilled_guidance_layer.out_proj.weight"
)
for i in range(num_guidance_layers):
block_prefix = f"distilled_guidance_layer.layers.{i}."
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
)
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
)
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
)
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
f"distilled_guidance_layer.norms.{i}.scale"
)

# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")

# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")

# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)

# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")

converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")

return converted_state_dict
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
Expand Down Expand Up @@ -151,6 +152,7 @@
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Expand Down Expand Up @@ -1325,7 +1325,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel
Expand Down
Loading