Anthony Chen1,2 · Jianjin Xu3 · Wenzhao Zheng4 · Gaole Dai1 · Yida Wang5 · Renrui Zhang6 · Haofan Wang2 · Shanghang Zhang1*
1Peking University · 2InstantX Team · 3Carnegie Mellon University · 4UC Berkeley · 5Li Auto Inc. · 6CUHK
Training-free Regional Prompting for Diffusion Transformers(Regional-Prompting-FLUX) enables Diffusion Transformers (i.e., FLUX) with find-grained compositional text-to-image generation capability in a training-free manner. Empirically, we show that our method is highly effective and compatible with LoRA and ControlNet.
We inference at speed much faster than the RPG-based implementation, yet take up less GPU memory.
- [2024/11/05] 🔥 We release the code, feel free to try it out!
- [2024/11/05] 🔥 We release the technical report!
We use previous commit from diffusers repo to ensure reproducibility, as we found new diffusers version may experience different results.
# install diffusers locally
git clone https://github.com/huggingface/diffusers.git
cd diffusers
# reset diffusers version to 0.31.dev, where we developed Regional-Prompting-FLUX on, different version may experience different results
git reset --hard d13b0d63c0208f2c4c078c4261caf8bf587beb3b
pip install -e ".[torch]"
cd ..
# install other dependencies
pip install -U transformers sentencepiece protobuf PEFT
# clone this repo
git clone https://github.com/antonioo-c/Regional-Prompting-FLUX.git
# replace file in diffusers
cd Regional-Prompting-FLUX
cp transformer_flux.py ../diffusers/src/diffusers/models/transformers/transformer_flux.py
See detailed example (including LoRAs and ControlNets) in infer_flux_regional.py. Below is a quick start example.
import torch
from pipeline_flux_regional import RegionalFluxPipeline, RegionalFluxAttnProcessor2_0
pipeline = RegionalFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
attn_procs = {}
for name in pipeline.transformer.attn_processors.keys():
if 'transformer_blocks' in name and name.endswith("attn.processor"):
attn_procs[name] = RegionalFluxAttnProcessor2_0()
else:
attn_procs[name] = pipeline.transformer.attn_processors[name]
pipeline.transformer.set_attn_processor(attn_procs)
## general settings
image_width = 1280
image_height = 768
num_inference_steps = 24
seed = 124
base_prompt = "An ancient woman stands solemnly holding a blazing torch, while a fierce battle rages in the background, capturing both strength and tragedy in a historical war scene."
background_prompt = "a photo" # set by default, but if you want to enrich background, you can set it to a more descriptive prompt
regional_prompt_mask_pairs = {
"0": {
"description": "A dignified woman in ancient robes stands in the foreground, her face illuminated by the torch she holds high. Her expression is one of determination and sorrow, her clothing and appearance reflecting the historical period. The torch casts dramatic shadows across her features, its flames dancing vibrantly against the darkness.",
"mask": [128, 128, 640, 768]
}
}
## region control factor settings
mask_inject_steps = 10 # larger means stronger control, recommended between 5-10
double_inject_blocks_interval = 1 # 1 means strongest control
single_inject_blocks_interval = 1 # 1 means strongest control
base_ratio = 0.2 # smaller means stronger control
regional_prompts = []
regional_masks = []
background_mask = torch.ones((image_height, image_width))
for region_idx, region in regional_prompt_mask_pairs.items():
description = region['description']
mask = region['mask']
x1, y1, x2, y2 = mask
mask = torch.zeros((image_height, image_width))
mask[y1:y2, x1:x2] = 1.0
background_mask -= mask
regional_prompts.append(description)
regional_masks.append(mask)
# if regional masks don't cover the whole image, append background prompt and mask
if background_mask.sum() > 0:
regional_prompts.append(background_prompt)
regional_masks.append(background_mask)
image = pipeline(
prompt=base_prompt,
width=image_width, height=image_height,
mask_inject_steps=mask_inject_steps,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed),
joint_attention_kwargs={
"regional_prompts": regional_prompts,
"regional_masks": regional_masks,
"double_inject_blocks_interval": double_inject_blocks_interval,
"single_inject_blocks_interval": single_inject_blocks_interval,
"base_ratio": base_ratio
},
).images[0]
image.save(f"output.jpg")
Our work is sponsored by HuggingFace and fal.ai. Thanks!
If you find Regional-Prompting-FLUX useful for your research and applications, please cite us using this BibTeX:
@article{chen2024training,
title={Training-free Regional Prompting for Diffusion Transformers},
author={Chen, Anthony and Xu, Jianjin and Zheng, Wenzhao and Dai, Gaole and Wang, Yida and Zhang, Renrui and Wang, Haofan and Zhang, Shanghang},
journal={arXiv preprint arXiv:2411.02395},
year={2024}
}
For any question, feel free to contact us via [email protected].