Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ class QEffFluxTransformer2DModel(FluxTransformer2DModel):
def forward(
self,
hidden_states: torch.Tensor,
image_rotary_emb_cos: Optional[torch.Tensor] = None,
image_rotary_emb_sin: Optional[torch.Tensor] = None,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
Expand Down Expand Up @@ -286,9 +288,6 @@ def forward(
)
img_ids = img_ids[0]

ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)

if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
Expand All @@ -299,7 +298,7 @@ def forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_emb[index_block],
image_rotary_emb=image_rotary_emb,
image_rotary_emb=(image_rotary_emb_cos, image_rotary_emb_sin),
joint_attention_kwargs=joint_attention_kwargs,
)

Expand All @@ -320,7 +319,7 @@ def forward(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_single_emb[index_block],
image_rotary_emb=image_rotary_emb,
image_rotary_emb=(image_rotary_emb_cos, image_rotary_emb_sin),
joint_attention_kwargs=joint_attention_kwargs,
)

Expand Down
12 changes: 11 additions & 1 deletion QEfficient/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,10 @@ def compile(

# Prepare dynamic specialization updates based on image dimensions
specialization_updates = {
"transformer": {"cl": cl},
"transformer": {
"cl": cl,
"image_seq_len": cl + 256,
},
"vae_decoder": {
"latent_height": latent_height,
"latent_width": latent_width,
Expand Down Expand Up @@ -718,6 +721,11 @@ def __call__(
latents,
)

# Create latent position IDs for transformer
ids = torch.cat((text_ids, latent_image_ids), dim=0)
# Compute image rotary embeddings for transformer
image_rotary_emb_cos, image_rotary_emb_sin = self.transformer.model.pos_embed(ids)

# Step 6: Calculate compressed latent dimension for transformer buffer allocation
cl, _, _ = calculate_compressed_latent_dimension(height, width, self.model.vae_scale_factor)

Expand Down Expand Up @@ -779,6 +787,8 @@ def __call__(
"adaln_emb": adaln_dual_emb.detach().numpy(),
"adaln_single_emb": adaln_single_emb.detach().numpy(),
"adaln_out": adaln_out.detach().numpy(),
"image_rotary_emb_cos": image_rotary_emb_cos.detach().numpy(),
"image_rotary_emb_sin": image_rotary_emb_sin.detach().numpy(),
}

# Run transformer inference and measure time
Expand Down
4 changes: 4 additions & 0 deletions QEfficient/diffusers/pipelines/pipeline_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def get_onnx_params(
# Output AdaLN embedding
# Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection
"adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32),
"image_rotary_emb_cos": torch.randn(constants.FLUX_IMAGE_SEQ_LENGTH, 128, dtype=torch.float32),
"image_rotary_emb_sin": torch.randn(constants.FLUX_IMAGE_SEQ_LENGTH, 128, dtype=torch.float32),
}

output_names = ["output"]
Expand All @@ -427,6 +429,8 @@ def get_onnx_params(
"pooled_projections": {0: "batch_size"},
"timestep": {0: "steps"},
"img_ids": {0: "cl"},
"image_rotary_emb_cos": {0: "image_seq_len"},
"image_rotary_emb_sin": {0: "image_seq_len"},
}

return example_inputs, dynamic_axes, output_names
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def get_models_dir():
FLUX_ADALN_DUAL_BLOCK_CHUNKS = 12 # 6 chunks for norm1 + 6 chunks for norm1_context
FLUX_ADALN_SINGLE_BLOCK_CHUNKS = 3
FLUX_ADALN_OUTPUT_DIM = 6144 # 2 * FLUX_ADALN_HIDDEN_DIM
FLUX_IMAGE_SEQ_LENGTH = 4352

# Wan Transformer Constants
WAN_TEXT_EMBED_DIM = 5120
Expand Down
Loading