-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
157 lines (140 loc) · 7.2 KB
/
model.py
File metadata and controls
157 lines (140 loc) · 7.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import pytorch_lightning as pl
class LitDiffusionModel(pl.LightningModule):
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, schedule='linear', s=0.008):
super().__init__()
"""
If you include more hyperparams (e.g. `n_layers`), be sure to add that to `argparse` from `train.py`.
Also, manually make sure that this new hyperparameter is being saved in `hparams.yaml`.
"""
self.save_hyperparameters()
"""
Your model implementation starts here. We have separate learnable modules for `time_embed` and `model`.
You may choose a different architecture altogether. Feel free to explore what works best for you.
If your architecture is just a sequence of `torch.nn.XXX` layers, using `torch.nn.Sequential` will be easier.
`time_embed` can be learned or a fixed function based on the insights you get from visualizing the data.
If your `model` is different for different datasets, you can use a hyperparameter to switch between them.
Make sure that your hyperparameter behaves as expecte and is being saved correctly in `hparams.yaml`.
"""
self.time_embed = self.time_embedding
self.model = torch.nn.Sequential(
torch.nn.Linear(n_dim + 1, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, n_dim)
)
"""
Be sure to save at least these 2 parameters in the model instance.
"""
self.n_steps = n_steps
self.n_dim = n_dim
self.schedule = schedule
self.s = s
"""
Sets up variables for noise schedule
"""
[self.alpha, self.alpha_bar, self.beta] = self.init_alpha_beta_schedule(lbeta, ubeta)
def time_embedding(self, t):
return t/self.n_steps
def forward(self, x, t):
"""
Similar to `forward` function in `nn.Module`.
Notice here that `x` and `t` are passed separately. If you are using an architecture that combines
`x` and `t` in a different way, modify this function appropriately.
"""
if not isinstance(t, torch.Tensor):
t = torch.LongTensor([t]).expand(x.size(0))
t_embed = self.time_embed(t)
return self.model(torch.cat((x, t_embed[:, None]), dim=1).float())
def init_alpha_beta_schedule(self, lbeta, ubeta):
"""
Set up your noise schedule. You can perhaps have an additional hyperparameter that allows you to
switch between various schedules for answering q4 in depth. Make sure that this hyperparameter
is included correctly while saving and loading your checkpoints.
"""
if self.schedule == 'linear':
beta = torch.linspace(start=lbeta, end=ubeta, steps=self.n_steps)
alpha = 1 - beta
alpha_bar = torch.ones(self.n_steps)
alpha_bar[0] = alpha[0]
for t in range(1, self.n_steps):
alpha_bar[t] = alpha_bar[t-1] * alpha[t]
elif self.schedule == 'cosine':
t = torch.linspace(start=0, end=self.n_steps, steps=self.n_steps+1)
f = torch.cos((t/self.n_steps + self.s)/(1 + self.s) * torch.pi/2) ** 2
alpha_bar = f/f[0]
alpha = torch.ones(self.n_steps)
for i in range(self.n_steps):
alpha[i] = alpha_bar[i+1] / alpha_bar[i]
alpha_bar = alpha_bar[1:]
beta = 1 - alpha
return [alpha, alpha_bar, beta]
def q_sample(self, x, t):
"""
Sample from q given x_t.
"""
return torch.normal(mean=torch.sqrt(self.alpha[t])*x, std=torch.sqrt(self.beta[t])*torch.eye(n=self.n_dim))
def p_sample(self, x, t):
"""
Sample from p given x_t.
"""
mu_theta = 1/torch.sqrt(self.alpha[t]) * (x - self.beta[t] / torch.sqrt(1 - self.alpha_bar[t]) * self.forward(x, t))
return torch.normal(mean=mu_theta, std=torch.sqrt(self.beta[t])*torch.eye(n=self.n_dim))
def training_step(self, batch, batch_idx):
"""
Implements one training step.
Given a batch of samples (n_samples, n_dim) from the distribution you must calculate the loss
for this batch. Simply return this loss from this function so that PyTorch Lightning will
automatically do the backprop for you.
Refer to the DDPM paper [1] for more details about equations that you need to implement for
calculating loss. Make sure that all the operations preserve gradients for proper backprop.
Refer to PyTorch Lightning documentation [2,3] for more details about how the automatic backprop
will update the parameters based on the loss you return from this function.
References:
[1]: https://arxiv.org/abs/2006.11239
[2]: https://pytorch-lightning.readthedocs.io/en/stable/
[3]: https://www.pytorchlightning.ai/tutorials
"""
n_samples, _ = batch.shape
t = torch.randint(low=0, high=self.n_steps, size=[n_samples])
epsilon = torch.randn(size=batch.shape)
batch_updated = torch.sqrt(self.alpha_bar[t][:, None]) * batch + torch.sqrt(1 - self.alpha_bar[t][:, None]) * epsilon
epsilon_theta = self.forward(batch_updated, t)
loss = torch.nn.MSELoss()
return loss(epsilon_theta, epsilon)
def sample(self, n_samples, progress=False, return_intermediate=False):
"""
Implements inference step for the DDPM.
`progress` is an optional flag to implement -- it should just show the current step in diffusion
reverse process.
If `return_intermediate` is `False`,
the function returns a `n_samples` sampled from the learned DDPM
i.e. a Tensor of size (n_samples, n_dim).
Return: (n_samples, n_dim)(final result from diffusion)
Else
the function returns all the intermediate steps in the diffusion process as well
i.e. a Tensor of size (n_samples, n_dim) and a list of `self.n_steps` Tensors of size (n_samples, n_dim) each.
Return: (n_samples, n_dim)(final result), [(n_samples, n_dim)(intermediate) x n_steps]
"""
samples = torch.randn(size=(n_samples, self.n_dim))
intermediate_samples = []
for t in reversed(range(self.n_steps)):
z = torch.randn(size=(n_samples, self.n_dim))
samples = 1/torch.sqrt(self.alpha[t]) * (samples - (self.beta[t] / torch.sqrt(1 - self.alpha_bar[t])) * self.forward(samples, t)) \
+ torch.sqrt(self.beta[t]) * z
intermediate_samples.append(samples)
if not return_intermediate:
return samples
else:
return samples, intermediate_samples
def configure_optimizers(self):
"""
Sets up the optimizer to be used for backprop.
Must return a `torch.optim.XXX` instance.
You may choose to add certain hyperparameters of the optimizers to the `train.py` as well.
In our experiments, we chose one good value of optimizer hyperparameters for all experiments.
"""
return torch.optim.Adam(self.model.parameters())