Skip to content

Chroma Pipeline #11698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 117 commits into from
Jun 14, 2025
Merged

Chroma Pipeline #11698

merged 117 commits into from
Jun 14, 2025

Conversation

Ednaordinary
Copy link
Contributor

What does this PR do?

Fixes #11010

relevant #11566

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DN6

@ghunkins
Copy link
Contributor

ghunkins commented Jun 13, 2025

Amazing work on this @Ednaordinary, huge thanks. Getting the below when adding a LoRA:

 scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[self.__class__.__name__]
E       KeyError: 'ChromaTransformer2DModel'

Can we add ChromaTransformer2DModel here: https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/peft.py#L45

_SET_ADAPTER_SCALE_FN_MAPPING = {
  ...
  "ChromaTransformer2DModel": lambda model_cls, weights: weights,
  ...
}

I think that will fix.

@ghunkins
Copy link
Contributor

image

Looks great 🔥

@DN6 DN6 merged commit 8adc600 into huggingface:main Jun 14, 2025
30 checks passed
@DN6
Copy link
Collaborator

DN6 commented Jun 14, 2025

Great work @Ednaordinary @hameerabbasi and @iddl! 🚀

@nitinmukesh
Copy link

Awesome. Thank you all. 👍

@tin2tin
Copy link

tin2tin commented Jun 14, 2025

Thank you for this great commit!

A question: is bitsandbytes quantization not supported for this model? (As it, it's 18.5 GB VRAM, it's a bit too heavy for a lot of consumer cards)

import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline, BitsAndBytesConfig
from transformers import T5EncoderModel, T5Tokenizer

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype, quantization_config=nf4_config) 

text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)

pipe.enable_model_cpu_offload()

Gives me this error:

 File ".\python\Lib\site-packages\huggingface_hub\utils\_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".\Python\Python311\site-packages\diffusers\loaders\single_file_model.py", line 415, in from_single_file
    load_model_dict_into_meta(
  File ".\Python\Python311\site-packages\diffusers\models\model_loading_utils.py", line 298, in load_model_dict_into_meta
    hf_quantizer.create_quantized_param(
  File ".\Python\Python311\site-packages\diffusers\quantizers\bitsandbytes\bnb_quantizer.py", line 182, in create_quantized_param
    raise ValueError(
ValueError: Supplied state dict for distilled_guidance_layer.in_proj.weight does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.

@Ednaordinary
Copy link
Contributor Author

Ednaordinary commented Jun 14, 2025

@tin2tin bitsandbytes is supported! Just save the diffusers version first with .save_pretrained() and reload with the quantization config (.from_pretrained()). My diffusers weights may also work but I'm not sure how out of date they are with the current code: https://huggingface.co/imnotednamode/Chroma-v36-dc-diffusers

Actually my weights still load fine, just prints some unnecessary attribute warnings. Will fix when I get around to it

@tin2tin
Copy link

tin2tin commented Jun 14, 2025

@Ednaordinary Oh, that's super cool! How do you load your diffusers version? Just the transformer and quantize that? Can you share a snippet which shows how you use it? Like this?

import torch
from diffusers import ChromaTransformer2DModel, ChromaPipeline, BitsAndBytesConfig
from transformers import T5EncoderModel, T5Tokenizer

bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
transformer = ChromaTransformer2DModel.from_pretrained(
    "imnotednamode/Chroma-v36-dc-diffusers",
    subfolder="transformer",
    quantization_config=nf4_config,
    torch_dtype=torch.bfloat16,
)

text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)

pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)

pipe.enable_model_cpu_offload()

Getting this notice:

The config attributes {'approximator_in_factor': 16} were passed to ChromaTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.

@Ednaordinary
Copy link
Contributor Author

Ednaordinary commented Jun 14, 2025

You can safely ignore the config notice, it's because changes have been made to the diffusers code since I generated that checkpoint. Also be sure to add llm_int8_skip_modules=["distilled_guidance_layer"] as noted in #11698 (comment) for the best quality

@tin2tin
Copy link

tin2tin commented Jun 14, 2025

@Ednaordinary Where should I add llm_int8_skip_modules=["distilled_guidance_layer"]. Could you help me out with a quantized-example-code-snippet?

@nitinmukesh
Copy link

nitinmukesh commented Jun 14, 2025

@tin2tin

See here is the code for different model.
bitsandbytes-foundation/bitsandbytes#1611

model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map='cuda:0', attn_implementation='flash_attention_2',
    torch_dtype=torch.bfloat16,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,

        llm_int8_skip_modules=["lm_head"],

        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    )

Also instead of applying quantization on text_encoder_2 and transformer separately, both these modules can be specified in quantization_config. If I remember correctly @sayakpaul posted example somewhere. I somehow can't find it.

It was something like

pipe=FluxPipeline (
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            llm_int8_skip_modules=["lm_head"],
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            ??=[text_encoder_2, transformer]
        )
)

@nitinmukesh
Copy link

Ok I found it
#11648 (comment)

components_to_quantize=["transformer", "text_encoder_2"]

@Ednaordinary
Copy link
Contributor Author

Ednaordinary commented Jun 14, 2025

Another pipeline level quant example is here #11698 (comment)

Also yes, the parameter is passed to the BitsAndBytes config

BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    llm_int8_skip_modules=["distilled_guidance_layer"],
)

@nitinmukesh
Copy link

nitinmukesh commented Jun 14, 2025

Another pipeline level quant example is here [#11698 (comment)]

Ahh this is what I was referring to (sample from Sayakpaul). Thanks.
#11698 (comment)

@asomoza
Copy link
Member

asomoza commented Jun 14, 2025

if you want a ready to use code, this one works with main branch:

import torch

from diffusers import ChromaPipeline
from diffusers.quantizers import PipelineQuantizationConfig


dtype = torch.bfloat16

repo_id = "imnotednamode/Chroma-v36-dc-diffusers"

pipeline_quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_4bit",
    quant_kwargs={
        "load_in_4bit": True,
        "bnb_4bit_quant_type": "nf4",
        "bnb_4bit_compute_dtype": dtype,
        "llm_int8_skip_modules": ["distilled_guidance_layer"],
    },
    components_to_quantize=["transformer", "text_encoder"],
)

pipe = ChromaPipeline.from_pretrained(
    "imnotednamode/Chroma-v36-dc-diffusers",
    quantization_config=pipeline_quant_config,
    torch_dtype=dtype,
)
pipe.enable_model_cpu_offload()

prompt = 'Ultra-realistic, high-quality photo of an anthropomorphic capybara with a tough, streetwise attitude, wearing a worn black leather jacket, dark sunglasses, and ripped jeans. The capybara is leaning casually against a gritty urban wall covered in vibrant graffiti. Behind it, in bold, dripping yellow spray paint, the word "HuggingFace" is scrawled in large street-art style letters. The scene is set in a dimly lit alleyway with moody lighting, scattered trash, and an edgy, rebellious vibe — like a character straight out of an underground comic book.'
negative = "low quality, bad anatomy, extra digits, missing digits, extra limbs, missing limbs"

image = pipe(
    prompt=prompt,
    negative_prompt=negative,
    num_inference_steps=30,
    guidance_scale=4.0,
    width=1024,
    height=1024,
    generator=torch.Generator().manual_seed(42),
).images[0]

image.save("chroma.png")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Chroma - Flux based model with architecture changes
10 participants