diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index f13b7d54aec4..5492dff04cae 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -283,6 +283,8 @@
         title: AllegroTransformer3DModel
       - local: api/models/aura_flow_transformer2d
         title: AuraFlowTransformer2DModel
+      - local: api/models/chroma_transformer
+        title: ChromaTransformer2DModel
       - local: api/models/cogvideox_transformer3d
         title: CogVideoXTransformer3DModel
       - local: api/models/cogview3plus_transformer2d
@@ -405,6 +407,8 @@
       title: AutoPipeline
     - local: api/pipelines/blip_diffusion
       title: BLIP-Diffusion
+    - local: api/pipelines/chroma
+      title: Chroma
     - local: api/pipelines/cogvideox
       title: CogVideoX
     - local: api/pipelines/cogview3
diff --git a/docs/source/en/api/models/chroma_transformer.md b/docs/source/en/api/models/chroma_transformer.md
new file mode 100644
index 000000000000..681e81f7a584
--- /dev/null
+++ b/docs/source/en/api/models/chroma_transformer.md
@@ -0,0 +1,19 @@
+<!--Copyright 2025 The HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+specific language governing permissions and limitations under the License.
+-->
+
+# ChromaTransformer2DModel
+
+A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
+
+## ChromaTransformer2DModel
+
+[[autodoc]] ChromaTransformer2DModel
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
new file mode 100644
index 000000000000..22448d88e06b
--- /dev/null
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -0,0 +1,71 @@
+<!--Copyright 2025 The HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+specific language governing permissions and limitations under the License.
+-->
+
+# Chroma
+
+<div class="flex flex-wrap space-x-1">
+  <img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
+  <img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
+</div>
+
+Chroma is a text to image generation model based on Flux.
+
+Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
+
+<Tip>
+
+Chroma can use all the same optimizations as Flux.
+
+</Tip>
+
+## Inference (Single File)
+
+The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
+
+The following example demonstrates how to run Chroma from a single file.
+
+Then run the following example
+
+```python
+import torch
+from diffusers import ChromaTransformer2DModel, ChromaPipeline
+from transformers import T5EncoderModel
+
+bfl_repo = "black-forest-labs/FLUX.1-dev"
+dtype = torch.bfloat16
+
+transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
+
+text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
+tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
+
+pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
+
+pipe.enable_model_cpu_offload()
+
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+    prompt,
+    guidance_scale=4.0,
+    output_type="pil",
+    num_inference_steps=26,
+    generator=torch.Generator("cpu").manual_seed(0)
+).images[0]
+
+image.save("image.png")
+```
+
+## ChromaPipeline
+
+[[autodoc]] ChromaPipeline
+	- all
+	- __call__
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index ce0777fdef68..27bbd3501680 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -159,6 +159,7 @@
             "AutoencoderTiny",
             "AutoModel",
             "CacheMixin",
+            "ChromaTransformer2DModel",
             "CogVideoXTransformer3DModel",
             "CogView3PlusTransformer2DModel",
             "CogView4Transformer2DModel",
@@ -352,6 +353,7 @@
             "AuraFlowPipeline",
             "BlipDiffusionControlNetPipeline",
             "BlipDiffusionPipeline",
+            "ChromaPipeline",
             "CLIPImageProjection",
             "CogVideoXFunControlPipeline",
             "CogVideoXImageToVideoPipeline",
@@ -768,6 +770,7 @@
             AutoencoderTiny,
             AutoModel,
             CacheMixin,
+            ChromaTransformer2DModel,
             CogVideoXTransformer3DModel,
             CogView3PlusTransformer2DModel,
             CogView4Transformer2DModel,
@@ -940,6 +943,7 @@
             AudioLDM2UNet2DConditionModel,
             AudioLDMPipeline,
             AuraFlowPipeline,
+            ChromaPipeline,
             CLIPImageProjection,
             CogVideoXFunControlPipeline,
             CogVideoXImageToVideoPipeline,
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 0480e93f356f..e7a458f28ef9 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -60,6 +60,7 @@
     "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
     "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
     "WanVACETransformer3DModel": lambda model_cls, weights: weights,
+    "ChromaTransformer2DModel": lambda model_cls, weights: weights,
 }
 
 
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 6919c4949d59..c2eb62ba1222 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -29,6 +29,7 @@
     convert_animatediff_checkpoint_to_diffusers,
     convert_auraflow_transformer_checkpoint_to_diffusers,
     convert_autoencoder_dc_checkpoint_to_diffusers,
+    convert_chroma_transformer_checkpoint_to_diffusers,
     convert_controlnet_checkpoint,
     convert_flux_transformer_checkpoint_to_diffusers,
     convert_hidream_transformer_to_diffusers,
@@ -97,6 +98,10 @@
         "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
         "default_subfolder": "transformer",
     },
+    "ChromaTransformer2DModel": {
+        "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
+        "default_subfolder": "transformer",
+    },
     "LTXVideoTransformer3DModel": {
         "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
         "default_subfolder": "transformer",
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index 0f762b949d47..d8d183304e9a 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -3310,3 +3310,172 @@ def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
             checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
 
     return checkpoint
+
+
+def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+    converted_state_dict = {}
+    keys = list(checkpoint.keys())
+
+    for k in keys:
+        if "model.diffusion_model." in k:
+            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+    num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1  # noqa: C401
+    num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1  # noqa: C401
+    num_guidance_layers = (
+        list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1  # noqa: C401
+    )
+    mlp_ratio = 4.0
+    inner_dim = 3072
+
+    # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+    # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+    def swap_scale_shift(weight):
+        shift, scale = weight.chunk(2, dim=0)
+        new_weight = torch.cat([scale, shift], dim=0)
+        return new_weight
+
+    # guidance
+    converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
+        "distilled_guidance_layer.in_proj.bias"
+    )
+    converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
+        "distilled_guidance_layer.in_proj.weight"
+    )
+    converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
+        "distilled_guidance_layer.out_proj.bias"
+    )
+    converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
+        "distilled_guidance_layer.out_proj.weight"
+    )
+    for i in range(num_guidance_layers):
+        block_prefix = f"distilled_guidance_layer.layers.{i}."
+        converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
+            f"distilled_guidance_layer.layers.{i}.in_layer.bias"
+        )
+        converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
+            f"distilled_guidance_layer.layers.{i}.in_layer.weight"
+        )
+        converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
+            f"distilled_guidance_layer.layers.{i}.out_layer.bias"
+        )
+        converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
+            f"distilled_guidance_layer.layers.{i}.out_layer.weight"
+        )
+        converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
+            f"distilled_guidance_layer.norms.{i}.scale"
+        )
+
+    # context_embedder
+    converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
+    converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
+
+    # x_embedder
+    converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
+    converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
+
+    # double transformer blocks
+    for i in range(num_layers):
+        block_prefix = f"transformer_blocks.{i}."
+        # Q, K, V
+        sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
+        context_q, context_k, context_v = torch.chunk(
+            checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+        )
+        sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+            checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+        )
+        context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+            checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+        )
+        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+        converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+        converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+        converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+        converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+        converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+        converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+        converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+        converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+        converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+        # qk_norm
+        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+        )
+        converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+        )
+        converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+        )
+        converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+        )
+        # ff img_mlp
+        converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.img_mlp.0.weight"
+        )
+        converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
+        converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
+        converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
+        converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_mlp.0.weight"
+        )
+        converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_mlp.0.bias"
+        )
+        converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_mlp.2.weight"
+        )
+        converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_mlp.2.bias"
+        )
+        # output projections.
+        converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.img_attn.proj.weight"
+        )
+        converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
+            f"double_blocks.{i}.img_attn.proj.bias"
+        )
+        converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_attn.proj.weight"
+        )
+        converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
+            f"double_blocks.{i}.txt_attn.proj.bias"
+        )
+
+    # single transformer blocks
+    for i in range(num_single_layers):
+        block_prefix = f"single_transformer_blocks.{i}."
+        # Q, K, V, mlp
+        mlp_hidden_dim = int(inner_dim * mlp_ratio)
+        split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+        q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+        q_bias, k_bias, v_bias, mlp_bias = torch.split(
+            checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+        )
+        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+        converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+        converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+        converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+        converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+        converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+        # qk norm
+        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
+            f"single_blocks.{i}.norm.query_norm.scale"
+        )
+        converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
+            f"single_blocks.{i}.norm.key_norm.scale"
+        )
+        # output projections.
+        converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
+        converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
+
+    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+    converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+
+    return converted_state_dict
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 8723fbca2187..b493d651f4ba 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -74,6 +74,7 @@
     _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
     _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
     _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
+    _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
     _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
     _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
     _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
@@ -151,6 +152,7 @@
         from .transformers import (
             AllegroTransformer3DModel,
             AuraFlowTransformer2DModel,
+            ChromaTransformer2DModel,
             CogVideoXTransformer3DModel,
             CogView3PlusTransformer2DModel,
             CogView4Transformer2DModel,
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 09e3621c2c7b..cfc501c47ed9 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -31,7 +31,7 @@ def get_timestep_embedding(
     downscale_freq_shift: float = 1,
     scale: float = 1,
     max_period: int = 10000,
-):
+) -> torch.Tensor:
     """
     This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
 
@@ -1325,7 +1325,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
         self.downscale_freq_shift = downscale_freq_shift
         self.scale = scale
 
-    def forward(self, timesteps):
+    def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
         t_emb = get_timestep_embedding(
             timesteps,
             self.num_channels,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index e7b8ba55ca61..cc03a0ccbcdf 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -17,6 +17,7 @@
     from .t5_film_transformer import T5FilmDecoder
     from .transformer_2d import Transformer2DModel
     from .transformer_allegro import AllegroTransformer3DModel
+    from .transformer_chroma import ChromaTransformer2DModel
     from .transformer_cogview3plus import CogView3PlusTransformer2DModel
     from .transformer_cogview4 import CogView4Transformer2DModel
     from .transformer_cosmos import CosmosTransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py
new file mode 100644
index 000000000000..2b415cfed2fe
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_chroma.py
@@ -0,0 +1,732 @@
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.import_utils import is_torch_npu_available
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_processor import (
+    Attention,
+    AttentionProcessor,
+    FluxAttnProcessor2_0,
+    FluxAttnProcessor2_0_NPU,
+    FusedFluxAttnProcessor2_0,
+)
+from ..cache_utils import CacheMixin
+from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+class ChromaAdaLayerNormZeroPruned(nn.Module):
+    r"""
+    Norm layer adaptive layer norm zero (adaLN-Zero).
+
+    Parameters:
+        embedding_dim (`int`): The size of each embedding vector.
+        num_embeddings (`int`): The size of the embeddings dictionary.
+    """
+
+    def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
+        super().__init__()
+        if num_embeddings is not None:
+            self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+        else:
+            self.emb = None
+
+        if norm_type == "layer_norm":
+            self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+        elif norm_type == "fp32_layer_norm":
+            self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
+        else:
+            raise ValueError(
+                f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+            )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        timestep: Optional[torch.Tensor] = None,
+        class_labels: Optional[torch.LongTensor] = None,
+        hidden_dtype: Optional[torch.dtype] = None,
+        emb: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        if self.emb is not None:
+            emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
+        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
+        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
+    r"""
+    Norm layer adaptive layer norm zero (adaLN-Zero).
+
+    Parameters:
+        embedding_dim (`int`): The size of each embedding vector.
+        num_embeddings (`int`): The size of the embeddings dictionary.
+    """
+
+    def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+        super().__init__()
+
+        if norm_type == "layer_norm":
+            self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+        else:
+            raise ValueError(
+                f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+            )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        emb: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
+        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+        return x, gate_msa
+
+
+class ChromaAdaLayerNormContinuousPruned(nn.Module):
+    r"""
+    Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
+
+    Args:
+        embedding_dim (`int`): Embedding dimension to use during projection.
+        conditioning_embedding_dim (`int`): Dimension of the input condition.
+        elementwise_affine (`bool`, defaults to `True`):
+            Boolean flag to denote if affine transformation should be applied.
+        eps (`float`, defaults to 1e-5): Epsilon factor.
+        bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
+        norm_type (`str`, defaults to `"layer_norm"`):
+            Normalization layer to use. Values supported: "layer_norm", "rms_norm".
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        conditioning_embedding_dim: int,
+        # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+        # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+        # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+        # However, this is how it was implemented in the original code, and it's rather likely you should
+        # set `elementwise_affine` to False.
+        elementwise_affine=True,
+        eps=1e-5,
+        bias=True,
+        norm_type="layer_norm",
+    ):
+        super().__init__()
+        if norm_type == "layer_norm":
+            self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+        elif norm_type == "rms_norm":
+            self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+        else:
+            raise ValueError(f"unknown norm_type {norm_type}")
+
+    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+        # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+        shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
+        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+        return x
+
+
+class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
+    def __init__(self, num_channels: int, out_dim: int):
+        super().__init__()
+
+        self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+        self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+
+        self.register_buffer(
+            "mod_proj",
+            get_timestep_embedding(
+                torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
+            ),
+            persistent=False,
+        )
+
+    def forward(self, timestep: torch.Tensor) -> torch.Tensor:
+        mod_index_length = self.mod_proj.shape[0]
+        batch_size = timestep.shape[0]
+
+        timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
+        guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
+            dtype=timestep.dtype, device=timestep.device
+        )
+
+        mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
+        timestep_guidance = (
+            torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
+        )
+        input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
+        return input_vec.to(timestep.dtype)
+
+
+class ChromaApproximator(nn.Module):
+    def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
+        super().__init__()
+        self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
+        self.layers = nn.ModuleList(
+            [PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
+        )
+        self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
+        self.out_proj = nn.Linear(hidden_dim, out_dim)
+
+    def forward(self, x):
+        x = self.in_proj(x)
+
+        for layer, norms in zip(self.layers, self.norms):
+            x = x + layer(norms(x))
+
+        return self.out_proj(x)
+
+
+@maybe_allow_in_graph
+class ChromaSingleTransformerBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        mlp_ratio: float = 4.0,
+    ):
+        super().__init__()
+        self.mlp_hidden_dim = int(dim * mlp_ratio)
+        self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
+        self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+        self.act_mlp = nn.GELU(approximate="tanh")
+        self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+        if is_torch_npu_available():
+            deprecation_message = (
+                "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
+                "should be set explicitly using the `set_attn_processor` method."
+            )
+            deprecate("npu_processor", "0.34.0", deprecation_message)
+            processor = FluxAttnProcessor2_0_NPU()
+        else:
+            processor = FluxAttnProcessor2_0()
+
+        self.attn = Attention(
+            query_dim=dim,
+            cross_attention_dim=None,
+            dim_head=attention_head_dim,
+            heads=num_attention_heads,
+            out_dim=dim,
+            bias=True,
+            processor=processor,
+            qk_norm="rms_norm",
+            eps=1e-6,
+            pre_only=True,
+        )
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        temb: torch.Tensor,
+        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+        joint_attention_kwargs = joint_attention_kwargs or {}
+        attn_output = self.attn(
+            hidden_states=norm_hidden_states,
+            image_rotary_emb=image_rotary_emb,
+            **joint_attention_kwargs,
+        )
+
+        hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+        gate = gate.unsqueeze(1)
+        hidden_states = gate * self.proj_out(hidden_states)
+        hidden_states = residual + hidden_states
+        if hidden_states.dtype == torch.float16:
+            hidden_states = hidden_states.clip(-65504, 65504)
+
+        return hidden_states
+
+
+@maybe_allow_in_graph
+class ChromaTransformerBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        qk_norm: str = "rms_norm",
+        eps: float = 1e-6,
+    ):
+        super().__init__()
+        self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
+        self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
+
+        self.attn = Attention(
+            query_dim=dim,
+            cross_attention_dim=None,
+            added_kv_proj_dim=dim,
+            dim_head=attention_head_dim,
+            heads=num_attention_heads,
+            out_dim=dim,
+            context_pre_only=False,
+            bias=True,
+            processor=FluxAttnProcessor2_0(),
+            qk_norm=qk_norm,
+            eps=eps,
+        )
+
+        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+        self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+        self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: torch.Tensor,
+        temb: torch.Tensor,
+        image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        temb_img, temb_txt = temb[:, :6], temb[:, 6:]
+        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
+
+        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+            encoder_hidden_states, emb=temb_txt
+        )
+        joint_attention_kwargs = joint_attention_kwargs or {}
+        # Attention.
+        attention_outputs = self.attn(
+            hidden_states=norm_hidden_states,
+            encoder_hidden_states=norm_encoder_hidden_states,
+            image_rotary_emb=image_rotary_emb,
+            **joint_attention_kwargs,
+        )
+
+        if len(attention_outputs) == 2:
+            attn_output, context_attn_output = attention_outputs
+        elif len(attention_outputs) == 3:
+            attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+        # Process attention outputs for the `hidden_states`.
+        attn_output = gate_msa.unsqueeze(1) * attn_output
+        hidden_states = hidden_states + attn_output
+
+        norm_hidden_states = self.norm2(hidden_states)
+        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+        ff_output = self.ff(norm_hidden_states)
+        ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+        hidden_states = hidden_states + ff_output
+        if len(attention_outputs) == 3:
+            hidden_states = hidden_states + ip_attn_output
+
+        # Process attention outputs for the `encoder_hidden_states`.
+
+        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+        encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+        context_ff_output = self.ff_context(norm_encoder_hidden_states)
+        encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+        if encoder_hidden_states.dtype == torch.float16:
+            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+        return encoder_hidden_states, hidden_states
+
+
+class ChromaTransformer2DModel(
+    ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+):
+    """
+    The Transformer model introduced in Flux, modified for Chroma.
+
+    Reference: https://huggingface.co/lodestones/Chroma
+
+    Args:
+        patch_size (`int`, defaults to `1`):
+            Patch size to turn the input data into small patches.
+        in_channels (`int`, defaults to `64`):
+            The number of channels in the input.
+        out_channels (`int`, *optional*, defaults to `None`):
+            The number of channels in the output. If not specified, it defaults to `in_channels`.
+        num_layers (`int`, defaults to `19`):
+            The number of layers of dual stream DiT blocks to use.
+        num_single_layers (`int`, defaults to `38`):
+            The number of layers of single stream DiT blocks to use.
+        attention_head_dim (`int`, defaults to `128`):
+            The number of dimensions to use for each attention head.
+        num_attention_heads (`int`, defaults to `24`):
+            The number of attention heads to use.
+        joint_attention_dim (`int`, defaults to `4096`):
+            The number of dimensions to use for the joint attention (embedding/channel dimension of
+            `encoder_hidden_states`).
+        axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+            The dimensions to use for the rotary positional embeddings.
+    """
+
+    _supports_gradient_checkpointing = True
+    _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
+    _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+
+    @register_to_config
+    def __init__(
+        self,
+        patch_size: int = 1,
+        in_channels: int = 64,
+        out_channels: Optional[int] = None,
+        num_layers: int = 19,
+        num_single_layers: int = 38,
+        attention_head_dim: int = 128,
+        num_attention_heads: int = 24,
+        joint_attention_dim: int = 4096,
+        axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
+        approximator_num_channels: int = 64,
+        approximator_hidden_dim: int = 5120,
+        approximator_layers: int = 5,
+    ):
+        super().__init__()
+        self.out_channels = out_channels or in_channels
+        self.inner_dim = num_attention_heads * attention_head_dim
+
+        self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+
+        self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
+            num_channels=approximator_num_channels // 4,
+            out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
+        )
+        self.distilled_guidance_layer = ChromaApproximator(
+            in_dim=approximator_num_channels,
+            out_dim=self.inner_dim,
+            hidden_dim=approximator_hidden_dim,
+            n_layers=approximator_layers,
+        )
+
+        self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+        self.x_embedder = nn.Linear(in_channels, self.inner_dim)
+
+        self.transformer_blocks = nn.ModuleList(
+            [
+                ChromaTransformerBlock(
+                    dim=self.inner_dim,
+                    num_attention_heads=num_attention_heads,
+                    attention_head_dim=attention_head_dim,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+
+        self.single_transformer_blocks = nn.ModuleList(
+            [
+                ChromaSingleTransformerBlock(
+                    dim=self.inner_dim,
+                    num_attention_heads=num_attention_heads,
+                    attention_head_dim=attention_head_dim,
+                )
+                for _ in range(num_single_layers)
+            ]
+        )
+
+        self.norm_out = ChromaAdaLayerNormContinuousPruned(
+            self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
+        )
+        self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+        self.gradient_checkpointing = False
+
+    @property
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+    def attn_processors(self) -> Dict[str, AttentionProcessor]:
+        r"""
+        Returns:
+            `dict` of attention processors: A dictionary containing all attention processors used in the model with
+            indexed by its weight name.
+        """
+        # set recursively
+        processors = {}
+
+        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+            if hasattr(module, "get_processor"):
+                processors[f"{name}.processor"] = module.get_processor()
+
+            for sub_name, child in module.named_children():
+                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+            return processors
+
+        for name, module in self.named_children():
+            fn_recursive_add_processors(name, module, processors)
+
+        return processors
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+        r"""
+        Sets the attention processor to use to compute attention.
+
+        Parameters:
+            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+                The instantiated processor class or a dictionary of processor classes that will be set as the processor
+                for **all** `Attention` layers.
+
+                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+                processor. This is strongly recommended when setting trainable attention processors.
+
+        """
+        count = len(self.attn_processors.keys())
+
+        if isinstance(processor, dict) and len(processor) != count:
+            raise ValueError(
+                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+            )
+
+        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+            if hasattr(module, "set_processor"):
+                if not isinstance(processor, dict):
+                    module.set_processor(processor)
+                else:
+                    module.set_processor(processor.pop(f"{name}.processor"))
+
+            for sub_name, child in module.named_children():
+                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+        for name, module in self.named_children():
+            fn_recursive_attn_processor(name, module, processor)
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
+    def fuse_qkv_projections(self):
+        """
+        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+        are fused. For cross-attention modules, key and value projection matrices are fused.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+        """
+        self.original_attn_processors = None
+
+        for _, attn_processor in self.attn_processors.items():
+            if "Added" in str(attn_processor.__class__.__name__):
+                raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+        self.original_attn_processors = self.attn_processors
+
+        for module in self.modules():
+            if isinstance(module, Attention):
+                module.fuse_projections(fuse=True)
+
+        self.set_attn_processor(FusedFluxAttnProcessor2_0())
+
+    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+    def unfuse_qkv_projections(self):
+        """Disables the fused QKV projection if enabled.
+
+        <Tip warning={true}>
+
+        This API is 🧪 experimental.
+
+        </Tip>
+
+        """
+        if self.original_attn_processors is not None:
+            self.set_attn_processor(self.original_attn_processors)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: torch.Tensor = None,
+        timestep: torch.LongTensor = None,
+        img_ids: torch.Tensor = None,
+        txt_ids: torch.Tensor = None,
+        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+        controlnet_block_samples=None,
+        controlnet_single_block_samples=None,
+        return_dict: bool = True,
+        controlnet_blocks_repeat: bool = False,
+    ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+        """
+        The [`FluxTransformer2DModel`] forward method.
+
+        Args:
+            hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+                Input `hidden_states`.
+            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+            timestep ( `torch.LongTensor`):
+                Used to indicate denoising step.
+            block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+                A list of tensors that if specified are added to the residuals of transformer blocks.
+            joint_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+                tuple.
+
+        Returns:
+            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+            `tuple` where the first element is the sample tensor.
+        """
+        if joint_attention_kwargs is not None:
+            joint_attention_kwargs = joint_attention_kwargs.copy()
+            lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+        else:
+            lora_scale = 1.0
+
+        if USE_PEFT_BACKEND:
+            # weight the lora layers by setting `lora_scale` for each PEFT layer
+            scale_lora_layers(self, lora_scale)
+        else:
+            if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+                logger.warning(
+                    "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+                )
+
+        hidden_states = self.x_embedder(hidden_states)
+
+        timestep = timestep.to(hidden_states.dtype) * 1000
+
+        input_vec = self.time_text_embed(timestep)
+        pooled_temb = self.distilled_guidance_layer(input_vec)
+
+        encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+        if txt_ids.ndim == 3:
+            logger.warning(
+                "Passing `txt_ids` 3d torch.Tensor is deprecated."
+                "Please remove the batch dimension and pass it as a 2d torch Tensor"
+            )
+            txt_ids = txt_ids[0]
+        if img_ids.ndim == 3:
+            logger.warning(
+                "Passing `img_ids` 3d torch.Tensor is deprecated."
+                "Please remove the batch dimension and pass it as a 2d torch Tensor"
+            )
+            img_ids = img_ids[0]
+
+        ids = torch.cat((txt_ids, img_ids), dim=0)
+        image_rotary_emb = self.pos_embed(ids)
+
+        if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+            ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+            ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+            joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+        for index_block, block in enumerate(self.transformer_blocks):
+            img_offset = 3 * len(self.single_transformer_blocks)
+            txt_offset = img_offset + 6 * len(self.transformer_blocks)
+            img_modulation = img_offset + 6 * index_block
+            text_modulation = txt_offset + 6 * index_block
+            temb = torch.cat(
+                (
+                    pooled_temb[:, img_modulation : img_modulation + 6],
+                    pooled_temb[:, text_modulation : text_modulation + 6],
+                ),
+                dim=1,
+            )
+            if torch.is_grad_enabled() and self.gradient_checkpointing:
+                encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+                    block,
+                    hidden_states,
+                    encoder_hidden_states,
+                    temb,
+                    image_rotary_emb,
+                )
+
+            else:
+                encoder_hidden_states, hidden_states = block(
+                    hidden_states=hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    temb=temb,
+                    image_rotary_emb=image_rotary_emb,
+                    joint_attention_kwargs=joint_attention_kwargs,
+                )
+
+            # controlnet residual
+            if controlnet_block_samples is not None:
+                interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+                interval_control = int(np.ceil(interval_control))
+                # For Xlabs ControlNet.
+                if controlnet_blocks_repeat:
+                    hidden_states = (
+                        hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+                    )
+                else:
+                    hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+        for index_block, block in enumerate(self.single_transformer_blocks):
+            start_idx = 3 * index_block
+            temb = pooled_temb[:, start_idx : start_idx + 3]
+            if torch.is_grad_enabled() and self.gradient_checkpointing:
+                hidden_states = self._gradient_checkpointing_func(
+                    block,
+                    hidden_states,
+                    temb,
+                    image_rotary_emb,
+                )
+
+            else:
+                hidden_states = block(
+                    hidden_states=hidden_states,
+                    temb=temb,
+                    image_rotary_emb=image_rotary_emb,
+                    joint_attention_kwargs=joint_attention_kwargs,
+                )
+
+            # controlnet residual
+            if controlnet_single_block_samples is not None:
+                interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+                interval_control = int(np.ceil(interval_control))
+                hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+                    hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+                    + controlnet_single_block_samples[index_block // interval_control]
+                )
+
+        hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+        temb = pooled_temb[:, -2:]
+        hidden_states = self.norm_out(hidden_states, temb)
+        output = self.proj_out(hidden_states)
+
+        if USE_PEFT_BACKEND:
+            # remove `lora_scale` from each PEFT layer
+            unscale_lora_layers(self, lora_scale)
+
+        if not return_dict:
+            return (output,)
+
+        return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 268e5c2a8c39..058411bd65f9 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -148,6 +148,7 @@
         "AudioLDM2UNet2DConditionModel",
     ]
     _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
+    _import_structure["chroma"] = ["ChromaPipeline"]
     _import_structure["cogvideo"] = [
         "CogVideoXPipeline",
         "CogVideoXImageToVideoPipeline",
@@ -531,6 +532,7 @@
         )
         from .aura_flow import AuraFlowPipeline
         from .blip_diffusion import BlipDiffusionPipeline
+        from .chroma import ChromaPipeline
         from .cogvideo import (
             CogVideoXFunControlPipeline,
             CogVideoXImageToVideoPipeline,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index ed8ad79ca781..b1a7ffaaea9c 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -21,6 +21,7 @@
 from ..models.controlnets import ControlNetUnionModel
 from ..utils import is_sentencepiece_available
 from .aura_flow import AuraFlowPipeline
+from .chroma import ChromaPipeline
 from .cogview3 import CogView3PlusPipeline
 from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
 from .controlnet import (
@@ -143,6 +144,7 @@
         ("flux-controlnet", FluxControlNetPipeline),
         ("lumina", LuminaPipeline),
         ("lumina2", Lumina2Pipeline),
+        ("chroma", ChromaPipeline),
         ("cogview3", CogView3PlusPipeline),
         ("cogview4", CogView4Pipeline),
         ("cogview4-control", CogView4ControlPipeline),
diff --git a/src/diffusers/pipelines/chroma/__init__.py b/src/diffusers/pipelines/chroma/__init__.py
new file mode 100644
index 000000000000..9faa7902a15c
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+    DIFFUSERS_SLOW_IMPORT,
+    OptionalDependencyNotAvailable,
+    _LazyModule,
+    get_objects_from_module,
+    is_torch_available,
+    is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
+
+try:
+    if not (is_transformers_available() and is_torch_available()):
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from ...utils import dummy_torch_and_transformers_objects  # noqa F403
+
+    _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+    _import_structure["pipeline_chroma"] = ["ChromaPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+    try:
+        if not (is_transformers_available() and is_torch_available()):
+            raise OptionalDependencyNotAvailable()
+    except OptionalDependencyNotAvailable:
+        from ...utils.dummy_torch_and_transformers_objects import *  # noqa F403
+    else:
+        from .pipeline_chroma import ChromaPipeline
+else:
+    import sys
+
+    sys.modules[__name__] = _LazyModule(
+        __name__,
+        globals()["__file__"],
+        _import_structure,
+        module_spec=__spec__,
+    )
+
+    for name, value in _dummy_objects.items():
+        setattr(sys.modules[__name__], name, value)
+    for name, value in _additional_imports.items():
+        setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py
new file mode 100644
index 000000000000..c111458d3320
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py
@@ -0,0 +1,863 @@
+# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ChromaTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+    USE_PEFT_BACKEND,
+    is_torch_xla_available,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import ChromaPipelineOutput
+
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import ChromaPipeline
+
+        >>> pipe = ChromaPipeline.from_single_file(
+        ...     "chroma-unlocked-v35-detail-calibrated.safetensors", torch_dtype=torch.bfloat16
+        ... )
+        >>> pipe.to("cuda")
+        >>> prompt = "A cat holding a sign that says hello world"
+        >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
+        >>> image.save("chroma.png")
+        ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+    image_seq_len,
+    base_seq_len: int = 256,
+    max_seq_len: int = 4096,
+    base_shift: float = 0.5,
+    max_shift: float = 1.15,
+):
+    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+    b = base_shift - m * base_seq_len
+    mu = image_seq_len * m + b
+    return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    sigmas: Optional[List[float]] = None,
+    **kwargs,
+):
+    r"""
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+            must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+            `num_inference_steps` and `sigmas` must be `None`.
+        sigmas (`List[float]`, *optional*):
+            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+            `num_inference_steps` and `timesteps` must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None and sigmas is not None:
+        raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    elif sigmas is not None:
+        accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accept_sigmas:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" sigmas schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+class ChromaPipeline(
+    DiffusionPipeline,
+    FluxLoraLoaderMixin,
+    FromSingleFileMixin,
+    TextualInversionLoaderMixin,
+    FluxIPAdapterMixin,
+):
+    r"""
+    The Chroma pipeline for text-to-image generation.
+
+    Reference: https://huggingface.co/lodestones/Chroma/
+
+    Args:
+        transformer ([`ChromaTransformer2DModel`]):
+            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
+        text_encoder ([`T5EncoderModel`]):
+            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+        tokenizer (`T5TokenizerFast`):
+            Second Tokenizer of class
+            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+    """
+
+    model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+    _optional_components = ["image_encoder", "feature_extractor"]
+    _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+    def __init__(
+        self,
+        scheduler: FlowMatchEulerDiscreteScheduler,
+        vae: AutoencoderKL,
+        text_encoder: T5EncoderModel,
+        tokenizer: T5TokenizerFast,
+        transformer: ChromaTransformer2DModel,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        feature_extractor: CLIPImageProcessor = None,
+    ):
+        super().__init__()
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            tokenizer=tokenizer,
+            transformer=transformer,
+            scheduler=scheduler,
+            image_encoder=image_encoder,
+            feature_extractor=feature_extractor,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+        # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+        # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+        self.default_sample_size = 128
+
+    def _get_t5_prompt_embeds(
+        self,
+        prompt: Union[str, List[str]] = None,
+        num_images_per_prompt: int = 1,
+        max_sequence_length: int = 512,
+        device: Optional[torch.device] = None,
+        dtype: Optional[torch.dtype] = None,
+    ):
+        device = device or self._execution_device
+        dtype = dtype or self.text_encoder.dtype
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+        batch_size = len(prompt)
+
+        if isinstance(self, TextualInversionLoaderMixin):
+            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+        text_inputs = self.tokenizer(
+            prompt,
+            padding="max_length",
+            max_length=max_sequence_length,
+            truncation=True,
+            return_length=False,
+            return_overflowing_tokens=False,
+            return_tensors="pt",
+        )
+        text_input_ids = text_inputs.input_ids
+        attention_mask = text_inputs.attention_mask.clone()
+
+        # Chroma requires the attention mask to include one padding token
+        seq_lengths = attention_mask.sum(dim=1)
+        mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
+        attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
+
+        prompt_embeds = self.text_encoder(
+            text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
+        )[0]
+
+        dtype = self.text_encoder.dtype
+        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+        _, seq_len, _ = prompt_embeds.shape
+
+        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        return prompt_embeds
+
+    def encode_prompt(
+        self,
+        prompt: Union[str, List[str]],
+        negative_prompt: Union[str, List[str]] = None,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        do_classifier_free_guidance: bool = True,
+        max_sequence_length: int = 512,
+        lora_scale: Optional[float] = None,
+    ):
+        r"""
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+                instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if self.text_encoder is not None and USE_PEFT_BACKEND:
+                scale_lora_layers(self.text_encoder, lora_scale)
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        if prompt_embeds is None:
+            prompt_embeds = self._get_t5_prompt_embeds(
+                prompt=prompt,
+                num_images_per_prompt=num_images_per_prompt,
+                max_sequence_length=max_sequence_length,
+                device=device,
+            )
+
+        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+        negative_text_ids = None
+
+        if do_classifier_free_guidance:
+            if negative_prompt_embeds is None:
+                negative_prompt = negative_prompt or ""
+                negative_prompt = (
+                    batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+                )
+
+                if prompt is not None and type(prompt) is not type(negative_prompt):
+                    raise TypeError(
+                        f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                        f" {type(prompt)}."
+                    )
+                elif batch_size != len(negative_prompt):
+                    raise ValueError(
+                        f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                        f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                        " the batch size of `prompt`."
+                    )
+
+                negative_prompt_embeds = self._get_t5_prompt_embeds(
+                    prompt=negative_prompt,
+                    num_images_per_prompt=num_images_per_prompt,
+                    max_sequence_length=max_sequence_length,
+                    device=device,
+                )
+            negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+        if self.text_encoder is not None:
+            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        return prompt_embeds, text_ids, negative_prompt_embeds, negative_text_ids
+
+    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        image_embeds = self.image_encoder(image).image_embeds
+        image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+        return image_embeds
+
+    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+    ):
+        image_embeds = []
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+                )
+
+            for single_ip_adapter_image in ip_adapter_image:
+                single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+                image_embeds.append(single_image_embeds[None, :])
+        else:
+            if not isinstance(ip_adapter_image_embeds, list):
+                ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+            if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+                )
+
+            for single_image_embeds in ip_adapter_image_embeds:
+                image_embeds.append(single_image_embeds)
+
+        ip_adapter_image_embeds = []
+        for single_image_embeds in image_embeds:
+            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+            single_image_embeds = single_image_embeds.to(device=device)
+            ip_adapter_image_embeds.append(single_image_embeds)
+
+        return ip_adapter_image_embeds
+
+    def check_inputs(
+        self,
+        prompt,
+        height,
+        width,
+        negative_prompt=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+        max_sequence_length=None,
+    ):
+        if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+            logger.warning(
+                f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if max_sequence_length is not None and max_sequence_length > 512:
+            raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+    @staticmethod
+    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+        latent_image_ids = torch.zeros(height, width, 3)
+        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+        latent_image_ids = latent_image_ids.reshape(
+            latent_image_id_height * latent_image_id_width, latent_image_id_channels
+        )
+
+        return latent_image_ids.to(device=device, dtype=dtype)
+
+    @staticmethod
+    def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+        latents = latents.permute(0, 2, 4, 1, 3, 5)
+        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+        return latents
+
+    @staticmethod
+    def _unpack_latents(latents, height, width, vae_scale_factor):
+        batch_size, num_patches, channels = latents.shape
+
+        # VAE applies 8x compression on images but we must also account for packing which requires
+        # latent height and width to be divisible by 2.
+        height = 2 * (int(height) // (vae_scale_factor * 2))
+        width = 2 * (int(width) // (vae_scale_factor * 2))
+
+        latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+        latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+        latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+        return latents
+
+    def enable_vae_slicing(self):
+        r"""
+        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+        """
+        self.vae.enable_slicing()
+
+    def disable_vae_slicing(self):
+        r"""
+        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_slicing()
+
+    def enable_vae_tiling(self):
+        r"""
+        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+        processing larger images.
+        """
+        self.vae.enable_tiling()
+
+    def disable_vae_tiling(self):
+        r"""
+        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+        computing decoding in one step.
+        """
+        self.vae.disable_tiling()
+
+    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+    def prepare_latents(
+        self,
+        batch_size,
+        num_channels_latents,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        latents=None,
+    ):
+        # VAE applies 8x compression on images but we must also account for packing which requires
+        # latent height and width to be divisible by 2.
+        height = 2 * (int(height) // (self.vae_scale_factor * 2))
+        width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+        shape = (batch_size, num_channels_latents, height, width)
+
+        if latents is not None:
+            latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+            return latents.to(device=device, dtype=dtype), latent_image_ids
+
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+        latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+        return latents, latent_image_ids
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def joint_attention_kwargs(self):
+        return self._joint_attention_kwargs
+
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def current_timestep(self):
+        return self._current_timestep
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        negative_prompt: Union[str, List[str]] = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 28,
+        sigmas: Optional[List[float]] = None,
+        guidance_scale: float = 3.5,
+        num_images_per_prompt: Optional[int] = 1,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+        negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+        negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        max_sequence_length: int = 512,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                not greater than `1`).
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            sigmas (`List[float]`, *optional*):
+                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+                will be used.
+            guidance_scale (`float`, *optional*, defaults to 3.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+                provided, embeddings are computed from the `ip_adapter_image` input argument.
+            negative_ip_adapter_image:
+                (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+                provided, embeddings are computed from the `ip_adapter_image` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
+            joint_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
+            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+            generated images.
+        """
+
+        height = height or self.default_sample_size * self.vae_scale_factor
+        width = width or self.default_sample_size * self.vae_scale_factor
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            height,
+            width,
+            negative_prompt=negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+            max_sequence_length=max_sequence_length,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._joint_attention_kwargs = joint_attention_kwargs
+        self._current_timestep = None
+        self._interrupt = False
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        lora_scale = (
+            self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+        )
+        (
+            prompt_embeds,
+            text_ids,
+            negative_prompt_embeds,
+            negative_text_ids,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            negative_prompt=negative_prompt,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            do_classifier_free_guidance=self.do_classifier_free_guidance,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            max_sequence_length=max_sequence_length,
+            lora_scale=lora_scale,
+        )
+
+        # 4. Prepare latent variables
+        num_channels_latents = self.transformer.config.in_channels // 4
+        latents, latent_image_ids = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+        # 5. Prepare timesteps
+        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+        image_seq_len = latents.shape[1]
+        mu = calculate_shift(
+            image_seq_len,
+            self.scheduler.config.get("base_image_seq_len", 256),
+            self.scheduler.config.get("max_image_seq_len", 4096),
+            self.scheduler.config.get("base_shift", 0.5),
+            self.scheduler.config.get("max_shift", 1.15),
+        )
+        timesteps, num_inference_steps = retrieve_timesteps(
+            self.scheduler,
+            num_inference_steps,
+            device,
+            sigmas=sigmas,
+            mu=mu,
+        )
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+        self._num_timesteps = len(timesteps)
+
+        if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+            negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+        ):
+            negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+            negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+        elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+            negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+        ):
+            ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+            ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+        if self.joint_attention_kwargs is None:
+            self._joint_attention_kwargs = {}
+
+        image_embeds = None
+        negative_image_embeds = None
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+            )
+        if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+            negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+                negative_ip_adapter_image,
+                negative_ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+            )
+
+        # 6. Denoising loop
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                self._current_timestep = t
+                if image_embeds is not None:
+                    self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+                timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+                noise_pred = self.transformer(
+                    hidden_states=latents,
+                    timestep=timestep / 1000,
+                    encoder_hidden_states=prompt_embeds,
+                    txt_ids=text_ids,
+                    img_ids=latent_image_ids,
+                    joint_attention_kwargs=self.joint_attention_kwargs,
+                    return_dict=False,
+                )[0]
+
+                if self.do_classifier_free_guidance:
+                    if negative_image_embeds is not None:
+                        self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+                    neg_noise_pred = self.transformer(
+                        hidden_states=latents,
+                        timestep=timestep / 1000,
+                        encoder_hidden_states=negative_prompt_embeds,
+                        txt_ids=negative_text_ids,
+                        img_ids=latent_image_ids,
+                        joint_attention_kwargs=self.joint_attention_kwargs,
+                        return_dict=False,
+                    )[0]
+                    noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents_dtype = latents.dtype
+                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+                if latents.dtype != latents_dtype:
+                    if torch.backends.mps.is_available():
+                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+                        latents = latents.to(latents_dtype)
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        self._current_timestep = None
+
+        if output_type == "latent":
+            image = latents
+        else:
+            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+            image = self.vae.decode(latents, return_dict=False)[0]
+            image = self.image_processor.postprocess(image, output_type=output_type)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return ChromaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/chroma/pipeline_output.py b/src/diffusers/pipelines/chroma/pipeline_output.py
new file mode 100644
index 000000000000..951d132dba2e
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class ChromaPipelineOutput(BaseOutput):
+    """
+    Output class for Stable Diffusion pipelines.
+
+    Args:
+        images (`List[PIL.Image.Image]` or `np.ndarray`)
+            List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+            num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+    """
+
+    images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 24b3c3d7be59..2981f3a420d6 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -325,6 +325,21 @@ def from_pretrained(cls, *args, **kwargs):
         requires_backends(cls, ["torch"])
 
 
+class ChromaTransformer2DModel(metaclass=DummyObject):
+    _backends = ["torch"]
+
+    def __init__(self, *args, **kwargs):
+        requires_backends(self, ["torch"])
+
+    @classmethod
+    def from_config(cls, *args, **kwargs):
+        requires_backends(cls, ["torch"])
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        requires_backends(cls, ["torch"])
+
+
 class CogVideoXTransformer3DModel(metaclass=DummyObject):
     _backends = ["torch"]
 
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index cc8f3e01ee78..deebdc757faa 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -272,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs):
         requires_backends(cls, ["torch", "transformers"])
 
 
+class ChromaPipeline(metaclass=DummyObject):
+    _backends = ["torch", "transformers"]
+
+    def __init__(self, *args, **kwargs):
+        requires_backends(self, ["torch", "transformers"])
+
+    @classmethod
+    def from_config(cls, *args, **kwargs):
+        requires_backends(cls, ["torch", "transformers"])
+
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs):
+        requires_backends(cls, ["torch", "transformers"])
+
+
 class CLIPImageProjection(metaclass=DummyObject):
     _backends = ["torch", "transformers"]
 
diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py
new file mode 100644
index 000000000000..93df7ca35c4a
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_chroma.py
@@ -0,0 +1,183 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import ChromaTransformer2DModel
+from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
+from diffusers.models.embeddings import ImageProjection
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+def create_chroma_ip_adapter_state_dict(model):
+    # "ip_adapter" (cross-attention weights)
+    ip_cross_attn_state_dict = {}
+    key_id = 0
+
+    for name in model.attn_processors.keys():
+        if name.startswith("single_transformer_blocks"):
+            continue
+
+        joint_attention_dim = model.config["joint_attention_dim"]
+        hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
+        sd = FluxIPAdapterJointAttnProcessor2_0(
+            hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
+        ).state_dict()
+        ip_cross_attn_state_dict.update(
+            {
+                f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+                f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
+                f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
+                f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
+            }
+        )
+
+        key_id += 1
+
+    # "image_proj" (ImageProjection layer weights)
+
+    image_projection = ImageProjection(
+        cross_attention_dim=model.config["joint_attention_dim"],
+        image_embed_dim=model.config["pooled_projection_dim"],
+        num_image_text_embeds=4,
+    )
+
+    ip_image_projection_state_dict = {}
+    sd = image_projection.state_dict()
+    ip_image_projection_state_dict.update(
+        {
+            "proj.weight": sd["image_embeds.weight"],
+            "proj.bias": sd["image_embeds.bias"],
+            "norm.weight": sd["norm.weight"],
+            "norm.bias": sd["norm.bias"],
+        }
+    )
+
+    del sd
+    ip_state_dict = {}
+    ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+    return ip_state_dict
+
+
+class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
+    model_class = ChromaTransformer2DModel
+    main_input_name = "hidden_states"
+    # We override the items here because the transformer under consideration is small.
+    model_split_percents = [0.8, 0.7, 0.7]
+
+    # Skip setting testing with default: AttnProcessor
+    uses_custom_attn_processor = True
+
+    @property
+    def dummy_input(self):
+        batch_size = 1
+        num_latent_channels = 4
+        num_image_channels = 3
+        height = width = 4
+        sequence_length = 48
+        embedding_dim = 32
+
+        hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+        encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+        text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
+        image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
+        timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+        return {
+            "hidden_states": hidden_states,
+            "encoder_hidden_states": encoder_hidden_states,
+            "img_ids": image_ids,
+            "txt_ids": text_ids,
+            "timestep": timestep,
+        }
+
+    @property
+    def input_shape(self):
+        return (16, 4)
+
+    @property
+    def output_shape(self):
+        return (16, 4)
+
+    def prepare_init_args_and_inputs_for_common(self):
+        init_dict = {
+            "patch_size": 1,
+            "in_channels": 4,
+            "num_layers": 1,
+            "num_single_layers": 1,
+            "attention_head_dim": 16,
+            "num_attention_heads": 2,
+            "joint_attention_dim": 32,
+            "axes_dims_rope": [4, 4, 8],
+            "approximator_num_channels": 8,
+            "approximator_hidden_dim": 16,
+            "approximator_layers": 1,
+        }
+
+        inputs_dict = self.dummy_input
+        return init_dict, inputs_dict
+
+    def test_deprecated_inputs_img_txt_ids_3d(self):
+        init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+        model = self.model_class(**init_dict)
+        model.to(torch_device)
+        model.eval()
+
+        with torch.no_grad():
+            output_1 = model(**inputs_dict).to_tuple()[0]
+
+        # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
+        text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
+        image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
+
+        assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
+        assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
+
+        inputs_dict["txt_ids"] = text_ids_3d
+        inputs_dict["img_ids"] = image_ids_3d
+
+        with torch.no_grad():
+            output_2 = model(**inputs_dict).to_tuple()[0]
+
+        self.assertEqual(output_1.shape, output_2.shape)
+        self.assertTrue(
+            torch.allclose(output_1, output_2, atol=1e-5),
+            msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
+        )
+
+    def test_gradient_checkpointing_is_applied(self):
+        expected_set = {"ChromaTransformer2DModel"}
+        super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+    model_class = ChromaTransformer2DModel
+
+    def prepare_init_args_and_inputs_for_common(self):
+        return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
+
+
+class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+    model_class = ChromaTransformer2DModel
+
+    def prepare_init_args_and_inputs_for_common(self):
+        return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 33c876535871..036ed2ea3039 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model):
 
     image_projection = ImageProjection(
         cross_attention_dim=model.config["joint_attention_dim"],
-        image_embed_dim=model.config["pooled_projection_dim"],
+        image_embed_dim=(
+            model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
+        ),
         num_image_text_embeds=4,
     )
 
diff --git a/tests/pipelines/chroma/__init__.py b/tests/pipelines/chroma/__init__.py
new file mode 100644
index 000000000000..8b137891791f
--- /dev/null
+++ b/tests/pipelines/chroma/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py
new file mode 100644
index 000000000000..fc5749f96cd8
--- /dev/null
+++ b/tests/pipelines/chroma/test_pipeline_chroma.py
@@ -0,0 +1,167 @@
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
+from diffusers.utils.testing_utils import torch_device
+
+from ..test_pipelines_common import (
+    FluxIPAdapterTesterMixin,
+    PipelineTesterMixin,
+    check_qkv_fusion_matches_attn_procs_length,
+    check_qkv_fusion_processors_exist,
+)
+
+
+class ChromaPipelineFastTests(
+    unittest.TestCase,
+    PipelineTesterMixin,
+    FluxIPAdapterTesterMixin,
+):
+    pipeline_class = ChromaPipeline
+    params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+    batch_params = frozenset(["prompt"])
+
+    # there is no xformers processor for Flux
+    test_xformers_attention = False
+    test_layerwise_casting = True
+    test_group_offloading = True
+
+    def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+        torch.manual_seed(0)
+        transformer = ChromaTransformer2DModel(
+            patch_size=1,
+            in_channels=4,
+            num_layers=num_layers,
+            num_single_layers=num_single_layers,
+            attention_head_dim=16,
+            num_attention_heads=2,
+            joint_attention_dim=32,
+            axes_dims_rope=[4, 4, 8],
+            approximator_hidden_dim=32,
+            approximator_layers=1,
+            approximator_num_channels=16,
+        )
+
+        torch.manual_seed(0)
+        text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+        tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+        torch.manual_seed(0)
+        vae = AutoencoderKL(
+            sample_size=32,
+            in_channels=3,
+            out_channels=3,
+            block_out_channels=(4,),
+            layers_per_block=1,
+            latent_channels=1,
+            norm_num_groups=1,
+            use_quant_conv=False,
+            use_post_quant_conv=False,
+            shift_factor=0.0609,
+            scaling_factor=1.5035,
+        )
+
+        scheduler = FlowMatchEulerDiscreteScheduler()
+
+        return {
+            "scheduler": scheduler,
+            "text_encoder": text_encoder,
+            "tokenizer": tokenizer,
+            "transformer": transformer,
+            "vae": vae,
+            "image_encoder": None,
+            "feature_extractor": None,
+        }
+
+    def get_dummy_inputs(self, device, seed=0):
+        if str(device).startswith("mps"):
+            generator = torch.manual_seed(seed)
+        else:
+            generator = torch.Generator(device="cpu").manual_seed(seed)
+
+        inputs = {
+            "prompt": "A painting of a squirrel eating a burger",
+            "negative_prompt": "bad, ugly",
+            "generator": generator,
+            "num_inference_steps": 2,
+            "guidance_scale": 5.0,
+            "height": 8,
+            "width": 8,
+            "max_sequence_length": 48,
+            "output_type": "np",
+        }
+        return inputs
+
+    def test_chroma_different_prompts(self):
+        pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+        inputs = self.get_dummy_inputs(torch_device)
+        output_same_prompt = pipe(**inputs).images[0]
+
+        inputs = self.get_dummy_inputs(torch_device)
+        inputs["prompt"] = "a different prompt"
+        output_different_prompts = pipe(**inputs).images[0]
+
+        max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+        # Outputs should be different here
+        # For some reasons, they don't show large differences
+        assert max_diff > 1e-6
+
+    def test_fused_qkv_projections(self):
+        device = "cpu"  # ensure determinism for the device-dependent torch.Generator
+        components = self.get_dummy_components()
+        pipe = self.pipeline_class(**components)
+        pipe = pipe.to(device)
+        pipe.set_progress_bar_config(disable=None)
+
+        inputs = self.get_dummy_inputs(device)
+        image = pipe(**inputs).images
+        original_image_slice = image[0, -3:, -3:, -1]
+
+        # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+        # to the pipeline level.
+        pipe.transformer.fuse_qkv_projections()
+        assert check_qkv_fusion_processors_exist(pipe.transformer), (
+            "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+        )
+        assert check_qkv_fusion_matches_attn_procs_length(
+            pipe.transformer, pipe.transformer.original_attn_processors
+        ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+        inputs = self.get_dummy_inputs(device)
+        image = pipe(**inputs).images
+        image_slice_fused = image[0, -3:, -3:, -1]
+
+        pipe.transformer.unfuse_qkv_projections()
+        inputs = self.get_dummy_inputs(device)
+        image = pipe(**inputs).images
+        image_slice_disabled = image[0, -3:, -3:, -1]
+
+        assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+            "Fusion of QKV projections shouldn't affect the outputs."
+        )
+        assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+            "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+        )
+        assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+            "Original outputs should match when fused QKV projections are disabled."
+        )
+
+    def test_chroma_image_output_shape(self):
+        pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+        inputs = self.get_dummy_inputs(torch_device)
+
+        height_width_pairs = [(32, 32), (72, 57)]
+        for height, width in height_width_pairs:
+            expected_height = height - height % (pipe.vae_scale_factor * 2)
+            expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+            inputs.update({"height": height, "width": width})
+            image = pipe(**inputs).images[0]
+            output_height, output_width, _ = image.shape
+            assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 91ffc0ae537d..687a28294c9a 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -521,7 +521,8 @@ def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
 
     def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
         inputs["negative_prompt"] = ""
-        inputs["true_cfg_scale"] = 4.0
+        if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters:
+            inputs["true_cfg_scale"] = 4.0
         inputs["output_type"] = "np"
         inputs["return_dict"] = False
         return inputs
@@ -542,7 +543,11 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
         components = self.get_dummy_components()
         pipe = self.pipeline_class(**components).to(torch_device)
         pipe.set_progress_bar_config(disable=None)
-        image_embed_dim = pipe.transformer.config.pooled_projection_dim
+        image_embed_dim = (
+            pipe.transformer.config.pooled_projection_dim
+            if hasattr(pipe.transformer.config, "pooled_projection_dim")
+            else 768
+        )
 
         # forward pass without ip adapter
         inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))