|
| 1 | +import torch |
| 2 | +import math |
| 3 | + |
| 4 | +from .model import QwenImageTransformer2DModel |
| 5 | + |
| 6 | + |
| 7 | +class QwenImageControlNetModel(QwenImageTransformer2DModel): |
| 8 | + def __init__( |
| 9 | + self, |
| 10 | + extra_condition_channels=0, |
| 11 | + dtype=None, |
| 12 | + device=None, |
| 13 | + operations=None, |
| 14 | + **kwargs |
| 15 | + ): |
| 16 | + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) |
| 17 | + self.main_model_double = 60 |
| 18 | + |
| 19 | + # controlnet_blocks |
| 20 | + self.controlnet_blocks = torch.nn.ModuleList([]) |
| 21 | + for _ in range(len(self.transformer_blocks)): |
| 22 | + self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) |
| 23 | + self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) |
| 24 | + |
| 25 | + def forward( |
| 26 | + self, |
| 27 | + x, |
| 28 | + timesteps, |
| 29 | + context, |
| 30 | + attention_mask=None, |
| 31 | + guidance: torch.Tensor = None, |
| 32 | + ref_latents=None, |
| 33 | + hint=None, |
| 34 | + transformer_options={}, |
| 35 | + **kwargs |
| 36 | + ): |
| 37 | + timestep = timesteps |
| 38 | + encoder_hidden_states = context |
| 39 | + encoder_hidden_states_mask = attention_mask |
| 40 | + |
| 41 | + hidden_states, img_ids, orig_shape = self.process_img(x) |
| 42 | + hint, _, _ = self.process_img(hint) |
| 43 | + |
| 44 | + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) |
| 45 | + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) |
| 46 | + ids = torch.cat((txt_ids, img_ids), dim=1) |
| 47 | + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) |
| 48 | + del ids, txt_ids, img_ids |
| 49 | + |
| 50 | + hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) |
| 51 | + encoder_hidden_states = self.txt_norm(encoder_hidden_states) |
| 52 | + encoder_hidden_states = self.txt_in(encoder_hidden_states) |
| 53 | + |
| 54 | + if guidance is not None: |
| 55 | + guidance = guidance * 1000 |
| 56 | + |
| 57 | + temb = ( |
| 58 | + self.time_text_embed(timestep, hidden_states) |
| 59 | + if guidance is None |
| 60 | + else self.time_text_embed(timestep, guidance, hidden_states) |
| 61 | + ) |
| 62 | + |
| 63 | + repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) |
| 64 | + |
| 65 | + controlnet_block_samples = () |
| 66 | + for i, block in enumerate(self.transformer_blocks): |
| 67 | + encoder_hidden_states, hidden_states = block( |
| 68 | + hidden_states=hidden_states, |
| 69 | + encoder_hidden_states=encoder_hidden_states, |
| 70 | + encoder_hidden_states_mask=encoder_hidden_states_mask, |
| 71 | + temb=temb, |
| 72 | + image_rotary_emb=image_rotary_emb, |
| 73 | + ) |
| 74 | + |
| 75 | + controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat |
| 76 | + |
| 77 | + return {"input": controlnet_block_samples[:self.main_model_double]} |
0 commit comments