Skip to content

Commit f9c1e61

Browse files
naykungithub-actions[bot]sayakpaul
authored
Qwen Image Layered Support (#12853)
* [qwen-image] qwen image layered support * [qwen-image] update doc * [qwen-image] fix pr comments * Apply style fixes * make fix-copies --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <[email protected]>
1 parent 87f7d11 commit f9c1e61

File tree

7 files changed

+1070
-10
lines changed

7 files changed

+1070
-10
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@
564564
"QwenImageEditPlusPipeline",
565565
"QwenImageImg2ImgPipeline",
566566
"QwenImageInpaintPipeline",
567+
"QwenImageLayeredPipeline",
567568
"QwenImagePipeline",
568569
"ReduxImageEncoder",
569570
"SanaControlNetPipeline",
@@ -1272,6 +1273,7 @@
12721273
QwenImageEditPlusPipeline,
12731274
QwenImageImg2ImgPipeline,
12741275
QwenImageInpaintPipeline,
1276+
QwenImageLayeredPipeline,
12751277
QwenImagePipeline,
12761278
ReduxImageEncoder,
12771279
SanaControlNetPipeline,

src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def __init__(
394394
attn_scales=[],
395395
temperal_downsample=[True, True, False],
396396
dropout=0.0,
397+
input_channels=3,
397398
non_linearity: str = "silu",
398399
):
399400
super().__init__()
@@ -410,7 +411,7 @@ def __init__(
410411
scale = 1.0
411412

412413
# init block
413-
self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
414+
self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)
414415

415416
# downsample blocks
416417
self.down_blocks = nn.ModuleList([])
@@ -570,6 +571,7 @@ def __init__(
570571
attn_scales=[],
571572
temperal_upsample=[False, True, True],
572573
dropout=0.0,
574+
input_channels=3,
573575
non_linearity: str = "silu",
574576
):
575577
super().__init__()
@@ -621,7 +623,7 @@ def __init__(
621623

622624
# output blocks
623625
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
624-
self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
626+
self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1)
625627

626628
self.gradient_checkpointing = False
627629

@@ -684,6 +686,7 @@ def __init__(
684686
attn_scales: List[float] = [],
685687
temperal_downsample: List[bool] = [False, True, True],
686688
dropout: float = 0.0,
689+
input_channels: int = 3,
687690
latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
688691
latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
689692
) -> None:
@@ -695,13 +698,13 @@ def __init__(
695698
self.temperal_upsample = temperal_downsample[::-1]
696699

697700
self.encoder = QwenImageEncoder3d(
698-
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
701+
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels
699702
)
700703
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
701704
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
702705

703706
self.decoder = QwenImageDecoder3d(
704-
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
707+
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels
705708
)
706709

707710
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,17 +143,26 @@ def apply_rotary_emb_qwen(
143143

144144

145145
class QwenTimestepProjEmbeddings(nn.Module):
146-
def __init__(self, embedding_dim):
146+
def __init__(self, embedding_dim, use_additional_t_cond=False):
147147
super().__init__()
148148

149149
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
150150
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
151+
self.use_additional_t_cond = use_additional_t_cond
152+
if use_additional_t_cond:
153+
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
151154

152-
def forward(self, timestep, hidden_states):
155+
def forward(self, timestep, hidden_states, addition_t_cond=None):
153156
timesteps_proj = self.time_proj(timestep)
154157
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
155158

156159
conditioning = timesteps_emb
160+
if self.use_additional_t_cond:
161+
if addition_t_cond is None:
162+
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
163+
addition_t_emb = self.addition_t_embedding(addition_t_cond)
164+
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
165+
conditioning = conditioning + addition_t_emb
157166

158167
return conditioning
159168

@@ -259,6 +268,120 @@ def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0
259268
return freqs.clone().contiguous()
260269

261270

271+
class QwenEmbedLayer3DRope(nn.Module):
272+
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
273+
super().__init__()
274+
self.theta = theta
275+
self.axes_dim = axes_dim
276+
pos_index = torch.arange(4096)
277+
neg_index = torch.arange(4096).flip(0) * -1 - 1
278+
self.pos_freqs = torch.cat(
279+
[
280+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
281+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
282+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
283+
],
284+
dim=1,
285+
)
286+
self.neg_freqs = torch.cat(
287+
[
288+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
289+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
290+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
291+
],
292+
dim=1,
293+
)
294+
295+
self.scale_rope = scale_rope
296+
297+
def rope_params(self, index, dim, theta=10000):
298+
"""
299+
Args:
300+
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
301+
"""
302+
assert dim % 2 == 0
303+
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
304+
freqs = torch.polar(torch.ones_like(freqs), freqs)
305+
return freqs
306+
307+
def forward(self, video_fhw, txt_seq_lens, device):
308+
"""
309+
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
310+
txt_length: [bs] a list of 1 integers representing the length of the text
311+
"""
312+
if self.pos_freqs.device != device:
313+
self.pos_freqs = self.pos_freqs.to(device)
314+
self.neg_freqs = self.neg_freqs.to(device)
315+
316+
if isinstance(video_fhw, list):
317+
video_fhw = video_fhw[0]
318+
if not isinstance(video_fhw, list):
319+
video_fhw = [video_fhw]
320+
321+
vid_freqs = []
322+
max_vid_index = 0
323+
layer_num = len(video_fhw) - 1
324+
for idx, fhw in enumerate(video_fhw):
325+
frame, height, width = fhw
326+
if idx != layer_num:
327+
video_freq = self._compute_video_freqs(frame, height, width, idx)
328+
else:
329+
### For the condition image, we set the layer index to -1
330+
video_freq = self._compute_condition_freqs(frame, height, width)
331+
video_freq = video_freq.to(device)
332+
vid_freqs.append(video_freq)
333+
334+
if self.scale_rope:
335+
max_vid_index = max(height // 2, width // 2, max_vid_index)
336+
else:
337+
max_vid_index = max(height, width, max_vid_index)
338+
339+
max_vid_index = max(max_vid_index, layer_num)
340+
max_len = max(txt_seq_lens)
341+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
342+
vid_freqs = torch.cat(vid_freqs, dim=0)
343+
344+
return vid_freqs, txt_freqs
345+
346+
@functools.lru_cache(maxsize=None)
347+
def _compute_video_freqs(self, frame, height, width, idx=0):
348+
seq_lens = frame * height * width
349+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
350+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
351+
352+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
353+
if self.scale_rope:
354+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
355+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
356+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
357+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
358+
else:
359+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
360+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
361+
362+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
363+
return freqs.clone().contiguous()
364+
365+
@functools.lru_cache(maxsize=None)
366+
def _compute_condition_freqs(self, frame, height, width):
367+
seq_lens = frame * height * width
368+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
369+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
370+
371+
freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1)
372+
if self.scale_rope:
373+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
374+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
375+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
376+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
377+
else:
378+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
379+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
380+
381+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
382+
return freqs.clone().contiguous()
383+
384+
262385
class QwenDoubleStreamAttnProcessor2_0:
263386
"""
264387
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
@@ -578,14 +701,21 @@ def __init__(
578701
guidance_embeds: bool = False, # TODO: this should probably be removed
579702
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
580703
zero_cond_t: bool = False,
704+
use_additional_t_cond: bool = False,
705+
use_layer3d_rope: bool = False,
581706
):
582707
super().__init__()
583708
self.out_channels = out_channels or in_channels
584709
self.inner_dim = num_attention_heads * attention_head_dim
585710

586-
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
711+
if not use_layer3d_rope:
712+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
713+
else:
714+
self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
587715

588-
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
716+
self.time_text_embed = QwenTimestepProjEmbeddings(
717+
embedding_dim=self.inner_dim, use_additional_t_cond=use_additional_t_cond
718+
)
589719

590720
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
591721

@@ -621,6 +751,7 @@ def forward(
621751
guidance: torch.Tensor = None, # TODO: this should probably be removed
622752
attention_kwargs: Optional[Dict[str, Any]] = None,
623753
controlnet_block_samples=None,
754+
additional_t_cond=None,
624755
return_dict: bool = True,
625756
) -> Union[torch.Tensor, Transformer2DModelOutput]:
626757
"""
@@ -683,9 +814,9 @@ def forward(
683814
guidance = guidance.to(hidden_states.dtype) * 1000
684815

685816
temb = (
686-
self.time_text_embed(timestep, hidden_states)
817+
self.time_text_embed(timestep, hidden_states, additional_t_cond)
687818
if guidance is None
688-
else self.time_text_embed(timestep, guidance, hidden_states)
819+
else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond)
689820
)
690821

691822
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@
422422
"QwenImageEditInpaintPipeline",
423423
"QwenImageControlNetInpaintPipeline",
424424
"QwenImageControlNetPipeline",
425+
"QwenImageLayeredPipeline",
425426
]
426427
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
427428
try:
@@ -764,6 +765,7 @@
764765
QwenImageEditPlusPipeline,
765766
QwenImageImg2ImgPipeline,
766767
QwenImageInpaintPipeline,
768+
QwenImageLayeredPipeline,
767769
QwenImagePipeline,
768770
)
769771
from .sana import (

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
3232
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
3333
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
34+
_import_structure["pipeline_qwenimage_layered"] = ["QwenImageLayeredPipeline"]
3435

3536
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3637
try:
@@ -47,6 +48,7 @@
4748
from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
4849
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
4950
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
51+
from .pipeline_qwenimage_layered import QwenImageLayeredPipeline
5052
else:
5153
import sys
5254

0 commit comments

Comments
 (0)