-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[modular] i2i and t2i support for kontext modular #12454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
in case you need it, the code I'm using to test: import torch
from diffusers.modular_pipelines import ComponentsManager, ModularPipeline
from diffusers.utils import load_image
# CONFIG
repo_id = "black-forest-labs/FLUX.1-Kontext-dev"
device = "cuda"
prompt = "make it sknow"
guidance_scale = 2.5
num_inference_steps = 28
seed = 0
image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/dog_source.png").convert("RGB")
# COMPONENTS MANAGER
components = ComponentsManager()
components.enable_auto_cpu_offload(device=device)
# BLOCKS
blocks = ModularPipeline.from_pretrained(repo_id, components_manager=components).blocks
# ENCODE PROMPT
text_blocks = blocks.sub_blocks.pop("text_encoder")
text_encoder_node = text_blocks.init_pipeline(repo_id, components_manager=components)
text_encoder_node.load_components(torch_dtype=torch.bfloat16)
text_state = text_encoder_node(prompt=prompt, max_sequence_length=512)
text_embeddings = text_state.get_by_kwargs("denoiser_input_fields")
# ENCODE IMAGE
encoder_block_name = next((name for name in blocks.block_names if "encode" in name.lower() and "text" not in name.lower()), None)
encoder_blocks = blocks.sub_blocks.pop(encoder_block_name)
encoder_node = encoder_blocks.init_pipeline(repo_id, components_manager=components)
encoder_node.load_components(torch_dtype=torch.bfloat16)
state = encoder_node(image=image)
image_latents = state.get("image_latents")
# DENOISE
denoise_blocks = blocks.sub_blocks.pop("denoise")
denoise_node = denoise_blocks.init_pipeline(repo_id, components_manager=components)
denoise_node.load_components(torch_dtype=torch.bfloat16)
generator = torch.Generator(device=device).manual_seed(seed)
denoise_state = denoise_node(
**text_embeddings,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
image_latents=image_latents,
)
latents = denoise_state.get("latents")
# VAE DECODE
decoder_blocks = blocks.sub_blocks.pop("decode")
decoder_node = decoder_blocks.init_pipeline(repo_id, components_manager=components)
decoder_node.load_components(torch_dtype=torch.bfloat16)
image = decoder_node(latents=latents, output="images")[0]
image.save("modular_result.png") |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@asomoza you should be good to go now :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! both T2I and I2I now works with nodes and I don't see any other issues
What does this PR do?
Subcedes #12269.
Test code:
Results: