Skip to content

Commit 4a4c45f

Browse files
author
andreasmardt
committed
updated notebook
1 parent f30f8f0 commit 4a4c45f

File tree

3 files changed

+242
-24
lines changed

3 files changed

+242
-24
lines changed

Attention_and_coarse_graining.ipynb

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,6 @@
2727
"from tqdm.notebook import tqdm # progress bar"
2828
]
2929
},
30-
{
31-
"cell_type": "code",
32-
"execution_count": null,
33-
"metadata": {},
34-
"outputs": [],
35-
"source": [
36-
"from torch.utils.tensorboard import SummaryWriter\n",
37-
"writer = SummaryWriter()"
38-
]
39-
},
4030
{
4131
"cell_type": "code",
4232
"execution_count": null,
@@ -280,6 +270,25 @@
280270
"loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)"
281271
]
282272
},
273+
{
274+
"cell_type": "code",
275+
"execution_count": null,
276+
"metadata": {},
277+
"outputs": [],
278+
"source": [
279+
"# Can be used to record the training performance with tensorboard\n",
280+
"# it is not necessary for training or using the methods\n",
281+
"# if you do not wish to install the additional package just leave the flag to false!\n",
282+
"tensorboard_installed = False\n",
283+
"if tensorboard_installed:\n",
284+
" from torch.utils.tensorboard import SummaryWriter\n",
285+
" writer = SummaryWriter()\n",
286+
" input_model, _ = next(iter(loader_train))\n",
287+
" writer.add_graph(lobe, input_to_model=input_model.to(device))\n",
288+
"else:\n",
289+
" writer=None"
290+
]
291+
},
283292
{
284293
"cell_type": "markdown",
285294
"metadata": {},
@@ -392,7 +401,7 @@
392401
"outputs": [],
393402
"source": [
394403
"# train only for the matrix S\n",
395-
"deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='s')\n",
404+
"deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='s', tb_writer=writer)\n",
396405
"plt.loglog(*deepmsm.train_scores.T, label='training')\n",
397406
"plt.loglog(*deepmsm.validation_scores.T, label='validation')\n",
398407
"plt.xlabel('step')\n",
@@ -407,7 +416,7 @@
407416
"outputs": [],
408417
"source": [
409418
"# Train for S and u\n",
410-
"deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us')\n",
419+
"deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='us', tb_writer=writer)\n",
411420
"plt.loglog(*deepmsm.train_scores.T, label='training')\n",
412421
"plt.loglog(*deepmsm.validation_scores.T, label='validation')\n",
413422
"plt.xlabel('step')\n",
@@ -424,7 +433,7 @@
424433
"outputs": [],
425434
"source": [
426435
"# Train for chi, u, and S in an iterative manner\n",
427-
"deepmsm.fit_routine(loader_train, n_epochs=50, validation_loader=loader_val, rel=0.001, reset_u=False, max_iter=1000)\n",
436+
"deepmsm.fit_routine(loader_train, n_epochs=50, validation_loader=loader_val, rel=0.001, reset_u=False, max_iter=1000, tb_writer=writer)\n",
428437
"plt.loglog(*deepmsm.train_scores.T, label='training')\n",
429438
"plt.loglog(*deepmsm.validation_scores.T, label='validation')\n",
430439
"plt.xlabel('step')\n",
@@ -445,7 +454,7 @@
445454
"metadata": {},
446455
"outputs": [],
447456
"source": [
448-
"plot_mask(data=data[0], lobe=lobe_msm, mask=mask, mask_const=mask_const, device=device, vmax=0.5)"
457+
"plot_mask(data=data[0], lobe=lobe_msm, mask=mask, mask_const=mask_const, device=device, vmax=0.25)"
449458
]
450459
},
451460
{
@@ -689,9 +698,9 @@
689698
"outputs": [],
690699
"source": [
691700
"for _ in range(5):\n",
692-
" model_msm_final = deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='s').fetch_model()\n",
701+
" model_msm_final = deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='s', tb_writer=writer).fetch_model()\n",
693702
" # train for u and S\n",
694-
" model_msm_final = deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us').fetch_model()"
703+
" model_msm_final = deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us', tb_writer=writer).fetch_model()"
695704
]
696705
},
697706
{

deepmsm.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from deeptime.util.torch import map_data
1313
from deeptime.markov.tools.analysis import pcca_memberships
1414

15+
CLIP_VALUE = 1.
1516

1617
def symeig_reg(mat, epsilon: float = 1e-6, mode='regularize', eigenvectors=True) \
1718
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@@ -710,6 +711,14 @@ def transform(self, data, **kwargs):
710711
out.append(net(self.mask(data_tensor)).cpu().numpy())
711712
return out if len(out) > 1 else out[0]
712713

714+
def get_mu(self, data_t):
715+
self._lobe.eval()
716+
net = self._lobe
717+
with torch.no_grad():
718+
x_t = net(self.mask(torch.Tensor(data_t).to(self._device)))
719+
mu = self._ulayer(x_t, x_t, return_mu=True)[-1] # use dummy x_0
720+
return mu.detach().to('cpu').numpy()
721+
713722
def get_transition_matrix(self, data_0, data_t):
714723
self._lobe.eval()
715724
net = self._lobe
@@ -776,11 +785,11 @@ def observables(self, data_0, data_t, data_ev=None, data_ac=None, state1=None, s
776785
with torch.no_grad():
777786
x_0 = net(self.mask(torch.Tensor(data_0).to(self._device)))
778787
x_t = net(self.mask(torch.Tensor(data_t).to(self._device)))
779-
output_u = self.ulayer(x_0, x_t, return_mu=return_mu)
788+
output_u = self._ulayer(x_0, x_t, return_mu=return_mu)
780789
if return_mu:
781790
mu = output_u[5]
782791
Sigma = output_u[4]
783-
output_S = self.slayer(*output_u[:5], return_K=return_K, return_S=return_S)
792+
output_S = self._slayer(*output_u[:5], return_K=return_K, return_S=return_S)
784793
if return_K:
785794
K = output_S[1]
786795
if return_S:
@@ -789,14 +798,14 @@ def observables(self, data_0, data_t, data_ev=None, data_ac=None, state1=None, s
789798
if data_ev is not None:
790799
x_ev = torch.Tensor(data_ev).to(self._device)
791800
ev_est = obs_ev(x_ev,mu)
792-
ret.append(ev_est)
801+
ret.append(ev_est.detach().to('cpu').numpy())
793802
if data_ac is not None:
794803
x_ac = torch.Tensor(data_ac).to(self._device)
795804
ac_est = obs_ac(x_ac, mu, x_t, K, Sigma)
796-
ret.append(ac_est)
805+
ret.append(ac_est.detach().to('cpu').numpy())
797806
if state1 is not None:
798807
its_est = get_process_eigval(S, Sigma, state1, state2, epsilon=self._epsilon, mode=self._mode)
799-
ret.append(its_est)
808+
ret.append(its_est.detach().to('cpu').numpy())
800809
return ret
801810

802811
class DeepMSM(DLEstimatorMixin, Transformer):
@@ -887,8 +896,7 @@ def __init__(self, lobe: nn.Module, output_dim: int, coarse_grain: list = None,
887896
self.optimizer_u = torch.optim.Adam(self.ulayer.parameters(), lr=self.learning_rate*10)
888897
self.optimizer_s = torch.optim.Adam(self.slayer.parameters(), lr=self.learning_rate*100)
889898
self.optimizer_lobe = torch.optim.Adam(self.lobe.parameters(), lr=self.learning_rate)
890-
self.optimimzer_all = torch.optim.Adam(chain(self.ulayer.parameters(), self.slayer.parameters(), self.lobe.parameters()),
891-
lr=self.learning_rate)
899+
self.optimimzer_all = torch.optim.Adam(chain(self.ulayer.parameters(), self.slayer.parameters(), self.lobe.parameters()), lr=self.learning_rate)
892900
self._train_scores = []
893901
self._validation_scores = []
894902
self._train_vampe = []
@@ -1075,6 +1083,7 @@ def partial_fit(self, data, mask: bool = False, train_score_callback: Callable[[
10751083

10761084
loss_value = -vampe_loss_rev(x_0, x_t, self.ulayer, self.slayer)[0]
10771085
loss_value.backward()
1086+
torch.nn.utils.clip_grad_norm_(chain(self.lobe.parameters(), self.mask.parameters(), self.ulayer.parameters(), self.slayer.parameters()), CLIP_VALUE)
10781087
if self.mask is not None and mask:
10791088
self.optimizer_mask.step()
10801089
self.optimizer_lobe.step()
@@ -1194,6 +1203,7 @@ def fit(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validation_l
11941203

11951204
loss_value = -vampe_loss_rev(x_0, x_t, self.ulayer, self.slayer)[0]
11961205
loss_value.backward()
1206+
torch.nn.utils.clip_grad_norm_(chain(self.ulayer.parameters(), self.slayer.parameters()), CLIP_VALUE)
11971207
self.optimizer_u.step()
11981208
if train_mode=='us':
11991209
self.optimizer_s.step()
@@ -1224,6 +1234,7 @@ def fit(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validation_l
12241234

12251235
loss_value = -vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma, self.slayer)[0]
12261236
loss_value.backward()
1237+
torch.nn.utils.clip_grad_norm_(self.slayer.parameters(), CLIP_VALUE)
12271238
self.optimizer_s.step()
12281239

12291240
if train_score_callback is not None:
@@ -1355,6 +1366,7 @@ def fit_routine(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, vali
13551366
score = vampe_loss_rev(x_0, x_t, self.ulayer, self.slayer)[0]
13561367
loss_value = -score
13571368
loss_value.backward()
1369+
torch.nn.utils.clip_grad_norm_(chain(self.ulayer.parameters(), self.slayer.parameters()), CLIP_VALUE)
13581370
self.optimizer_u.step()
13591371
self.optimizer_s.step()
13601372
if (score-score_value_before) < rel and counter > 0:
@@ -1455,6 +1467,9 @@ def fit_cg(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validatio
14551467
loss_value += torch.trace(matrix_cg)
14561468

14571469
loss_value.backward()
1470+
torch.nn.utils.clip_grad_norm_(chain(self.ulayer.parameters(), self.slayer.parameters()), CLIP_VALUE)
1471+
for lay_cg in self.cg_list:
1472+
torch.nn.utils.clip_grad_norm_(lay_cg.parameters(), CLIP_VALUE)
14581473
self.optimizer_u.step()
14591474
self.optimizer_s.step()
14601475
for opt in self.cg_opt_list:
@@ -1503,6 +1518,7 @@ def fit_cg(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validatio
15031518
matrix_cg = self.cg_list[idx].get_cg_uS(chi_t, chi_tau, u_n, S_n, return_chi=False)[0]
15041519
loss_value = torch.trace(matrix_cg)
15051520
loss_value.backward()
1521+
torch.nn.utils.clip_grad_norm_(self.cg_list[idx].parameters(), CLIP_VALUE)
15061522
self.cg_opt_list[idx].step()
15071523

15081524
if train_score_callback is not None:
@@ -1634,6 +1650,7 @@ def partial_fit_obs(self, data, data_ev, data_ac, exp_ev=None, exp_ac=None, exp_
16341650
loss_its, est_its = obs_its_loss(S, Sigma, its_state1, its_state2, exp_its, lam_its, epsilon=self.epsilon, mode=self.score_mode)
16351651
loss_value += loss_its
16361652
loss_value.backward()
1653+
torch.nn.utils.clip_grad_norm_(chain(self.lobe.parameters(), self.mask.parameters(), self.ulayer.parameters(), self.slayer.parameters()), CLIP_VALUE)
16371654
self.optimizer_lobe.step()
16381655
self.optimizer_u.step()
16391656
self.optimizer_s.step()
@@ -1989,6 +2006,7 @@ def fit_obs(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validati
19892006
for i in range(est_its.shape[0]):
19902007
tb_writer.add_scalars('ITS', {'train_'+str(i+1): est_its[i].item()}, self._step)
19912008
loss_value.backward()
2009+
torch.nn.utils.clip_grad_norm_(chain(self.ulayer.parameters(), self.slayer.parameters()), CLIP_VALUE)
19922010
self.optimizer_u.step()
19932011
if train_mode=='us':
19942012
self.optimizer_s.step()
@@ -2082,6 +2100,7 @@ def fit_obs(self, data_loader: torch.utils.data.DataLoader, n_epochs=1, validati
20822100
for i in range(est_its.shape[0]):
20832101
tb_writer.add_scalars('ITS', {'train_'+str(i+1): est_its[i].item()}, self._step)
20842102
loss_value.backward()
2103+
torch.nn.utils.clip_grad_norm_(self.slayer.parameters(), CLIP_VALUE)
20852104
self.optimizer_s.step()
20862105

20872106
if train_score_callback is not None:

0 commit comments

Comments
 (0)