-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodel.py
More file actions
35 lines (25 loc) · 866 Bytes
/
model.py
File metadata and controls
35 lines (25 loc) · 866 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torch import nn
from torch.nn import functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# let's use 3 hidden layers
class SimpleModel(nn.Module):
def __init__(self, D_i, D_k, D_o) -> None:
super().__init__()
self.ffn = nn.Sequential(
nn.Linear(D_i, D_k),
nn.LeakyReLU(),
nn.Linear(D_k, D_k),
nn.LeakyReLU(),
nn.Linear(D_k, D_k),
nn.LeakyReLU(),
nn.Linear(D_k, D_o),
nn.Sigmoid(),
)
def weights_init(layer_in):
if isinstance(layer_in, nn.Linear):
nn.init.kaiming_normal_(layer_in.weight)
layer_in.bias.data.fill_(0.0)
self.ffn.apply(weights_init)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.ffn(x)