@@ -384,7 +384,7 @@ def augment_input(orig_vars, aux_vars, x):
384
384
return x
385
385
386
386
387
- def batch_jacobian (func , x , create_graph = False ):
387
+ def batch_jacobian (func , x , create_graph = False , mode = 'scalar' ):
388
388
'''
389
389
jacobian
390
390
@@ -408,7 +408,10 @@ def batch_jacobian(func, x, create_graph=False):
408
408
# x in shape (Batch, Length)
409
409
def _func_sum (x ):
410
410
return func (x ).sum (dim = 0 )
411
- return torch .autograd .functional .jacobian (_func_sum , x , create_graph = create_graph )[0 ]
411
+ if mode == 'scalar' :
412
+ return torch .autograd .functional .jacobian (_func_sum , x , create_graph = create_graph )[0 ]
413
+ elif mode == 'vector' :
414
+ return torch .autograd .functional .jacobian (_func_sum , x , create_graph = create_graph ).permute (1 ,0 ,2 )
412
415
413
416
def batch_hessian (model , x , create_graph = False ):
414
417
'''
@@ -588,4 +591,4 @@ def model2param(model):
588
591
p = torch .tensor ([]).to (model .device )
589
592
for params in model .parameters ():
590
593
p = torch .cat ([p , params .reshape (- 1 ,)], dim = 0 )
591
- return p
594
+ return p
0 commit comments