Skip to content

Commit 91a2f63

Browse files
authored
Update utils.py
1 parent c162d38 commit 91a2f63

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

kan/utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def augment_input(orig_vars, aux_vars, x):
384384
return x
385385

386386

387-
def batch_jacobian(func, x, create_graph=False):
387+
def batch_jacobian(func, x, create_graph=False, mode='scalar'):
388388
'''
389389
jacobian
390390
@@ -408,7 +408,10 @@ def batch_jacobian(func, x, create_graph=False):
408408
# x in shape (Batch, Length)
409409
def _func_sum(x):
410410
return func(x).sum(dim=0)
411-
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
411+
if mode == 'scalar':
412+
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph)[0]
413+
elif mode == 'vector':
414+
return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)
412415

413416
def batch_hessian(model, x, create_graph=False):
414417
'''
@@ -588,4 +591,4 @@ def model2param(model):
588591
p = torch.tensor([]).to(model.device)
589592
for params in model.parameters():
590593
p = torch.cat([p, params.reshape(-1,)], dim=0)
591-
return p
594+
return p

0 commit comments

Comments
 (0)