Skip to content

Commit

Permalink
added tinyimagenet and static drop
Browse files Browse the repository at this point in the history
  • Loading branch information
Arnav0400 committed Jul 25, 2021
1 parent f39e456 commit cae74c8
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 38 deletions.
57 changes: 39 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
import torch
import numpy as np
from torchvision import datasets, transforms
from torchvision.transforms import RandomRotation, RandomVerticalFlip, RandomHorizontalFlip, Pad, Resize, Compose, ToTensor
from torchvision import transforms
from torch import nn, optim
from torch.utils.data import DataLoader
import os
from PIL import Image
import argparse
from models.get_model import get_model
from utils.rotmnist import MnistRotDataset
from utils.tinyimagenet import TinyImageNet
from pytorch_lightning import Trainer, loggers, seed_everything
seed_everything(42)

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

def __init__(self, model, dataset, batch_size=64, mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]):
def __init__(self, model, dataset, batch_size=64):
super().__init__()
self.batch_size = batch_size
self.dataset = dataset
self.model = model
self.mean = mean
self.std = std

def forward(self, x):
return self.model(x)
Expand Down Expand Up @@ -68,27 +67,40 @@ def configure_optimizers(self):
return [optimizer], [lr_scheduler]

def train_dataloader(self):
mean=[0.4914, 0.4822, 0.4465]
std=[0.2023, 0.1994, 0.2010]
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(self.mean, self.std)])
transforms.Normalize(mean, std)])
if self.dataset == 'CIFAR10':
dataset = datasets.CIFAR10(root=os.getcwd(), train=True, transform=transform_train, download = True)
elif self.dataset == 'CIFAR100':
dataset = datasets.CIFAR100(root=os.getcwd(), train=True, transform=transform_train, download = True)
elif self.dataset == 'MNIST-rot':

train_transform = Compose([
Pad((0, 0, 1, 1), fill=0),
Resize(87),
RandomRotation(180, resample=Image.BILINEAR, expand=False),
RandomVerticalFlip(),
RandomHorizontalFlip(),
Resize(29),
ToTensor(),
train_transform = transforms.Compose([
transforms.Pad((0, 0, 1, 1), fill=0),
transforms.Resize(87),
transforms.RandomRotation(180, resample=Image.BILINEAR, expand=False),
transforms.RandomVerticalFlip(),
transforms.RandomHorizontalFlip(),
transforms.Resize(29),
transforms.ToTensor(),
])

dataset = MnistRotDataset(mode='train', transform=train_transform)
elif self.dataset == 'TINYIMNET':
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
norm_transform = transforms.Normalize(norm_mean, norm_std)
train_transform = transforms.Compose([
transforms.RandomAffine(degrees=20.0, scale=(0.8, 1.2), shear=20.0),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
norm_transform,
])
dataset = TinyImageNet(os.getcwd(), train=True, transform=train_transform)
dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=8, shuffle=True, drop_last=True, pin_memory=True)
return dataloader

Expand All @@ -100,11 +112,20 @@ def val_dataloader(self):
elif self.dataset == 'CIFAR100':
dataset = datasets.CIFAR100(root=os.getcwd(), train=False, transform=transform_val, download = True)
elif self.dataset == 'MNIST-rot':
test_transform = Compose([
Pad((0, 0, 1, 1), fill=0),
ToTensor(),
transform_val = transforms.Compose([
transforms.Pad((0, 0, 1, 1), fill=0),
transforms.ToTensor(),
])
dataset = MnistRotDataset(mode='test', transform=transform_val)
elif self.dataset == 'TINYIMNET':
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
norm_transform = transforms.Normalize(norm_mean, norm_std)
transform_val = transforms.Compose([
transforms.ToTensor(),
norm_transform
])
dataset = MnistRotDataset(mode='test', transform=test_transform)
dataset = TinyImageNet(os.getcwd(), train=False, transform=transform_val)
dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=8, pin_memory=True)
return dataloader

Expand All @@ -126,7 +147,7 @@ def parse_args():
if not os.path.exists('weights'):
os.mkdir('weights')

model = get_model(args.model_name, args.num_classes)
model = get_model(args.model_name, args.num_classes, args)
system = CoolSystem(model, args.dataset)

model_parameters = filter(lambda p: p.requires_grad, system.model.parameters())
Expand Down
18 changes: 14 additions & 4 deletions models/conv2d_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from torch.nn.functional import conv2d

class Conv2dRepeat(nn.Module):
def __init__(self, original_weight_shape, repeated_weight_shape, previous_weight_shape=None, concat_dim=1, stride=1, padding=1, conv_type="intra", weight_activation='swish'):
def __init__(self, original_weight_shape, repeated_weight_shape, previous_weight_shape=None, concat_dim=1, stride=1, padding=1, conv_type="intra", args=None):
super(Conv2dRepeat, self).__init__()

self.args = args
self.ooc, self.oic, self.ok1, self.ok2 = original_weight_shape
self.roc, self.ric, self.rk1, self.rk2 = repeated_weight_shape
self.do_repeat = False if original_weight_shape==repeated_weight_shape else True
self.stride = stride
self.padding = padding
self.conv_type = conv_type
self.wactivation = weight_activation
self.args.wactivation = self.args.weight_activation
if previous_weight_shape is not None:
if concat_dim==0:
self.ooc+=previous_weight_shape[0]
Expand Down Expand Up @@ -52,6 +52,12 @@ def __init__(self, original_weight_shape, repeated_weight_shape, previous_weight

self.unfold = torch.nn.Unfold(kernel_size=(self.ooc, self.oic), stride=(self.ooc, self.oic))
self.fold = torch.nn.Fold(output_size=(self.ooc*self.r0, self.oic*self.r1), kernel_size=(self.ooc, self.oic), stride=(self.ooc, self.oic))

if self.wactivation=='static_drop':
self.drop_mask = torch.ones(self.roc, self.ric, self.rk1, self.rk2)*(1-self.args.drop_rate)
self.drop_mask = torch.bernoulli(self.drop_mask)
self.drop_mask = nn.Parameter(self.drop_mask)
self.drop_mask.requires_grad = False

def forward(self, x, weights=None):
if self.conv_type=="intra":
Expand All @@ -68,11 +74,15 @@ def forward(self, x, weights=None):
x = F.conv2d(x, weights, self.bias, stride=self.stride, padding = self.padding)
return x

def activation(self, weight, alphas, betas):
def activation(self, weight, alphas=None, betas=None):
if self.wactivation=="swish":
x = weight*alphas/(1+torch.exp(weight*betas))
elif self.wactivation=="fourier":
x = self.alphas[0]+self.alphas[1]*weight**1+self.alphas[2]*weight**2+self.alphas[3]*weight**3+self.alphas[4]*weight**4+self.alphas[5]*weight**5
elif self.wactivation=="static_drop":
x = weight*(self.drop_mask.reshape_as(weight))
elif self.wactivation==None:
x = weight
return x

def repeat(self, weights):
Expand Down
10 changes: 5 additions & 5 deletions models/get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .resnet_rep import *
from .group_equivariant import *

def get_model(model_name, num_classes):
def get_model(model_name, num_classes, args=None):
if model_name=='rep_vgg_4':
return CVGG11_4(num_classes)
elif model_name=='rep_vgg_5':
Expand All @@ -20,13 +20,13 @@ def get_model(model_name, num_classes):
elif model_name[:3]=='VGG':
return VGG(model_name, num_classes)
elif model_name=='resnet_16_1':
return resnet_rep(num_classes, 1)
return resnet_rep(num_classes, 1, args)
elif model_name=='resnet_16_4':
return resnet_rep(num_classes, 4)
return resnet_rep(num_classes, 4, args)
elif model_name=='resnet_16_8':
return resnet_rep(num_classes, 8)
return resnet_rep(num_classes, 8, args)
elif model_name=='resnet_16_10':
return resnet_rep(num_classes, 10)
return resnet_rep(num_classes, 10, args)
elif model_name=='c16':
return c16(num_classes)
elif model_name=='c8':
Expand Down
23 changes: 12 additions & 11 deletions models/resnet_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ def conv3x3(in_planes, out_planes, stride=1):

class BasicBlock(nn.Module):
width=1
def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, args=None):
super(BasicBlock, self).__init__()
if inplanes==16:
self.conv1 = Conv2dRepeat((planes//self.width, inplanes, 3, 3), (planes, inplanes, 3, 3), stride=stride)
self.conv1 = Conv2dRepeat((planes//self.width, inplanes, 3, 3), (planes, inplanes, 3, 3), stride=stride, args=args)
else:
self.conv1 = Conv2dRepeat((planes//self.width, inplanes//self.width, 3, 3), (planes, inplanes, 3, 3), stride=stride)
self.conv1 = Conv2dRepeat((planes//self.width, inplanes//self.width, 3, 3), (planes, inplanes, 3, 3), stride=stride, args=args)

self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = Conv2dRepeat((planes//self.width, planes//self.width, 3, 3), (planes, planes, 3, 3))
self.conv2 = Conv2dRepeat((planes//self.width, planes//self.width, 3, 3), (planes, planes, 3, 3), args=args)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
Expand All @@ -42,9 +42,10 @@ def forward(self, x):
return out

class ResNetCifar(nn.Module):
def __init__(self, block, layers, width=1, num_classes=10):
def __init__(self, block, layers, width=1, num_classes=10, args=None):
self.inplanes = 16
super(ResNetCifar, self).__init__()
self.args = args
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(16)
Expand All @@ -68,20 +69,20 @@ def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1:
downsample = nn.Sequential(
Conv2dRepeat((planes//self.width, self.inplanes//self.width, 1, 1), (planes, self.inplanes, 1, 1), stride=stride, padding=0),
Conv2dRepeat((planes//self.width, self.inplanes//self.width, 1, 1), (planes, self.inplanes, 1, 1), stride=stride, padding=0, args=self.args),
nn.BatchNorm2d(planes),
)
elif self.inplanes != planes:
downsample = nn.Sequential(
Conv2dRepeat((planes//self.width, self.inplanes, 1, 1), (planes, self.inplanes, 1, 1), stride=stride, padding=0),
Conv2dRepeat((planes//self.width, self.inplanes, 1, 1), (planes, self.inplanes, 1, 1), stride=stride, padding=0, args=self.args),
nn.BatchNorm2d(planes),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
layers.append(block(self.inplanes, planes, stride, downsample, args=self.args))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
layers.append(block(self.inplanes, planes, args = self.args))

return nn.Sequential(*layers)

Expand All @@ -97,7 +98,7 @@ def forward(self, x):

return x

def resnet_rep(num_classes=10, k=1):
model = ResNetCifar(BasicBlock, [2, 2, 2], width=k, num_classes=num_classes)
def resnet_rep(num_classes=10, k=1, args=None):
model = ResNetCifar(BasicBlock, [2, 2, 2], width=k, num_classes=num_classes, args = args)
return model

42 changes: 42 additions & 0 deletions utils/tinyimagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torch.utils.data import Dataset
import glob
import numpy as np
import os
from torchvision.datasets.folder import pil_loader
from torchvision.datasets.utils import download_and_extract_archive

class TinyImageNet(Dataset):
def __init__(self, root, train, transform, download=True):

self.url = "http://cs231n.stanford.edu/tiny-imagenet-200"
self.root = root
if download:
if os.path.exists(f'{self.root}/tiny-imagenet-200/'):
print('File already downloaded')
else:
download_and_extract_archive(self.url, root, filename="tiny-imagenet-200.zip")

self.root = os.path.join(self.root, "tiny-imagenet-200")
self.train = train
self.transform = transform
self.ids_string = np.sort(np.loadtxt(f"{self.root}/wnids.txt", "str"))
self.ids = {class_string: i for i, class_string in enumerate(self.ids_string)}
if train:
self.paths = glob.glob(f"{self.root}/train/*/images/*")
self.label = [self.ids[path.split("/")[-3]] for path in self.paths]
else:
self.val_annotations = np.loadtxt(f"{self.root}/val/val_annotations.txt", "str")
self.paths = [f"{self.root}/val/images/{sample[0]}" for sample in self.val_annotations]
self.label = [self.ids[sample[1]] for sample in self.val_annotations]

def __len__(self):
return len(self.paths)

def __getitem__(self, idx):
image = pil_loader(self.paths[idx])

if self.transform is not None:
image = self.transform(image)

return image, self.label[idx]

0 comments on commit cae74c8

Please sign in to comment.