|
| 1 | +import torch |
| 2 | +from torch import nn, einsum |
| 3 | +from ldm.modules.attention import CrossAttention |
| 4 | +from inspect import isfunction |
| 5 | + |
| 6 | + |
| 7 | +def exists(val): |
| 8 | + return val is not None |
| 9 | + |
| 10 | + |
| 11 | +def uniq(arr): |
| 12 | + return{el: True for el in arr}.keys() |
| 13 | + |
| 14 | + |
| 15 | +def default(val, d): |
| 16 | + if exists(val): |
| 17 | + return val |
| 18 | + return d() if isfunction(d) else d |
| 19 | + |
| 20 | + |
| 21 | +# feedforward |
| 22 | +class GEGLU(nn.Module): |
| 23 | + def __init__(self, dim_in, dim_out): |
| 24 | + super().__init__() |
| 25 | + self.proj = nn.Linear(dim_in, dim_out * 2) |
| 26 | + |
| 27 | + def forward(self, x): |
| 28 | + x, gate = self.proj(x).chunk(2, dim=-1) |
| 29 | + return x * torch.nn.functional.gelu(gate) |
| 30 | + |
| 31 | + |
| 32 | +class FeedForward(nn.Module): |
| 33 | + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): |
| 34 | + super().__init__() |
| 35 | + inner_dim = int(dim * mult) |
| 36 | + dim_out = default(dim_out, dim) |
| 37 | + project_in = nn.Sequential( |
| 38 | + nn.Linear(dim, inner_dim), |
| 39 | + nn.GELU() |
| 40 | + ) if not glu else GEGLU(dim, inner_dim) |
| 41 | + |
| 42 | + self.net = nn.Sequential( |
| 43 | + project_in, |
| 44 | + nn.Dropout(dropout), |
| 45 | + nn.Linear(inner_dim, dim_out) |
| 46 | + ) |
| 47 | + |
| 48 | + def forward(self, x): |
| 49 | + return self.net(x) |
| 50 | + |
| 51 | + |
| 52 | +class GatedCrossAttentionDense(nn.Module): |
| 53 | + def __init__(self, query_dim, context_dim, n_heads, d_head): |
| 54 | + super().__init__() |
| 55 | + |
| 56 | + self.attn = CrossAttention( |
| 57 | + query_dim=query_dim, |
| 58 | + context_dim=context_dim, |
| 59 | + heads=n_heads, |
| 60 | + dim_head=d_head) |
| 61 | + self.ff = FeedForward(query_dim, glu=True) |
| 62 | + |
| 63 | + self.norm1 = nn.LayerNorm(query_dim) |
| 64 | + self.norm2 = nn.LayerNorm(query_dim) |
| 65 | + |
| 66 | + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) |
| 67 | + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) |
| 68 | + |
| 69 | + # this can be useful: we can externally change magnitude of tanh(alpha) |
| 70 | + # for example, when it is set to 0, then the entire model is same as |
| 71 | + # original one |
| 72 | + self.scale = 1 |
| 73 | + |
| 74 | + def forward(self, x, objs): |
| 75 | + |
| 76 | + x = x + self.scale * \ |
| 77 | + torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) |
| 78 | + x = x + self.scale * \ |
| 79 | + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) |
| 80 | + |
| 81 | + return x |
| 82 | + |
| 83 | + |
| 84 | +class GatedSelfAttentionDense(nn.Module): |
| 85 | + def __init__(self, query_dim, context_dim, n_heads, d_head): |
| 86 | + super().__init__() |
| 87 | + |
| 88 | + # we need a linear projection since we need cat visual feature and obj |
| 89 | + # feature |
| 90 | + self.linear = nn.Linear(context_dim, query_dim) |
| 91 | + |
| 92 | + self.attn = CrossAttention( |
| 93 | + query_dim=query_dim, |
| 94 | + context_dim=query_dim, |
| 95 | + heads=n_heads, |
| 96 | + dim_head=d_head) |
| 97 | + self.ff = FeedForward(query_dim, glu=True) |
| 98 | + |
| 99 | + self.norm1 = nn.LayerNorm(query_dim) |
| 100 | + self.norm2 = nn.LayerNorm(query_dim) |
| 101 | + |
| 102 | + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) |
| 103 | + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) |
| 104 | + |
| 105 | + # this can be useful: we can externally change magnitude of tanh(alpha) |
| 106 | + # for example, when it is set to 0, then the entire model is same as |
| 107 | + # original one |
| 108 | + self.scale = 1 |
| 109 | + |
| 110 | + def forward(self, x, objs): |
| 111 | + |
| 112 | + N_visual = x.shape[1] |
| 113 | + objs = self.linear(objs) |
| 114 | + |
| 115 | + x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( |
| 116 | + self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] |
| 117 | + x = x + self.scale * \ |
| 118 | + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) |
| 119 | + |
| 120 | + return x |
| 121 | + |
| 122 | + |
| 123 | +class GatedSelfAttentionDense2(nn.Module): |
| 124 | + def __init__(self, query_dim, context_dim, n_heads, d_head): |
| 125 | + super().__init__() |
| 126 | + |
| 127 | + # we need a linear projection since we need cat visual feature and obj |
| 128 | + # feature |
| 129 | + self.linear = nn.Linear(context_dim, query_dim) |
| 130 | + |
| 131 | + self.attn = CrossAttention( |
| 132 | + query_dim=query_dim, context_dim=query_dim, dim_head=d_head) |
| 133 | + self.ff = FeedForward(query_dim, glu=True) |
| 134 | + |
| 135 | + self.norm1 = nn.LayerNorm(query_dim) |
| 136 | + self.norm2 = nn.LayerNorm(query_dim) |
| 137 | + |
| 138 | + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) |
| 139 | + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) |
| 140 | + |
| 141 | + # this can be useful: we can externally change magnitude of tanh(alpha) |
| 142 | + # for example, when it is set to 0, then the entire model is same as |
| 143 | + # original one |
| 144 | + self.scale = 1 |
| 145 | + |
| 146 | + def forward(self, x, objs): |
| 147 | + |
| 148 | + B, N_visual, _ = x.shape |
| 149 | + B, N_ground, _ = objs.shape |
| 150 | + |
| 151 | + objs = self.linear(objs) |
| 152 | + |
| 153 | + # sanity check |
| 154 | + size_v = math.sqrt(N_visual) |
| 155 | + size_g = math.sqrt(N_ground) |
| 156 | + assert int(size_v) == size_v, "Visual tokens must be square rootable" |
| 157 | + assert int(size_g) == size_g, "Grounding tokens must be square rootable" |
| 158 | + size_v = int(size_v) |
| 159 | + size_g = int(size_g) |
| 160 | + |
| 161 | + # select grounding token and resize it to visual token size as residual |
| 162 | + out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ |
| 163 | + :, N_visual:, :] |
| 164 | + out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) |
| 165 | + out = torch.nn.functional.interpolate( |
| 166 | + out, (size_v, size_v), mode='bicubic') |
| 167 | + residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) |
| 168 | + |
| 169 | + # add residual to visual feature |
| 170 | + x = x + self.scale * torch.tanh(self.alpha_attn) * residual |
| 171 | + x = x + self.scale * \ |
| 172 | + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) |
| 173 | + |
| 174 | + return x |
| 175 | + |
| 176 | + |
| 177 | +class FourierEmbedder(): |
| 178 | + def __init__(self, num_freqs=64, temperature=100): |
| 179 | + |
| 180 | + self.num_freqs = num_freqs |
| 181 | + self.temperature = temperature |
| 182 | + self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) |
| 183 | + |
| 184 | + @torch.no_grad() |
| 185 | + def __call__(self, x, cat_dim=-1): |
| 186 | + "x: arbitrary shape of tensor. dim: cat dim" |
| 187 | + out = [] |
| 188 | + for freq in self.freq_bands: |
| 189 | + out.append(torch.sin(freq * x)) |
| 190 | + out.append(torch.cos(freq * x)) |
| 191 | + return torch.cat(out, cat_dim) |
| 192 | + |
| 193 | + |
| 194 | +class PositionNet(nn.Module): |
| 195 | + def __init__(self, in_dim, out_dim, fourier_freqs=8): |
| 196 | + super().__init__() |
| 197 | + self.in_dim = in_dim |
| 198 | + self.out_dim = out_dim |
| 199 | + |
| 200 | + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) |
| 201 | + self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy |
| 202 | + |
| 203 | + self.linears = nn.Sequential( |
| 204 | + nn.Linear(self.in_dim + self.position_dim, 512), |
| 205 | + nn.SiLU(), |
| 206 | + nn.Linear(512, 512), |
| 207 | + nn.SiLU(), |
| 208 | + nn.Linear(512, out_dim), |
| 209 | + ) |
| 210 | + |
| 211 | + self.null_positive_feature = torch.nn.Parameter( |
| 212 | + torch.zeros([self.in_dim])) |
| 213 | + self.null_position_feature = torch.nn.Parameter( |
| 214 | + torch.zeros([self.position_dim])) |
| 215 | + |
| 216 | + def forward(self, boxes, masks, positive_embeddings): |
| 217 | + B, N, _ = boxes.shape |
| 218 | + masks = masks.unsqueeze(-1) |
| 219 | + |
| 220 | + # embedding position (it may includes padding as placeholder) |
| 221 | + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C |
| 222 | + |
| 223 | + # learnable null embedding |
| 224 | + positive_null = self.null_positive_feature.view(1, 1, -1) |
| 225 | + xyxy_null = self.null_position_feature.view(1, 1, -1) |
| 226 | + |
| 227 | + # replace padding with learnable null embedding |
| 228 | + positive_embeddings = positive_embeddings * \ |
| 229 | + masks + (1 - masks) * positive_null |
| 230 | + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null |
| 231 | + |
| 232 | + objs = self.linears( |
| 233 | + torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) |
| 234 | + assert objs.shape == torch.Size([B, N, self.out_dim]) |
| 235 | + return objs |
| 236 | + |
| 237 | + |
| 238 | +class Gligen(nn.Module): |
| 239 | + def __init__(self, modules, position_net, key_dim): |
| 240 | + super().__init__() |
| 241 | + self.module_list = nn.ModuleList(modules) |
| 242 | + self.position_net = position_net |
| 243 | + self.key_dim = key_dim |
| 244 | + self.max_objs = 30 |
| 245 | + |
| 246 | + def _set_position(self, boxes, masks, positive_embeddings): |
| 247 | + objs = self.position_net(boxes, masks, positive_embeddings) |
| 248 | + |
| 249 | + def func(key, x): |
| 250 | + module = self.module_list[key] |
| 251 | + return module(x, objs) |
| 252 | + return func |
| 253 | + |
| 254 | + def set_position(self, latent_image_shape, position_params, device): |
| 255 | + batch, c, h, w = latent_image_shape |
| 256 | + masks = torch.zeros([self.max_objs], device="cpu") |
| 257 | + boxes = [] |
| 258 | + positive_embeddings = [] |
| 259 | + for p in position_params: |
| 260 | + x1 = (p[4]) / w |
| 261 | + y1 = (p[3]) / h |
| 262 | + x2 = (p[4] + p[2]) / w |
| 263 | + y2 = (p[3] + p[1]) / h |
| 264 | + masks[len(boxes)] = 1.0 |
| 265 | + boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] |
| 266 | + positive_embeddings += [p[0]] |
| 267 | + append_boxes = [] |
| 268 | + append_conds = [] |
| 269 | + if len(boxes) < self.max_objs: |
| 270 | + append_boxes = [torch.zeros( |
| 271 | + [self.max_objs - len(boxes), 4], device="cpu")] |
| 272 | + append_conds = [torch.zeros( |
| 273 | + [self.max_objs - len(boxes), self.key_dim], device="cpu")] |
| 274 | + |
| 275 | + box_out = torch.cat( |
| 276 | + boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) |
| 277 | + masks = masks.unsqueeze(0).repeat(batch, 1) |
| 278 | + conds = torch.cat(positive_embeddings + |
| 279 | + append_conds).unsqueeze(0).repeat(batch, 1, 1) |
| 280 | + return self._set_position( |
| 281 | + box_out.to(device), |
| 282 | + masks.to(device), |
| 283 | + conds.to(device)) |
| 284 | + |
| 285 | + def set_empty(self, latent_image_shape, device): |
| 286 | + batch, c, h, w = latent_image_shape |
| 287 | + masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) |
| 288 | + box_out = torch.zeros([self.max_objs, 4], |
| 289 | + device="cpu").repeat(batch, 1, 1) |
| 290 | + conds = torch.zeros([self.max_objs, self.key_dim], |
| 291 | + device="cpu").repeat(batch, 1, 1) |
| 292 | + return self._set_position( |
| 293 | + box_out.to(device), |
| 294 | + masks.to(device), |
| 295 | + conds.to(device)) |
| 296 | + |
| 297 | + def cleanup(self): |
| 298 | + pass |
| 299 | + |
| 300 | + def get_models(self): |
| 301 | + return [self] |
| 302 | + |
| 303 | +def load_gligen(sd): |
| 304 | + sd_k = sd.keys() |
| 305 | + output_list = [] |
| 306 | + key_dim = 768 |
| 307 | + for a in ["input_blocks", "middle_block", "output_blocks"]: |
| 308 | + for b in range(20): |
| 309 | + k_temp = filter(lambda k: "{}.{}.".format(a, b) |
| 310 | + in k and ".fuser." in k, sd_k) |
| 311 | + k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) |
| 312 | + |
| 313 | + n_sd = {} |
| 314 | + for k in k_temp: |
| 315 | + n_sd[k[1]] = sd[k[0]] |
| 316 | + if len(n_sd) > 0: |
| 317 | + query_dim = n_sd["linear.weight"].shape[0] |
| 318 | + key_dim = n_sd["linear.weight"].shape[1] |
| 319 | + |
| 320 | + if key_dim == 768: # SD1.x |
| 321 | + n_heads = 8 |
| 322 | + d_head = query_dim // n_heads |
| 323 | + else: |
| 324 | + d_head = 64 |
| 325 | + n_heads = query_dim // d_head |
| 326 | + |
| 327 | + gated = GatedSelfAttentionDense( |
| 328 | + query_dim, key_dim, n_heads, d_head) |
| 329 | + gated.load_state_dict(n_sd, strict=False) |
| 330 | + output_list.append(gated) |
| 331 | + |
| 332 | + if "position_net.null_positive_feature" in sd_k: |
| 333 | + in_dim = sd["position_net.null_positive_feature"].shape[0] |
| 334 | + out_dim = sd["position_net.linears.4.weight"].shape[0] |
| 335 | + |
| 336 | + class WeightsLoader(torch.nn.Module): |
| 337 | + pass |
| 338 | + w = WeightsLoader() |
| 339 | + w.position_net = PositionNet(in_dim, out_dim) |
| 340 | + w.load_state_dict(sd, strict=False) |
| 341 | + |
| 342 | + gligen = Gligen(output_list, w.position_net, key_dim) |
| 343 | + return gligen |
0 commit comments