Skip to content

Commit e47b3b2

Browse files
Update deepcaps.py
1 parent b2cd5aa commit e47b3b2

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed

deepcaps.py

+176
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,75 @@ def forward(self, x):
801801
return x # dim: (batch, 1, imsize, imsize)
802802

803803

804+
class Decoder_mnist32x32(nn.Module):
805+
def __init__(self, caps_size=16, num_caps=1, img_size=28, img_channels=1):
806+
super().__init__()
807+
808+
self.num_caps = num_caps
809+
self.img_channels = img_channels
810+
self.img_size = img_size
811+
812+
self.dense = torch.nn.Linear(caps_size*num_caps, 8*8*16).cuda(device)
813+
self.relu = nn.ReLU(inplace=True)
814+
815+
816+
self.reconst_layers1 = nn.Sequential(nn.BatchNorm2d(num_features=16, momentum=0.8),
817+
818+
nn.ConvTranspose2d(in_channels=16, out_channels=64,
819+
kernel_size=3, stride=1, padding=1
820+
)
821+
)
822+
823+
self.reconst_layers2 = nn.ConvTranspose2d(in_channels=64, out_channels=32,
824+
kernel_size=3, stride=2, padding=1
825+
)
826+
827+
self.reconst_layers3 = nn.ConvTranspose2d(in_channels=32, out_channels=16,
828+
kernel_size=3, stride=2, padding=1
829+
)
830+
831+
self.reconst_layers4 = nn.ConvTranspose2d(in_channels=16, out_channels=3,
832+
kernel_size=3, stride=1, padding=1
833+
)
834+
835+
# self.reconst_layers4 = nn.ConvTranspose2d(in_channels=8, out_channels=3,
836+
# kernel_size=3, stride=1, padding=1
837+
# )
838+
839+
self.reconst_layers5 = nn.ReLU()
840+
841+
842+
843+
def forward(self, x):
844+
# x.shape = (batch, 1, capsule_dim(=32 for MNIST))
845+
batch = x.shape[0]
846+
847+
x = x.type(torch.FloatTensor)
848+
849+
x = x.cuda()
850+
851+
x = self.dense(x)
852+
x = self.relu(x)
853+
x = x.reshape(-1, 16, 8, 8)
854+
855+
x = self.reconst_layers1(x)
856+
857+
x = self.reconst_layers2(x)
858+
859+
# padding
860+
p2d = (1, 0, 1, 0)
861+
x = func.pad(x, p2d, "constant", 0)
862+
x = self.reconst_layers3(x)
863+
864+
# padding
865+
p2d = (1, 0, 1, 0)
866+
x = func.pad(x, p2d, "constant", 0)
867+
x = self.reconst_layers4(x)
868+
869+
# x = self.reconst_layers5(x)
870+
x = x.reshape(-1, self.img_channels, self.img_size, self.img_size)
871+
return x # dim: (batch, 1, imsize, imsize)
872+
804873
# In[13]:
805874

806875

@@ -899,7 +968,105 @@ def loss(self, x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1)
899968
loss = self.margin_loss(x, labels, lamda, m_plus, m_minus) + self.reconst_loss(recnstrcted, data)
900969
return loss.mean()
901970

971+
####################################################################################################################################################
972+
####################################################################################################################################################
973+
class Model32x32(nn.Module):
974+
def __init__(self):
975+
super().__init__()
976+
self.conv2d = nn.Conv2d(in_channels=3, out_channels=128,
977+
kernel_size=3, stride=1, padding=1)
978+
self.batchNorm = torch.nn.BatchNorm2d(num_features=128, eps=1e-08, momentum=0.99)
979+
self.toCaps = ConvertToCaps()
980+
981+
self.conv2dCaps1_nj_4_strd_2 = Conv2DCaps(h=32, w=32, ch_i=128, n_i=1, ch_j=32, n_j=4, kernel_size=3, stride=2, r_num=1)
982+
self.conv2dCaps1_nj_4_strd_1_1 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1)
983+
self.conv2dCaps1_nj_4_strd_1_2 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1)
984+
self.conv2dCaps1_nj_4_strd_1_3 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=4, kernel_size=3, stride=1, r_num=1)
985+
986+
self.conv2dCaps2_nj_8_strd_2 = Conv2DCaps(h=16, w=16, ch_i=32, n_i=4, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1)
987+
self.conv2dCaps2_nj_8_strd_1_1 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
988+
self.conv2dCaps2_nj_8_strd_1_2 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
989+
self.conv2dCaps2_nj_8_strd_1_3 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
990+
991+
self.conv2dCaps3_nj_8_strd_2 = Conv2DCaps(h=8, w=8, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1)
992+
self.conv2dCaps3_nj_8_strd_1_1 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
993+
self.conv2dCaps3_nj_8_strd_1_2 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
994+
self.conv2dCaps3_nj_8_strd_1_3 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
902995

996+
self.conv2dCaps4_nj_8_strd_2 = Conv2DCaps(h=4, w=4, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=2, r_num=1)
997+
self.conv3dCaps4_nj_8 = ConvCapsLayer3D(ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, r_num=3)
998+
self.conv2dCaps4_nj_8_strd_1_1 = Conv2DCaps(h=2, w=2, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
999+
self.conv2dCaps4_nj_8_strd_1_2 = Conv2DCaps(h=2, w=2, ch_i=32, n_i=8, ch_j=32, n_j=8, kernel_size=3, stride=1, r_num=1)
1000+
1001+
self.decoder = Decoder_mnist32x32(caps_size=32, num_caps=1, img_size=32, img_channels=3)
1002+
self.flatCaps = FlattenCaps()
1003+
self.digCaps = CapsuleLayer(num_capsules=10, num_routes=64*10, in_channels=8, out_channels=32, routing_iters=3)
1004+
self.capsToScalars = CapsToScalars()
1005+
self.mask = Mask_CID()
1006+
self.mse_loss = nn.MSELoss(reduction="none")
1007+
1008+
def forward(self, x, target=None):
1009+
x = self.conv2d(x)
1010+
x = self.batchNorm(x)
1011+
x = self.toCaps(x)
1012+
1013+
x = self.conv2dCaps1_nj_4_strd_2(x)
1014+
x_skip = self.conv2dCaps1_nj_4_strd_1_1(x)
1015+
x = self.conv2dCaps1_nj_4_strd_1_2(x)
1016+
x = self.conv2dCaps1_nj_4_strd_1_3(x)
1017+
x = x + x_skip
1018+
1019+
x = self.conv2dCaps2_nj_8_strd_2(x)
1020+
x_skip = self.conv2dCaps2_nj_8_strd_1_1(x)
1021+
x = self.conv2dCaps2_nj_8_strd_1_2(x)
1022+
x = self.conv2dCaps2_nj_8_strd_1_3(x)
1023+
x = x + x_skip
1024+
1025+
x = self.conv2dCaps3_nj_8_strd_2(x)
1026+
x_skip = self.conv2dCaps3_nj_8_strd_1_1(x)
1027+
x = self.conv2dCaps3_nj_8_strd_1_2(x)
1028+
x = self.conv2dCaps3_nj_8_strd_1_3(x)
1029+
x = x + x_skip
1030+
x1 = x
1031+
1032+
x = self.conv2dCaps4_nj_8_strd_2(x)
1033+
x_skip = self.conv3dCaps4_nj_8(x)
1034+
x = self.conv2dCaps4_nj_8_strd_1_1(x)
1035+
x = self.conv2dCaps4_nj_8_strd_1_2(x)
1036+
x = x + x_skip
1037+
x2 = x
1038+
1039+
# x1.shape : torch.Size([64, 32, 8, 4, 4]) | x2.shape : torch.Size([64, 32, 8, 2, 2]) (for CIFAR10)
1040+
xa = self.flatCaps(x1)
1041+
xb = self.flatCaps(x2)
1042+
x = torch.cat((xa, xb), dim=-2)
1043+
dig_caps = self.digCaps(x)
1044+
1045+
x = self.capsToScalars(dig_caps)
1046+
masked, indices = self.mask(dig_caps, target)
1047+
decoded = self.decoder(masked)
1048+
1049+
return dig_caps, masked, decoded, indices
1050+
1051+
def margin_loss(self, x, labels, lamda, m_plus, m_minus):
1052+
v_c = torch.norm(x, dim=2, keepdim=True)
1053+
tmp1 = func.relu(m_plus - v_c).view(x.shape[0], -1) ** 2
1054+
tmp2 = func.relu(v_c - m_minus).view(x.shape[0], -1) ** 2
1055+
loss = labels*tmp1 + lamda*(1-labels)*tmp2
1056+
loss = loss.sum(dim=1)
1057+
return loss
1058+
1059+
def reconst_loss(self, recnstrcted, data):
1060+
loss = self.mse_loss(recnstrcted.view(recnstrcted.shape[0], -1), data.view(recnstrcted.shape[0], -1))
1061+
return 0.4 * loss.sum(dim=1)
1062+
1063+
def loss(self, x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
1064+
loss = self.margin_loss(x, labels, lamda, m_plus, m_minus) + self.reconst_loss(recnstrcted, data)
1065+
return loss.mean()
1066+
1067+
1068+
####################################################################################################################################################
1069+
####################################################################################################################################################
9031070

9041071

9051072
# In[14]:
@@ -929,6 +1096,10 @@ def loss(x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
9291096

9301097

9311098
model = Model().cuda()
1099+
# Uncomment below line for CIFAR10
1100+
# model = Model32x32().cuda()
1101+
1102+
9321103
# lr = 0.001
9331104
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
9341105
# # torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
@@ -947,6 +1118,11 @@ def loss(x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
9471118
train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(root='/home/mtech3/CODES/ankit/data/FashionMNIST/FashionMNIST/',train=True,download=True,transform=trans(rotation_range=0.1, translation_range=0.1, zoom_range=(0.1, 0.2))),batch_size=batch_size,shuffle=True)
9481119
test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(root='/home/mtech3/CODES/ankit/data/FashionMNIST/FashionMNIST/',train=False,download=True,transform=transforms.ToTensor()),batch_size=batch_size,shuffle=True)
9491120

1121+
######################
1122+
# Uncomment these for CIFAR10
1123+
# train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./CIFAR10',train=True,download=True,transform=trans(rotation_range=0.1, translation_range=0.1, zoom_range=(0.1, 0.2))),batch_size=batch_size,shuffle=True)
1124+
# test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./CIFAR10',train=False,download=True,transform=transforms.ToTensor()),batch_size=batch_size,shuffle=True)
1125+
######################
9501126

9511127
# In[17]:
9521128

0 commit comments

Comments
 (0)