12
12
from deeptime .util .torch import map_data
13
13
from deeptime .markov .tools .analysis import pcca_memberships
14
14
15
+ CLIP_VALUE = 1.
15
16
16
17
def symeig_reg (mat , epsilon : float = 1e-6 , mode = 'regularize' , eigenvectors = True ) \
17
18
-> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
@@ -710,6 +711,14 @@ def transform(self, data, **kwargs):
710
711
out .append (net (self .mask (data_tensor )).cpu ().numpy ())
711
712
return out if len (out ) > 1 else out [0 ]
712
713
714
+ def get_mu (self , data_t ):
715
+ self ._lobe .eval ()
716
+ net = self ._lobe
717
+ with torch .no_grad ():
718
+ x_t = net (self .mask (torch .Tensor (data_t ).to (self ._device )))
719
+ mu = self ._ulayer (x_t , x_t , return_mu = True )[- 1 ] # use dummy x_0
720
+ return mu .detach ().to ('cpu' ).numpy ()
721
+
713
722
def get_transition_matrix (self , data_0 , data_t ):
714
723
self ._lobe .eval ()
715
724
net = self ._lobe
@@ -776,11 +785,11 @@ def observables(self, data_0, data_t, data_ev=None, data_ac=None, state1=None, s
776
785
with torch .no_grad ():
777
786
x_0 = net (self .mask (torch .Tensor (data_0 ).to (self ._device )))
778
787
x_t = net (self .mask (torch .Tensor (data_t ).to (self ._device )))
779
- output_u = self .ulayer (x_0 , x_t , return_mu = return_mu )
788
+ output_u = self ._ulayer (x_0 , x_t , return_mu = return_mu )
780
789
if return_mu :
781
790
mu = output_u [5 ]
782
791
Sigma = output_u [4 ]
783
- output_S = self .slayer (* output_u [:5 ], return_K = return_K , return_S = return_S )
792
+ output_S = self ._slayer (* output_u [:5 ], return_K = return_K , return_S = return_S )
784
793
if return_K :
785
794
K = output_S [1 ]
786
795
if return_S :
@@ -789,14 +798,14 @@ def observables(self, data_0, data_t, data_ev=None, data_ac=None, state1=None, s
789
798
if data_ev is not None :
790
799
x_ev = torch .Tensor (data_ev ).to (self ._device )
791
800
ev_est = obs_ev (x_ev ,mu )
792
- ret .append (ev_est )
801
+ ret .append (ev_est . detach (). to ( 'cpu' ). numpy () )
793
802
if data_ac is not None :
794
803
x_ac = torch .Tensor (data_ac ).to (self ._device )
795
804
ac_est = obs_ac (x_ac , mu , x_t , K , Sigma )
796
- ret .append (ac_est )
805
+ ret .append (ac_est . detach (). to ( 'cpu' ). numpy () )
797
806
if state1 is not None :
798
807
its_est = get_process_eigval (S , Sigma , state1 , state2 , epsilon = self ._epsilon , mode = self ._mode )
799
- ret .append (its_est )
808
+ ret .append (its_est . detach (). to ( 'cpu' ). numpy () )
800
809
return ret
801
810
802
811
class DeepMSM (DLEstimatorMixin , Transformer ):
@@ -887,8 +896,7 @@ def __init__(self, lobe: nn.Module, output_dim: int, coarse_grain: list = None,
887
896
self .optimizer_u = torch .optim .Adam (self .ulayer .parameters (), lr = self .learning_rate * 10 )
888
897
self .optimizer_s = torch .optim .Adam (self .slayer .parameters (), lr = self .learning_rate * 100 )
889
898
self .optimizer_lobe = torch .optim .Adam (self .lobe .parameters (), lr = self .learning_rate )
890
- self .optimimzer_all = torch .optim .Adam (chain (self .ulayer .parameters (), self .slayer .parameters (), self .lobe .parameters ()),
891
- lr = self .learning_rate )
899
+ self .optimimzer_all = torch .optim .Adam (chain (self .ulayer .parameters (), self .slayer .parameters (), self .lobe .parameters ()), lr = self .learning_rate )
892
900
self ._train_scores = []
893
901
self ._validation_scores = []
894
902
self ._train_vampe = []
@@ -1075,6 +1083,7 @@ def partial_fit(self, data, mask: bool = False, train_score_callback: Callable[[
1075
1083
1076
1084
loss_value = - vampe_loss_rev (x_0 , x_t , self .ulayer , self .slayer )[0 ]
1077
1085
loss_value .backward ()
1086
+ torch .nn .utils .clip_grad_norm_ (chain (self .lobe .parameters (), self .mask .parameters (), self .ulayer .parameters (), self .slayer .parameters ()), CLIP_VALUE )
1078
1087
if self .mask is not None and mask :
1079
1088
self .optimizer_mask .step ()
1080
1089
self .optimizer_lobe .step ()
@@ -1194,6 +1203,7 @@ def fit(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validation_l
1194
1203
1195
1204
loss_value = - vampe_loss_rev (x_0 , x_t , self .ulayer , self .slayer )[0 ]
1196
1205
loss_value .backward ()
1206
+ torch .nn .utils .clip_grad_norm_ (chain (self .ulayer .parameters (), self .slayer .parameters ()), CLIP_VALUE )
1197
1207
self .optimizer_u .step ()
1198
1208
if train_mode == 'us' :
1199
1209
self .optimizer_s .step ()
@@ -1224,6 +1234,7 @@ def fit(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validation_l
1224
1234
1225
1235
loss_value = - vampe_loss_rev_only_S (v , C_00 , C_11 , C_01 , Sigma , self .slayer )[0 ]
1226
1236
loss_value .backward ()
1237
+ torch .nn .utils .clip_grad_norm_ (self .slayer .parameters (), CLIP_VALUE )
1227
1238
self .optimizer_s .step ()
1228
1239
1229
1240
if train_score_callback is not None :
@@ -1355,6 +1366,7 @@ def fit_routine(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, vali
1355
1366
score = vampe_loss_rev (x_0 , x_t , self .ulayer , self .slayer )[0 ]
1356
1367
loss_value = - score
1357
1368
loss_value .backward ()
1369
+ torch .nn .utils .clip_grad_norm_ (chain (self .ulayer .parameters (), self .slayer .parameters ()), CLIP_VALUE )
1358
1370
self .optimizer_u .step ()
1359
1371
self .optimizer_s .step ()
1360
1372
if (score - score_value_before ) < rel and counter > 0 :
@@ -1455,6 +1467,9 @@ def fit_cg(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validatio
1455
1467
loss_value += torch .trace (matrix_cg )
1456
1468
1457
1469
loss_value .backward ()
1470
+ torch .nn .utils .clip_grad_norm_ (chain (self .ulayer .parameters (), self .slayer .parameters ()), CLIP_VALUE )
1471
+ for lay_cg in self .cg_list :
1472
+ torch .nn .utils .clip_grad_norm_ (lay_cg .parameters (), CLIP_VALUE )
1458
1473
self .optimizer_u .step ()
1459
1474
self .optimizer_s .step ()
1460
1475
for opt in self .cg_opt_list :
@@ -1503,6 +1518,7 @@ def fit_cg(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validatio
1503
1518
matrix_cg = self .cg_list [idx ].get_cg_uS (chi_t , chi_tau , u_n , S_n , return_chi = False )[0 ]
1504
1519
loss_value = torch .trace (matrix_cg )
1505
1520
loss_value .backward ()
1521
+ torch .nn .utils .clip_grad_norm_ (self .cg_list [idx ].parameters (), CLIP_VALUE )
1506
1522
self .cg_opt_list [idx ].step ()
1507
1523
1508
1524
if train_score_callback is not None :
@@ -1634,6 +1650,7 @@ def partial_fit_obs(self, data, data_ev, data_ac, exp_ev=None, exp_ac=None, exp_
1634
1650
loss_its , est_its = obs_its_loss (S , Sigma , its_state1 , its_state2 , exp_its , lam_its , epsilon = self .epsilon , mode = self .score_mode )
1635
1651
loss_value += loss_its
1636
1652
loss_value .backward ()
1653
+ torch .nn .utils .clip_grad_norm_ (chain (self .lobe .parameters (), self .mask .parameters (), self .ulayer .parameters (), self .slayer .parameters ()), CLIP_VALUE )
1637
1654
self .optimizer_lobe .step ()
1638
1655
self .optimizer_u .step ()
1639
1656
self .optimizer_s .step ()
@@ -1989,6 +2006,7 @@ def fit_obs(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validati
1989
2006
for i in range (est_its .shape [0 ]):
1990
2007
tb_writer .add_scalars ('ITS' , {'train_' + str (i + 1 ): est_its [i ].item ()}, self ._step )
1991
2008
loss_value .backward ()
2009
+ torch .nn .utils .clip_grad_norm_ (chain (self .ulayer .parameters (), self .slayer .parameters ()), CLIP_VALUE )
1992
2010
self .optimizer_u .step ()
1993
2011
if train_mode == 'us' :
1994
2012
self .optimizer_s .step ()
@@ -2082,6 +2100,7 @@ def fit_obs(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validati
2082
2100
for i in range (est_its .shape [0 ]):
2083
2101
tb_writer .add_scalars ('ITS' , {'train_' + str (i + 1 ): est_its [i ].item ()}, self ._step )
2084
2102
loss_value .backward ()
2103
+ torch .nn .utils .clip_grad_norm_ (self .slayer .parameters (), CLIP_VALUE )
2085
2104
self .optimizer_s .step ()
2086
2105
2087
2106
if train_score_callback is not None :
0 commit comments