Skip to content

Commit 3696d16

Browse files
Add support for GLIGEN textbox model.
1 parent 472b1cc commit 3696d16

File tree

9 files changed

+491
-28
lines changed

9 files changed

+491
-28
lines changed

comfy/gligen.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
import torch
2+
from torch import nn, einsum
3+
from ldm.modules.attention import CrossAttention
4+
from inspect import isfunction
5+
6+
7+
def exists(val):
8+
return val is not None
9+
10+
11+
def uniq(arr):
12+
return{el: True for el in arr}.keys()
13+
14+
15+
def default(val, d):
16+
if exists(val):
17+
return val
18+
return d() if isfunction(d) else d
19+
20+
21+
# feedforward
22+
class GEGLU(nn.Module):
23+
def __init__(self, dim_in, dim_out):
24+
super().__init__()
25+
self.proj = nn.Linear(dim_in, dim_out * 2)
26+
27+
def forward(self, x):
28+
x, gate = self.proj(x).chunk(2, dim=-1)
29+
return x * torch.nn.functional.gelu(gate)
30+
31+
32+
class FeedForward(nn.Module):
33+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
34+
super().__init__()
35+
inner_dim = int(dim * mult)
36+
dim_out = default(dim_out, dim)
37+
project_in = nn.Sequential(
38+
nn.Linear(dim, inner_dim),
39+
nn.GELU()
40+
) if not glu else GEGLU(dim, inner_dim)
41+
42+
self.net = nn.Sequential(
43+
project_in,
44+
nn.Dropout(dropout),
45+
nn.Linear(inner_dim, dim_out)
46+
)
47+
48+
def forward(self, x):
49+
return self.net(x)
50+
51+
52+
class GatedCrossAttentionDense(nn.Module):
53+
def __init__(self, query_dim, context_dim, n_heads, d_head):
54+
super().__init__()
55+
56+
self.attn = CrossAttention(
57+
query_dim=query_dim,
58+
context_dim=context_dim,
59+
heads=n_heads,
60+
dim_head=d_head)
61+
self.ff = FeedForward(query_dim, glu=True)
62+
63+
self.norm1 = nn.LayerNorm(query_dim)
64+
self.norm2 = nn.LayerNorm(query_dim)
65+
66+
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
67+
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
68+
69+
# this can be useful: we can externally change magnitude of tanh(alpha)
70+
# for example, when it is set to 0, then the entire model is same as
71+
# original one
72+
self.scale = 1
73+
74+
def forward(self, x, objs):
75+
76+
x = x + self.scale * \
77+
torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
78+
x = x + self.scale * \
79+
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
80+
81+
return x
82+
83+
84+
class GatedSelfAttentionDense(nn.Module):
85+
def __init__(self, query_dim, context_dim, n_heads, d_head):
86+
super().__init__()
87+
88+
# we need a linear projection since we need cat visual feature and obj
89+
# feature
90+
self.linear = nn.Linear(context_dim, query_dim)
91+
92+
self.attn = CrossAttention(
93+
query_dim=query_dim,
94+
context_dim=query_dim,
95+
heads=n_heads,
96+
dim_head=d_head)
97+
self.ff = FeedForward(query_dim, glu=True)
98+
99+
self.norm1 = nn.LayerNorm(query_dim)
100+
self.norm2 = nn.LayerNorm(query_dim)
101+
102+
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
103+
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
104+
105+
# this can be useful: we can externally change magnitude of tanh(alpha)
106+
# for example, when it is set to 0, then the entire model is same as
107+
# original one
108+
self.scale = 1
109+
110+
def forward(self, x, objs):
111+
112+
N_visual = x.shape[1]
113+
objs = self.linear(objs)
114+
115+
x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
116+
self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
117+
x = x + self.scale * \
118+
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
119+
120+
return x
121+
122+
123+
class GatedSelfAttentionDense2(nn.Module):
124+
def __init__(self, query_dim, context_dim, n_heads, d_head):
125+
super().__init__()
126+
127+
# we need a linear projection since we need cat visual feature and obj
128+
# feature
129+
self.linear = nn.Linear(context_dim, query_dim)
130+
131+
self.attn = CrossAttention(
132+
query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
133+
self.ff = FeedForward(query_dim, glu=True)
134+
135+
self.norm1 = nn.LayerNorm(query_dim)
136+
self.norm2 = nn.LayerNorm(query_dim)
137+
138+
self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
139+
self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
140+
141+
# this can be useful: we can externally change magnitude of tanh(alpha)
142+
# for example, when it is set to 0, then the entire model is same as
143+
# original one
144+
self.scale = 1
145+
146+
def forward(self, x, objs):
147+
148+
B, N_visual, _ = x.shape
149+
B, N_ground, _ = objs.shape
150+
151+
objs = self.linear(objs)
152+
153+
# sanity check
154+
size_v = math.sqrt(N_visual)
155+
size_g = math.sqrt(N_ground)
156+
assert int(size_v) == size_v, "Visual tokens must be square rootable"
157+
assert int(size_g) == size_g, "Grounding tokens must be square rootable"
158+
size_v = int(size_v)
159+
size_g = int(size_g)
160+
161+
# select grounding token and resize it to visual token size as residual
162+
out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
163+
:, N_visual:, :]
164+
out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
165+
out = torch.nn.functional.interpolate(
166+
out, (size_v, size_v), mode='bicubic')
167+
residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
168+
169+
# add residual to visual feature
170+
x = x + self.scale * torch.tanh(self.alpha_attn) * residual
171+
x = x + self.scale * \
172+
torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
173+
174+
return x
175+
176+
177+
class FourierEmbedder():
178+
def __init__(self, num_freqs=64, temperature=100):
179+
180+
self.num_freqs = num_freqs
181+
self.temperature = temperature
182+
self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
183+
184+
@torch.no_grad()
185+
def __call__(self, x, cat_dim=-1):
186+
"x: arbitrary shape of tensor. dim: cat dim"
187+
out = []
188+
for freq in self.freq_bands:
189+
out.append(torch.sin(freq * x))
190+
out.append(torch.cos(freq * x))
191+
return torch.cat(out, cat_dim)
192+
193+
194+
class PositionNet(nn.Module):
195+
def __init__(self, in_dim, out_dim, fourier_freqs=8):
196+
super().__init__()
197+
self.in_dim = in_dim
198+
self.out_dim = out_dim
199+
200+
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
201+
self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
202+
203+
self.linears = nn.Sequential(
204+
nn.Linear(self.in_dim + self.position_dim, 512),
205+
nn.SiLU(),
206+
nn.Linear(512, 512),
207+
nn.SiLU(),
208+
nn.Linear(512, out_dim),
209+
)
210+
211+
self.null_positive_feature = torch.nn.Parameter(
212+
torch.zeros([self.in_dim]))
213+
self.null_position_feature = torch.nn.Parameter(
214+
torch.zeros([self.position_dim]))
215+
216+
def forward(self, boxes, masks, positive_embeddings):
217+
B, N, _ = boxes.shape
218+
masks = masks.unsqueeze(-1)
219+
220+
# embedding position (it may includes padding as placeholder)
221+
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
222+
223+
# learnable null embedding
224+
positive_null = self.null_positive_feature.view(1, 1, -1)
225+
xyxy_null = self.null_position_feature.view(1, 1, -1)
226+
227+
# replace padding with learnable null embedding
228+
positive_embeddings = positive_embeddings * \
229+
masks + (1 - masks) * positive_null
230+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
231+
232+
objs = self.linears(
233+
torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
234+
assert objs.shape == torch.Size([B, N, self.out_dim])
235+
return objs
236+
237+
238+
class Gligen(nn.Module):
239+
def __init__(self, modules, position_net, key_dim):
240+
super().__init__()
241+
self.module_list = nn.ModuleList(modules)
242+
self.position_net = position_net
243+
self.key_dim = key_dim
244+
self.max_objs = 30
245+
246+
def _set_position(self, boxes, masks, positive_embeddings):
247+
objs = self.position_net(boxes, masks, positive_embeddings)
248+
249+
def func(key, x):
250+
module = self.module_list[key]
251+
return module(x, objs)
252+
return func
253+
254+
def set_position(self, latent_image_shape, position_params, device):
255+
batch, c, h, w = latent_image_shape
256+
masks = torch.zeros([self.max_objs], device="cpu")
257+
boxes = []
258+
positive_embeddings = []
259+
for p in position_params:
260+
x1 = (p[4]) / w
261+
y1 = (p[3]) / h
262+
x2 = (p[4] + p[2]) / w
263+
y2 = (p[3] + p[1]) / h
264+
masks[len(boxes)] = 1.0
265+
boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
266+
positive_embeddings += [p[0]]
267+
append_boxes = []
268+
append_conds = []
269+
if len(boxes) < self.max_objs:
270+
append_boxes = [torch.zeros(
271+
[self.max_objs - len(boxes), 4], device="cpu")]
272+
append_conds = [torch.zeros(
273+
[self.max_objs - len(boxes), self.key_dim], device="cpu")]
274+
275+
box_out = torch.cat(
276+
boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
277+
masks = masks.unsqueeze(0).repeat(batch, 1)
278+
conds = torch.cat(positive_embeddings +
279+
append_conds).unsqueeze(0).repeat(batch, 1, 1)
280+
return self._set_position(
281+
box_out.to(device),
282+
masks.to(device),
283+
conds.to(device))
284+
285+
def set_empty(self, latent_image_shape, device):
286+
batch, c, h, w = latent_image_shape
287+
masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
288+
box_out = torch.zeros([self.max_objs, 4],
289+
device="cpu").repeat(batch, 1, 1)
290+
conds = torch.zeros([self.max_objs, self.key_dim],
291+
device="cpu").repeat(batch, 1, 1)
292+
return self._set_position(
293+
box_out.to(device),
294+
masks.to(device),
295+
conds.to(device))
296+
297+
def cleanup(self):
298+
pass
299+
300+
def get_models(self):
301+
return [self]
302+
303+
def load_gligen(sd):
304+
sd_k = sd.keys()
305+
output_list = []
306+
key_dim = 768
307+
for a in ["input_blocks", "middle_block", "output_blocks"]:
308+
for b in range(20):
309+
k_temp = filter(lambda k: "{}.{}.".format(a, b)
310+
in k and ".fuser." in k, sd_k)
311+
k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
312+
313+
n_sd = {}
314+
for k in k_temp:
315+
n_sd[k[1]] = sd[k[0]]
316+
if len(n_sd) > 0:
317+
query_dim = n_sd["linear.weight"].shape[0]
318+
key_dim = n_sd["linear.weight"].shape[1]
319+
320+
if key_dim == 768: # SD1.x
321+
n_heads = 8
322+
d_head = query_dim // n_heads
323+
else:
324+
d_head = 64
325+
n_heads = query_dim // d_head
326+
327+
gated = GatedSelfAttentionDense(
328+
query_dim, key_dim, n_heads, d_head)
329+
gated.load_state_dict(n_sd, strict=False)
330+
output_list.append(gated)
331+
332+
if "position_net.null_positive_feature" in sd_k:
333+
in_dim = sd["position_net.null_positive_feature"].shape[0]
334+
out_dim = sd["position_net.linears.4.weight"].shape[0]
335+
336+
class WeightsLoader(torch.nn.Module):
337+
pass
338+
w = WeightsLoader()
339+
w.position_net = PositionNet(in_dim, out_dim)
340+
w.load_state_dict(sd, strict=False)
341+
342+
gligen = Gligen(output_list, w.position_net, key_dim)
343+
return gligen

comfy/ldm/modules/attention.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,14 @@ def forward(self, x, context=None, transformer_options={}):
510510
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
511511

512512
def _forward(self, x, context=None, transformer_options={}):
513+
current_index = None
514+
if "current_index" in transformer_options:
515+
current_index = transformer_options["current_index"]
516+
if "patches" in transformer_options:
517+
transformer_patches = transformer_options["patches"]
518+
else:
519+
transformer_patches = {}
520+
513521
n = self.norm1(x)
514522
if "tomesd" in transformer_options:
515523
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
@@ -518,11 +526,19 @@ def _forward(self, x, context=None, transformer_options={}):
518526
n = self.attn1(n, context=context if self.disable_self_attn else None)
519527

520528
x += n
529+
if "middle_patch" in transformer_patches:
530+
patch = transformer_patches["middle_patch"]
531+
for p in patch:
532+
x = p(current_index, x)
533+
521534
n = self.norm2(x)
522535
n = self.attn2(n, context=context)
523536

524537
x += n
525538
x = self.ff(self.norm3(x)) + x
539+
540+
if current_index is not None:
541+
transformer_options["current_index"] += 1
526542
return x
527543

528544

comfy/ldm/modules/diffusionmodules/openaimodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,8 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
782782
:return: an [N x C x ...] Tensor of outputs.
783783
"""
784784
transformer_options["original_shape"] = list(x.shape)
785+
transformer_options["current_index"] = 0
786+
785787
assert (y is not None) == (
786788
self.num_classes is not None
787789
), "must specify y if and only if the model is class-conditional"

0 commit comments

Comments
 (0)