@@ -801,6 +801,75 @@ def forward(self, x):
801
801
return x # dim: (batch, 1, imsize, imsize)
802
802
803
803
804
+ class Decoder_mnist32x32 (nn .Module ):
805
+ def __init__ (self , caps_size = 16 , num_caps = 1 , img_size = 28 , img_channels = 1 ):
806
+ super ().__init__ ()
807
+
808
+ self .num_caps = num_caps
809
+ self .img_channels = img_channels
810
+ self .img_size = img_size
811
+
812
+ self .dense = torch .nn .Linear (caps_size * num_caps , 8 * 8 * 16 ).cuda (device )
813
+ self .relu = nn .ReLU (inplace = True )
814
+
815
+
816
+ self .reconst_layers1 = nn .Sequential (nn .BatchNorm2d (num_features = 16 , momentum = 0.8 ),
817
+
818
+ nn .ConvTranspose2d (in_channels = 16 , out_channels = 64 ,
819
+ kernel_size = 3 , stride = 1 , padding = 1
820
+ )
821
+ )
822
+
823
+ self .reconst_layers2 = nn .ConvTranspose2d (in_channels = 64 , out_channels = 32 ,
824
+ kernel_size = 3 , stride = 2 , padding = 1
825
+ )
826
+
827
+ self .reconst_layers3 = nn .ConvTranspose2d (in_channels = 32 , out_channels = 16 ,
828
+ kernel_size = 3 , stride = 2 , padding = 1
829
+ )
830
+
831
+ self .reconst_layers4 = nn .ConvTranspose2d (in_channels = 16 , out_channels = 3 ,
832
+ kernel_size = 3 , stride = 1 , padding = 1
833
+ )
834
+
835
+ # self.reconst_layers4 = nn.ConvTranspose2d(in_channels=8, out_channels=3,
836
+ # kernel_size=3, stride=1, padding=1
837
+ # )
838
+
839
+ self .reconst_layers5 = nn .ReLU ()
840
+
841
+
842
+
843
+ def forward (self , x ):
844
+ # x.shape = (batch, 1, capsule_dim(=32 for MNIST))
845
+ batch = x .shape [0 ]
846
+
847
+ x = x .type (torch .FloatTensor )
848
+
849
+ x = x .cuda ()
850
+
851
+ x = self .dense (x )
852
+ x = self .relu (x )
853
+ x = x .reshape (- 1 , 16 , 8 , 8 )
854
+
855
+ x = self .reconst_layers1 (x )
856
+
857
+ x = self .reconst_layers2 (x )
858
+
859
+ # padding
860
+ p2d = (1 , 0 , 1 , 0 )
861
+ x = func .pad (x , p2d , "constant" , 0 )
862
+ x = self .reconst_layers3 (x )
863
+
864
+ # padding
865
+ p2d = (1 , 0 , 1 , 0 )
866
+ x = func .pad (x , p2d , "constant" , 0 )
867
+ x = self .reconst_layers4 (x )
868
+
869
+ # x = self.reconst_layers5(x)
870
+ x = x .reshape (- 1 , self .img_channels , self .img_size , self .img_size )
871
+ return x # dim: (batch, 1, imsize, imsize)
872
+
804
873
# In[13]:
805
874
806
875
@@ -899,7 +968,105 @@ def loss(self, x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1)
899
968
loss = self .margin_loss (x , labels , lamda , m_plus , m_minus ) + self .reconst_loss (recnstrcted , data )
900
969
return loss .mean ()
901
970
971
+ ####################################################################################################################################################
972
+ ####################################################################################################################################################
973
+ class Model32x32 (nn .Module ):
974
+ def __init__ (self ):
975
+ super ().__init__ ()
976
+ self .conv2d = nn .Conv2d (in_channels = 3 , out_channels = 128 ,
977
+ kernel_size = 3 , stride = 1 , padding = 1 )
978
+ self .batchNorm = torch .nn .BatchNorm2d (num_features = 128 , eps = 1e-08 , momentum = 0.99 )
979
+ self .toCaps = ConvertToCaps ()
980
+
981
+ self .conv2dCaps1_nj_4_strd_2 = Conv2DCaps (h = 32 , w = 32 , ch_i = 128 , n_i = 1 , ch_j = 32 , n_j = 4 , kernel_size = 3 , stride = 2 , r_num = 1 )
982
+ self .conv2dCaps1_nj_4_strd_1_1 = Conv2DCaps (h = 16 , w = 16 , ch_i = 32 , n_i = 4 , ch_j = 32 , n_j = 4 , kernel_size = 3 , stride = 1 , r_num = 1 )
983
+ self .conv2dCaps1_nj_4_strd_1_2 = Conv2DCaps (h = 16 , w = 16 , ch_i = 32 , n_i = 4 , ch_j = 32 , n_j = 4 , kernel_size = 3 , stride = 1 , r_num = 1 )
984
+ self .conv2dCaps1_nj_4_strd_1_3 = Conv2DCaps (h = 16 , w = 16 , ch_i = 32 , n_i = 4 , ch_j = 32 , n_j = 4 , kernel_size = 3 , stride = 1 , r_num = 1 )
985
+
986
+ self .conv2dCaps2_nj_8_strd_2 = Conv2DCaps (h = 16 , w = 16 , ch_i = 32 , n_i = 4 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 2 , r_num = 1 )
987
+ self .conv2dCaps2_nj_8_strd_1_1 = Conv2DCaps (h = 8 , w = 8 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
988
+ self .conv2dCaps2_nj_8_strd_1_2 = Conv2DCaps (h = 8 , w = 8 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
989
+ self .conv2dCaps2_nj_8_strd_1_3 = Conv2DCaps (h = 8 , w = 8 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
990
+
991
+ self .conv2dCaps3_nj_8_strd_2 = Conv2DCaps (h = 8 , w = 8 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 2 , r_num = 1 )
992
+ self .conv2dCaps3_nj_8_strd_1_1 = Conv2DCaps (h = 4 , w = 4 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
993
+ self .conv2dCaps3_nj_8_strd_1_2 = Conv2DCaps (h = 4 , w = 4 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
994
+ self .conv2dCaps3_nj_8_strd_1_3 = Conv2DCaps (h = 4 , w = 4 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
902
995
996
+ self .conv2dCaps4_nj_8_strd_2 = Conv2DCaps (h = 4 , w = 4 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 2 , r_num = 1 )
997
+ self .conv3dCaps4_nj_8 = ConvCapsLayer3D (ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , r_num = 3 )
998
+ self .conv2dCaps4_nj_8_strd_1_1 = Conv2DCaps (h = 2 , w = 2 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
999
+ self .conv2dCaps4_nj_8_strd_1_2 = Conv2DCaps (h = 2 , w = 2 , ch_i = 32 , n_i = 8 , ch_j = 32 , n_j = 8 , kernel_size = 3 , stride = 1 , r_num = 1 )
1000
+
1001
+ self .decoder = Decoder_mnist32x32 (caps_size = 32 , num_caps = 1 , img_size = 32 , img_channels = 3 )
1002
+ self .flatCaps = FlattenCaps ()
1003
+ self .digCaps = CapsuleLayer (num_capsules = 10 , num_routes = 64 * 10 , in_channels = 8 , out_channels = 32 , routing_iters = 3 )
1004
+ self .capsToScalars = CapsToScalars ()
1005
+ self .mask = Mask_CID ()
1006
+ self .mse_loss = nn .MSELoss (reduction = "none" )
1007
+
1008
+ def forward (self , x , target = None ):
1009
+ x = self .conv2d (x )
1010
+ x = self .batchNorm (x )
1011
+ x = self .toCaps (x )
1012
+
1013
+ x = self .conv2dCaps1_nj_4_strd_2 (x )
1014
+ x_skip = self .conv2dCaps1_nj_4_strd_1_1 (x )
1015
+ x = self .conv2dCaps1_nj_4_strd_1_2 (x )
1016
+ x = self .conv2dCaps1_nj_4_strd_1_3 (x )
1017
+ x = x + x_skip
1018
+
1019
+ x = self .conv2dCaps2_nj_8_strd_2 (x )
1020
+ x_skip = self .conv2dCaps2_nj_8_strd_1_1 (x )
1021
+ x = self .conv2dCaps2_nj_8_strd_1_2 (x )
1022
+ x = self .conv2dCaps2_nj_8_strd_1_3 (x )
1023
+ x = x + x_skip
1024
+
1025
+ x = self .conv2dCaps3_nj_8_strd_2 (x )
1026
+ x_skip = self .conv2dCaps3_nj_8_strd_1_1 (x )
1027
+ x = self .conv2dCaps3_nj_8_strd_1_2 (x )
1028
+ x = self .conv2dCaps3_nj_8_strd_1_3 (x )
1029
+ x = x + x_skip
1030
+ x1 = x
1031
+
1032
+ x = self .conv2dCaps4_nj_8_strd_2 (x )
1033
+ x_skip = self .conv3dCaps4_nj_8 (x )
1034
+ x = self .conv2dCaps4_nj_8_strd_1_1 (x )
1035
+ x = self .conv2dCaps4_nj_8_strd_1_2 (x )
1036
+ x = x + x_skip
1037
+ x2 = x
1038
+
1039
+ # x1.shape : torch.Size([64, 32, 8, 4, 4]) | x2.shape : torch.Size([64, 32, 8, 2, 2]) (for CIFAR10)
1040
+ xa = self .flatCaps (x1 )
1041
+ xb = self .flatCaps (x2 )
1042
+ x = torch .cat ((xa , xb ), dim = - 2 )
1043
+ dig_caps = self .digCaps (x )
1044
+
1045
+ x = self .capsToScalars (dig_caps )
1046
+ masked , indices = self .mask (dig_caps , target )
1047
+ decoded = self .decoder (masked )
1048
+
1049
+ return dig_caps , masked , decoded , indices
1050
+
1051
+ def margin_loss (self , x , labels , lamda , m_plus , m_minus ):
1052
+ v_c = torch .norm (x , dim = 2 , keepdim = True )
1053
+ tmp1 = func .relu (m_plus - v_c ).view (x .shape [0 ], - 1 ) ** 2
1054
+ tmp2 = func .relu (v_c - m_minus ).view (x .shape [0 ], - 1 ) ** 2
1055
+ loss = labels * tmp1 + lamda * (1 - labels )* tmp2
1056
+ loss = loss .sum (dim = 1 )
1057
+ return loss
1058
+
1059
+ def reconst_loss (self , recnstrcted , data ):
1060
+ loss = self .mse_loss (recnstrcted .view (recnstrcted .shape [0 ], - 1 ), data .view (recnstrcted .shape [0 ], - 1 ))
1061
+ return 0.4 * loss .sum (dim = 1 )
1062
+
1063
+ def loss (self , x , recnstrcted , data , labels , lamda = 0.5 , m_plus = 0.9 , m_minus = 0.1 ):
1064
+ loss = self .margin_loss (x , labels , lamda , m_plus , m_minus ) + self .reconst_loss (recnstrcted , data )
1065
+ return loss .mean ()
1066
+
1067
+
1068
+ ####################################################################################################################################################
1069
+ ####################################################################################################################################################
903
1070
904
1071
905
1072
# In[14]:
@@ -929,6 +1096,10 @@ def loss(x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
929
1096
930
1097
931
1098
model = Model ().cuda ()
1099
+ # Uncomment below line for CIFAR10
1100
+ # model = Model32x32().cuda()
1101
+
1102
+
932
1103
# lr = 0.001
933
1104
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
934
1105
# # torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
@@ -947,6 +1118,11 @@ def loss(x, recnstrcted, data, labels, lamda=0.5, m_plus=0.9, m_minus=0.1):
947
1118
train_loader = torch .utils .data .DataLoader (datasets .FashionMNIST (root = '/home/mtech3/CODES/ankit/data/FashionMNIST/FashionMNIST/' ,train = True ,download = True ,transform = trans (rotation_range = 0.1 , translation_range = 0.1 , zoom_range = (0.1 , 0.2 ))),batch_size = batch_size ,shuffle = True )
948
1119
test_loader = torch .utils .data .DataLoader (datasets .FashionMNIST (root = '/home/mtech3/CODES/ankit/data/FashionMNIST/FashionMNIST/' ,train = False ,download = True ,transform = transforms .ToTensor ()),batch_size = batch_size ,shuffle = True )
949
1120
1121
+ ######################
1122
+ # Uncomment these for CIFAR10
1123
+ # train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./CIFAR10',train=True,download=True,transform=trans(rotation_range=0.1, translation_range=0.1, zoom_range=(0.1, 0.2))),batch_size=batch_size,shuffle=True)
1124
+ # test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(root='./CIFAR10',train=False,download=True,transform=transforms.ToTensor()),batch_size=batch_size,shuffle=True)
1125
+ ######################
950
1126
951
1127
# In[17]:
952
1128
0 commit comments