Skip to content

Commit b120446

Browse files
committed
update grid extension
1 parent c5ebd60 commit b120446

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+759
-155
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ expressiveness
1212
figures
1313
molecule
1414
applications
15+
experiments

kan/.ipynb_checkpoints/KANLayer-checkpoint.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_
109109

110110
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
111111
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
112-
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * self.mask).requires_grad_(sp_trainable) # make scale trainable
112+
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
113113
self.base_fun = base_fun
114114

115115

@@ -197,11 +197,13 @@ def update_grid_from_samples(self, x, mode='sample'):
197197
def get_grid(num_interval):
198198
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
199199
grid_adaptive = x_pos[ids, :].permute(1,0)
200-
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
201-
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
200+
margin = 0.00
201+
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
202+
grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
202203
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
203204
return grid
204205

206+
205207
grid = get_grid(num_interval)
206208

207209
if mode == 'grid':
@@ -210,6 +212,8 @@ def get_grid(num_interval):
210212
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
211213

212214
self.grid.data = extend_grid(grid, k_extend=self.k)
215+
#print('x_pos 2', x_pos.shape)
216+
#print('y_eval 2', y_eval.shape)
213217
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
214218

215219
def initialize_grid_from_parent(self, parent, x, mode='sample'):
@@ -240,16 +244,40 @@ def initialize_grid_from_parent(self, parent, x, mode='sample'):
240244

241245
batch = x.shape[0]
242246

247+
# shrink grid
243248
x_pos = torch.sort(x, dim=0)[0]
244249
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
245250
num_interval = self.grid.shape[1] - 1 - 2*self.k
246251

252+
253+
'''
254+
# based on samples
247255
def get_grid(num_interval):
248256
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
249257
grid_adaptive = x_pos[ids, :].permute(1,0)
250258
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
251259
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
252260
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
261+
return grid'''
262+
263+
#print('p', parent.grid)
264+
# based on interpolating parent grid
265+
def get_grid(num_interval):
266+
x_pos = parent.grid[:,parent.k:-parent.k]
267+
#print('x_pos', x_pos)
268+
sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)
269+
270+
#print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
271+
#print('sp2_coef_shape', sp2.coef.shape)
272+
sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
273+
shp = sp2_coef.shape
274+
#sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
275+
#print('sp2_coef',sp2_coef)
276+
#print(sp2.coef.shape)
277+
sp2.coef.data = sp2_coef
278+
percentile = torch.linspace(-1,1,self.num+1).to(self.device)
279+
grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
280+
#print('c', grid)
253281
return grid
254282

255283
grid = get_grid(num_interval)

kan/.ipynb_checkpoints/MultKAN-checkpoint.py

+38-13
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,13 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, sca
164164
self.act_fun = []
165165
self.depth = len(width) - 1
166166

167+
#print('haha1', width)
167168
for i in range(len(width)):
168-
if type(width[i]) == int:
169+
#print(type(width[i]), type(width[i]) == int)
170+
if type(width[i]) == int or type(width[i]) == np.int64:
169171
width[i] = [width[i],0]
172+
173+
#print('haha2', width)
170174

171175
self.width = width
172176

@@ -196,7 +200,18 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, sca
196200

197201
for l in range(self.depth):
198202
# splines
199-
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
203+
if isinstance(grid, list):
204+
grid_l = grid[l]
205+
else:
206+
grid_l = grid
207+
208+
if isinstance(k, list):
209+
k_l = k[l]
210+
else:
211+
k_l = k
212+
213+
214+
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid_l, k=k_l, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
200215
self.act_fun.append(sp_batch)
201216

202217
self.node_bias = []
@@ -951,14 +966,14 @@ def unfix_symbolic(self, l, i, j, log_history=True):
951966
if log_history:
952967
self.log_history('unfix_symbolic')
953968

954-
def unfix_symbolic_all(self):
969+
def unfix_symbolic_all(self, log_history=True):
955970
'''
956971
unfix all activation functions.
957972
'''
958973
for l in range(len(self.width) - 1):
959-
for i in range(self.width[l]):
960-
for j in range(self.width[l + 1]):
961-
self.unfix_symbolic(l, i, j)
974+
for i in range(self.width_in[l]):
975+
for j in range(self.width_out[l + 1]):
976+
self.unfix_symbolic(l, i, j, log_history)
962977

963978
def get_range(self, l, i, j, verbose=True):
964979
'''
@@ -1522,6 +1537,10 @@ def closure():
15221537

15231538
if _ == steps-1 and old_save_act:
15241539
self.save_act = True
1540+
1541+
if save_fig and _ % save_fig_freq == 0:
1542+
save_act = self.save_act
1543+
self.save_act = True
15251544

15261545
train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
15271546
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
@@ -1579,6 +1598,7 @@ def closure():
15791598
self.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
15801599
plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200)
15811600
plt.close()
1601+
self.save_act = save_act
15821602

15831603
self.log_history('fit')
15841604
# revert back to original state
@@ -2160,7 +2180,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No
21602180

21612181
return best_name, best_fun, best_r2, best_c;
21622182

2163-
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1):
2183+
def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple = 0.8, r2_threshold=0.0):
21642184
'''
21652185
automatic symbolic regression for all edges
21662186
@@ -2174,7 +2194,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
21742194
library of candidate symbolic functions
21752195
verbose : int
21762196
larger verbosity => more verbosity
2177-
2197+
weight_simple : float
2198+
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
2199+
r2_threshold : float
2200+
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
21782201
Returns:
21792202
--------
21802203
None
@@ -2191,17 +2214,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
21912214
for l in range(len(self.width_in) - 1):
21922215
for i in range(self.width_in[l]):
21932216
for j in range(self.width_out[l + 1]):
2194-
#if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
21952217
if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
21962218
print(f'skipping ({l},{i},{j}) since already symbolic')
21972219
elif self.symbolic_fun[l].mask[j, i] == 0. and self.act_fun[l].mask[i][j] == 0.:
21982220
self.fix_symbolic(l, i, j, '0', verbose=verbose > 1, log_history=False)
21992221
print(f'fixing ({l},{i},{j}) with 0')
22002222
else:
2201-
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False)
2202-
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
2203-
if verbose >= 1:
2204-
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
2223+
name, fun, r2, c = self.suggest_symbolic(l, i, j, a_range=a_range, b_range=b_range, lib=lib, verbose=False, weight_simple=weight_simple)
2224+
if r2 >= r2_threshold:
2225+
self.fix_symbolic(l, i, j, name, verbose=verbose > 1, log_history=False)
2226+
if verbose >= 1:
2227+
print(f'fixing ({l},{i},{j}) with {name}, r2={r2}, c={c}')
2228+
else:
2229+
print(f'For ({l},{i},{j}) the best fit was {name}, but r^2 = {r2} and this is lower than {r2_threshold}. This edge was omitted, keep training or try a different threshold.')
22052230

22062231
self.log_history('auto_symbolic')
22072232

kan/.ipynb_checkpoints/spline-checkpoint.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
6868
Returns:
6969
--------
7070
y_eval : 3D torch.tensor
71-
shape (number of samples, in_dim, out_dim)
71+
shape (batch, in_dim, out_dim)
7272
7373
'''
7474

@@ -78,16 +78,16 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
7878
return y_eval
7979

8080

81-
def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
81+
def curve2coef(x_eval, y_eval, grid, k):
8282
'''
8383
converting B-spline curves to B-spline coefficients using least squares.
8484
8585
Args:
8686
-----
8787
x_eval : 2D torch.tensor
88-
shape (in_dim, out_dim, number of samples)
89-
y_eval : 2D torch.tensor
90-
shape (in_dim, out_dim, number of samples)
88+
shape (batch, in_dim)
89+
y_eval : 3D torch.tensor
90+
shape (batch, in_dim, out_dim)
9191
grid : 2D torch.tensor
9292
shape (in_dim, grid+2*k)
9393
k : int
@@ -100,25 +100,33 @@ def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
100100
coef : 3D torch.tensor
101101
shape (in_dim, out_dim, G+k)
102102
'''
103+
#print('haha', x_eval.shape, y_eval.shape, grid.shape)
103104
batch = x_eval.shape[0]
104105
in_dim = x_eval.shape[1]
105106
out_dim = y_eval.shape[2]
106107
n_coef = grid.shape[1] - k - 1
107108
mat = B_batch(x_eval, grid, k)
108109
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
110+
#print('mat', mat.shape)
109111
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
112+
#print('y_eval', y_eval.shape)
110113
device = mat.device
111114

112-
#coef = torch.linalg.lstsq(mat, y_eval,
113-
#driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
114-
115+
#coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
116+
try:
117+
coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0]
118+
except:
119+
print('lstsq failed')
120+
121+
# manual psuedo-inverse
122+
'''lamb=1e-8
115123
XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
116124
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
117125
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
118126
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
119127
A = XtX + lamb * identity
120128
B = Xty
121-
coef = (A.pinverse() @ B)[:,:,:,0]
129+
coef = (A.pinverse() @ B)[:,:,:,0]'''
122130

123131
return coef
124132

kan/.ipynb_checkpoints/utils-checkpoint.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def augment_input(orig_vars, aux_vars, x):
384384
return x
385385

386386

387-
def batch_jacobian(func, x, create_graph=False):
387+
def batch_jacobian(func, x, create_graph=False, mode='scalar'):
388388
'''
389389
jacobian
390390
@@ -408,7 +408,10 @@ def batch_jacobian(func, x, create_graph=False):
408408
# x in shape (Batch, Length)
409409
def _func_sum(x):
410410
return func(x).sum(dim=0)
411-
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
411+
if mode == 'scalar':
412+
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
413+
elif mode == 'vector':
414+
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)
412415

413416
def batch_hessian(model, x, create_graph=False):
414417
'''
@@ -588,4 +591,4 @@ def model2param(model):
588591
p = torch.tensor([]).to(model.device)
589592
for params in model.parameters():
590593
p = torch.cat([p, params.reshape(-1,)], dim=0)
591-
return p
594+
return p

kan/KANLayer.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_
109109

110110
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
111111
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
112-
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * self.mask).requires_grad_(sp_trainable) # make scale trainable
112+
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask).requires_grad_(sp_trainable) # make scale trainable
113113
self.base_fun = base_fun
114114

115115

@@ -197,11 +197,13 @@ def update_grid_from_samples(self, x, mode='sample'):
197197
def get_grid(num_interval):
198198
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
199199
grid_adaptive = x_pos[ids, :].permute(1,0)
200-
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
201-
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
200+
margin = 0.00
201+
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]] + 2 * margin)/num_interval
202+
grid_uniform = grid_adaptive[:,[0]] - margin + h * torch.arange(num_interval+1,)[None, :].to(x.device)
202203
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
203204
return grid
204205

206+
205207
grid = get_grid(num_interval)
206208

207209
if mode == 'grid':
@@ -210,6 +212,8 @@ def get_grid(num_interval):
210212
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
211213

212214
self.grid.data = extend_grid(grid, k_extend=self.k)
215+
#print('x_pos 2', x_pos.shape)
216+
#print('y_eval 2', y_eval.shape)
213217
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
214218

215219
def initialize_grid_from_parent(self, parent, x, mode='sample'):
@@ -240,16 +244,40 @@ def initialize_grid_from_parent(self, parent, x, mode='sample'):
240244

241245
batch = x.shape[0]
242246

247+
# shrink grid
243248
x_pos = torch.sort(x, dim=0)[0]
244249
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
245250
num_interval = self.grid.shape[1] - 1 - 2*self.k
246251

252+
253+
'''
254+
# based on samples
247255
def get_grid(num_interval):
248256
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
249257
grid_adaptive = x_pos[ids, :].permute(1,0)
250258
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
251259
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
252260
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
261+
return grid'''
262+
263+
#print('p', parent.grid)
264+
# based on interpolating parent grid
265+
def get_grid(num_interval):
266+
x_pos = parent.grid[:,parent.k:-parent.k]
267+
#print('x_pos', x_pos)
268+
sp2 = KANLayer(in_dim=1, out_dim=self.in_dim,k=1,num=x_pos.shape[1]-1,scale_base_mu=0.0, scale_base_sigma=0.0).to(x.device)
269+
270+
#print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim))
271+
#print('sp2_coef_shape', sp2.coef.shape)
272+
sp2_coef = curve2coef(sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim), x_pos.permute(1,0).unsqueeze(dim=2), sp2.grid[:,:], k=1).permute(1,0,2)
273+
shp = sp2_coef.shape
274+
#sp2_coef = torch.cat([torch.zeros(shp[0], shp[1], 1), sp2_coef, torch.zeros(shp[0], shp[1], 1)], dim=2)
275+
#print('sp2_coef',sp2_coef)
276+
#print(sp2.coef.shape)
277+
sp2.coef.data = sp2_coef
278+
percentile = torch.linspace(-1,1,self.num+1).to(self.device)
279+
grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1,0)
280+
#print('c', grid)
253281
return grid
254282

255283
grid = get_grid(num_interval)

0 commit comments

Comments
 (0)