Skip to content

Commit eb2f78b

Browse files
[Training Node] algo support, grad acc, optional grad ckpt (Comfy-Org#9015)
* Add factorization utils for lokr * Add lokr train impl * Add loha train impl * Add adapter map for algo selection * Add optional grad ckpt and algo selection * Update __init__.py * correct key name for loha * Use custom fwd/bwd func and better init for loha * Support gradient accumulation * Fix bugs of loha * use more stable init * Add OFT training * linting
1 parent e729a5c commit eb2f78b

File tree

6 files changed

+372
-15
lines changed

6 files changed

+372
-15
lines changed

comfy/weight_adapter/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,20 @@
1515
OFTAdapter,
1616
BOFTAdapter,
1717
]
18+
adapter_maps: dict[str, type[WeightAdapterBase]] = {
19+
"LoRA": LoRAAdapter,
20+
"LoHa": LoHaAdapter,
21+
"LoKr": LoKrAdapter,
22+
"OFT": OFTAdapter,
23+
## We disable not implemented algo for now
24+
# "GLoRA": GLoRAAdapter,
25+
# "BOFT": BOFTAdapter,
26+
}
27+
1828

1929
__all__ = [
2030
"WeightAdapterBase",
2131
"WeightAdapterTrainBase",
22-
"adapters"
32+
"adapters",
33+
"adapter_maps",
2334
] + [a.__name__ for a in adapters]

comfy/weight_adapter/base.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
133133
def tucker_weight(wa, wb, t):
134134
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
135135
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
136+
137+
138+
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
139+
"""
140+
return a tuple of two value of input dimension decomposed by the number closest to factor
141+
second value is higher or equal than first value.
142+
143+
examples)
144+
factor
145+
-1 2 4 8 16 ...
146+
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
147+
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
148+
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
149+
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
150+
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
151+
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
152+
"""
153+
154+
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
155+
m = factor
156+
n = dimension // factor
157+
if m > n:
158+
n, m = m, n
159+
return m, n
160+
if factor < 0:
161+
factor = dimension
162+
m, n = 1, dimension
163+
length = m + n
164+
while m < n:
165+
new_m = m + 1
166+
while dimension % new_m != 0:
167+
new_m += 1
168+
new_n = dimension // new_m
169+
if new_m + new_n > length or new_m > factor:
170+
break
171+
else:
172+
m, n = new_m, new_n
173+
if m > n:
174+
n, m = m, n
175+
return m, n

comfy/weight_adapter/loha.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,120 @@
33

44
import torch
55
import comfy.model_management
6-
from .base import WeightAdapterBase, weight_decompose
6+
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
7+
8+
9+
class HadaWeight(torch.autograd.Function):
10+
@staticmethod
11+
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
12+
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
13+
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
14+
return diff_weight
15+
16+
@staticmethod
17+
def backward(ctx, grad_out):
18+
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
19+
grad_out = grad_out * scale
20+
temp = grad_out * (w2u @ w2d)
21+
grad_w1u = temp @ w1d.T
22+
grad_w1d = w1u.T @ temp
23+
24+
temp = grad_out * (w1u @ w1d)
25+
grad_w2u = temp @ w2d.T
26+
grad_w2d = w2u.T @ temp
27+
28+
del temp
29+
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
30+
31+
32+
class HadaWeightTucker(torch.autograd.Function):
33+
@staticmethod
34+
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
35+
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
36+
37+
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
38+
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
39+
40+
return rebuild1 * rebuild2 * scale
41+
42+
@staticmethod
43+
def backward(ctx, grad_out):
44+
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
45+
grad_out = grad_out * scale
46+
47+
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
48+
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
49+
50+
grad_w = rebuild * grad_out
51+
del rebuild
52+
53+
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
54+
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
55+
del grad_w, temp
56+
57+
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
58+
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
59+
del grad_temp
60+
61+
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
62+
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
63+
64+
grad_w = rebuild * grad_out
65+
del rebuild
66+
67+
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
68+
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
69+
del grad_w, temp
70+
71+
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
72+
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
73+
del grad_temp
74+
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
75+
76+
77+
class LohaDiff(WeightAdapterTrainBase):
78+
def __init__(self, weights):
79+
super().__init__()
80+
# Unpack weights tuple from LoHaAdapter
81+
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
82+
83+
# Create trainable parameters
84+
self.hada_w1_a = torch.nn.Parameter(w1a)
85+
self.hada_w1_b = torch.nn.Parameter(w1b)
86+
self.hada_w2_a = torch.nn.Parameter(w2a)
87+
self.hada_w2_b = torch.nn.Parameter(w2b)
88+
89+
self.use_tucker = False
90+
if t1 is not None and t2 is not None:
91+
self.use_tucker = True
92+
self.hada_t1 = torch.nn.Parameter(t1)
93+
self.hada_t2 = torch.nn.Parameter(t2)
94+
else:
95+
# Keep the attributes for consistent access
96+
self.hada_t1 = None
97+
self.hada_t2 = None
98+
99+
# Store rank and non-trainable alpha
100+
self.rank = w1b.shape[0]
101+
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
102+
103+
def __call__(self, w):
104+
org_dtype = w.dtype
105+
106+
scale = self.alpha / self.rank
107+
if self.use_tucker:
108+
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
109+
else:
110+
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
111+
112+
# Add the scaled difference to the original weight
113+
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
114+
115+
return weight.to(org_dtype)
116+
117+
def passive_memory_usage(self):
118+
"""Calculates memory usage of the trainable parameters."""
119+
return sum(param.numel() * param.element_size() for param in self.parameters())
7120

8121

9122
class LoHaAdapter(WeightAdapterBase):
@@ -13,6 +126,25 @@ def __init__(self, loaded_keys, weights):
13126
self.loaded_keys = loaded_keys
14127
self.weights = weights
15128

129+
@classmethod
130+
def create_train(cls, weight, rank=1, alpha=1.0):
131+
out_dim = weight.shape[0]
132+
in_dim = weight.shape[1:].numel()
133+
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
134+
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
135+
torch.nn.init.normal_(mat1, 0.1)
136+
torch.nn.init.constant_(mat2, 0.0)
137+
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
138+
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
139+
torch.nn.init.normal_(mat3, 0.1)
140+
torch.nn.init.normal_(mat4, 0.01)
141+
return LohaDiff(
142+
(mat1, mat2, alpha, mat3, mat4, None, None, None)
143+
)
144+
145+
def to_train(self):
146+
return LohaDiff(self.weights)
147+
16148
@classmethod
17149
def load(
18150
cls,

comfy/weight_adapter/lokr.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,77 @@
33

44
import torch
55
import comfy.model_management
6-
from .base import WeightAdapterBase, weight_decompose
6+
from .base import (
7+
WeightAdapterBase,
8+
WeightAdapterTrainBase,
9+
weight_decompose,
10+
factorization,
11+
)
12+
13+
14+
class LokrDiff(WeightAdapterTrainBase):
15+
def __init__(self, weights):
16+
super().__init__()
17+
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
18+
self.use_tucker = False
19+
if lokr_w1_a is not None:
20+
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
21+
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
22+
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
23+
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
24+
self.w1_rebuild = True
25+
self.ranka = rank_a
26+
27+
if lokr_w2_a is not None:
28+
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
29+
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
30+
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
31+
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
32+
if lokr_t2 is not None:
33+
self.use_tucker = True
34+
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
35+
self.w2_rebuild = True
36+
self.rankb = rank_b
37+
38+
if lokr_w1 is not None:
39+
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
40+
self.w1_rebuild = False
41+
42+
if lokr_w2 is not None:
43+
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
44+
self.w2_rebuild = False
45+
46+
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
47+
48+
@property
49+
def w1(self):
50+
if self.w1_rebuild:
51+
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
52+
else:
53+
return self.lokr_w1
54+
55+
@property
56+
def w2(self):
57+
if self.w2_rebuild:
58+
if self.use_tucker:
59+
w2 = torch.einsum(
60+
'i j k l, j r, i p -> p r k l',
61+
self.lokr_t2,
62+
self.lokr_w2_b,
63+
self.lokr_w2_a
64+
)
65+
else:
66+
w2 = self.lokr_w2_a @ self.lokr_w2_b
67+
return w2 * (self.alpha / self.rankb)
68+
else:
69+
return self.lokr_w2
70+
71+
def __call__(self, w):
72+
diff = torch.kron(self.w1, self.w2)
73+
return w + diff.reshape(w.shape).to(w)
74+
75+
def passive_memory_usage(self):
76+
return sum(param.numel() * param.element_size() for param in self.parameters())
777

878

979
class LoKrAdapter(WeightAdapterBase):
@@ -13,6 +83,20 @@ def __init__(self, loaded_keys, weights):
1383
self.loaded_keys = loaded_keys
1484
self.weights = weights
1585

86+
@classmethod
87+
def create_train(cls, weight, rank=1, alpha=1.0):
88+
out_dim = weight.shape[0]
89+
in_dim = weight.shape[1:].numel()
90+
out1, out2 = factorization(out_dim, rank)
91+
in1, in2 = factorization(in_dim, rank)
92+
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
93+
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
94+
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
95+
torch.nn.init.constant_(mat1, 0.0)
96+
return LokrDiff(
97+
(mat1, mat2, alpha, None, None, None, None, None, None)
98+
)
99+
16100
@classmethod
17101
def load(
18102
cls,

0 commit comments

Comments
 (0)