Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vae baseline new #131

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3e460d4
Formatting.
msbartlett Oct 17, 2019
981416b
Format with black.
msbartlett Oct 17, 2019
058d2bb
VAE baseline.
msbartlett Oct 21, 2019
15660c6
VAE baseline
msbartlett Oct 21, 2019
acb0cc1
Forgot to add files before committing...
msbartlett Oct 21, 2019
523e883
Run black... again
msbartlett Oct 21, 2019
be8d5df
Run black... again again
msbartlett Oct 21, 2019
ee5c8e5
Log images with wandb.
msbartlett Oct 21, 2019
fefa292
Remove unused imports.
msbartlett Oct 21, 2019
b74733a
Pretrain on KMNIST
MylesBartlett Oct 24, 2019
ed8f936
Call main in start_vae
MylesBartlett Oct 24, 2019
d194ec7
Remove clamping during evaluation.
MylesBartlett Oct 24, 2019
1d5e4a6
Don't normalize celeba.
MylesBartlett Oct 24, 2019
714da6f
Parition encoding.
MylesBartlett Oct 25, 2019
c24317a
Hyphens in CL arguments.
MylesBartlett Oct 25, 2019
e2ac089
Corrected sizes in tensor splitting.
MylesBartlett Oct 25, 2019
e32100d
Don't allow gradients of reconstruction loss to flow through zs parti…
MylesBartlett Oct 28, 2019
1f1a407
Recon not defined in VAE routine.
MylesBartlett Oct 28, 2019
e354ac6
Recon loss argument.
MylesBartlett Oct 28, 2019
0a846de
Rewrite validation to match training.
MylesBartlett Oct 28, 2019
21f800e
Typo.
MylesBartlett Oct 28, 2019
b8b232e
Correct validation procedure.
MylesBartlett Oct 28, 2019
3bd943f
Don't include zs in ELBO.
MylesBartlett Oct 29, 2019
1f4cb04
Revert previous change.
MylesBartlett Oct 29, 2019
7073002
Update discriminator's weights.
MylesBartlett Oct 29, 2019
7d4e3c7
Stop gradient in kl.
MylesBartlett Oct 29, 2019
4c987eb
Error in cat.
MylesBartlett Oct 29, 2019
5b4141a
Remove stop grad in kl computation.
MylesBartlett Oct 29, 2019
ee53613
Prelearn s.
MylesBartlett Oct 29, 2019
29014d3
Fix s after pretraining.
MylesBartlett Oct 29, 2019
f5af9c4
Set loss instead of compound.
MylesBartlett Oct 29, 2019
4d076bb
Logging.
MylesBartlett Oct 29, 2019
985708c
Remove zs pretraining.
MylesBartlett Oct 29, 2019
a7f68c1
Missing brackets.
MylesBartlett Oct 29, 2019
f3509dd
Log s-inverted and mean images.
msbartlett Oct 30, 2019
cbf6bae
Log all images to wandb.
msbartlett Oct 30, 2019
ec3fa2a
Correctly zero out s to obtain reconstructions if not partitioning.
msbartlett Oct 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ venv.bak/

# IDEs
.idea/
.vscode/settings.json

# project-specific
experiments/
Expand Down
5 changes: 3 additions & 2 deletions nosinn/data/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ethicml.data import Adult
from torch.utils.data import Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import MNIST, KMNIST

from .celeba import CelebA
from nosinn.data.dataset_wrappers import DataTupleDataset, LdAugmentedDataset
Expand Down Expand Up @@ -35,6 +35,7 @@ def load_dataset(args) -> DatasetTriplet:
base_aug.insert(0, transforms.Pad(args.padding))
train_data = MNIST(root=args.root, download=True, train=True)
pretrain_data, train_data = random_split(train_data, lengths=(50000, 10000))
# pretrain_data = KMNIST(root=args.root, download=True, train=True)
test_data = MNIST(root=args.root, download=True, train=False)

colorizer = LdColorizer(
Expand Down Expand Up @@ -83,7 +84,7 @@ def load_dataset(args) -> DatasetTriplet:
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)

Expand Down
74 changes: 50 additions & 24 deletions nosinn/models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.distributions as td
from tqdm import tqdm
import torch.nn as nn
Expand All @@ -7,24 +8,38 @@


class AutoEncoder(nn.Module):
def __init__(self, encoder, decoder, optimizer_args=None):
def __init__(self, encoder, decoder, decode_with_s=False, optimizer_args=None):
super(AutoEncoder, self).__init__()

self.encoder: ModelBase = ModelBase(encoder, optimizer_args=optimizer_args)
self.decoder: ModelBase = ModelBase(decoder, optimizer_args=optimizer_args)
self.decode_with_s = decode_with_s

def encode(self, inputs):
return self.encoder(inputs)

def decode(self, encoding):
decoding = self.decoder(encoding)
def reconstruct(self, encoding, s=None):
decoding = self.decode(encoding, s)

if decoding.dim() == 4 and decoding.size(1) > 3:
decoding = decoding[:64].view(decoding.size(0), -1, *decoding.shape[-2:])
fac = decoding.size(1) - 1
num_classes = 256
decoding = decoding[:64].view(decoding.size(0), num_classes, -1, *decoding.shape[-2:])
fac = num_classes - 1
decoding = decoding.max(dim=1)[1].float() / fac

return decoding

def decode(self, encoding, s=None):
decoder_input = encoding
if s is not None and self.decode_with_s:
if encoding.dim() == 4:
s = s.view(s.size(0), -1, 1, 1).float()
s = s.expand(-1, -1, decoder_input.size(-2), decoder_input.size(-1))
decoder_input = torch.cat([decoder_input, s], dim=1)
decoding = self.decoder(decoder_input)

return decoding

def forward(self, inputs, reverse: bool = True):
if reverse:
return self.decode(inputs)
Expand All @@ -39,25 +54,29 @@ def step(self):
self.encoder.step()
self.decoder.step()

def routine(self, inputs, loss_fn):
def routine(self, inputs, loss_fn, s=None):
encoding = self.encoder(inputs)
decoding = self.decoder(encoding)
return loss_fn(decoding, inputs)
decoding = self.decoder(encoding, s=s)
loss = loss_fn(decoding, inputs)
loss /= inputs.size(0)

return loss

def fit(self, train_data, epochs, device, loss_fn=nn.MSELoss()):
def fit(self, train_data, epochs, device, loss_fn):

self.train()

with tqdm(total=epochs * len(train_data)) as pbar:
for epoch in range(epochs):

for x, _, _ in train_data:
for x, s, _ in train_data:

x = x.to(device)
if self.decode_with_s:
s = s.to(device)

self.zero_grad()
loss = self.routine(x, loss_fn=loss_fn)
loss /= x.size(0)
loss = self.routine(x, loss_fn=loss_fn, s=s)

loss.backward()
self.step()
Expand All @@ -67,10 +86,13 @@ def fit(self, train_data, epochs, device, loss_fn=nn.MSELoss()):


class VAE(AutoEncoder):
def __init__(self, encoder, decoder, kl_weight=0.1, optimizer_args=None):
super(AutoEncoder, self).__init__()

super().__init__(encoder=encoder, decoder=decoder, optimizer_args=optimizer_args)
def __init__(self, encoder, decoder, kl_weight=0.1, decode_with_s=False, optimizer_args=None):
super().__init__(
encoder=encoder,
decoder=decoder,
decode_with_s=decode_with_s,
optimizer_args=optimizer_args,
)
self.encoder: ModelBase = ModelBase(encoder, optimizer_args=optimizer_args)
self.decoder: ModelBase = ModelBase(decoder, optimizer_args=optimizer_args)

Expand All @@ -86,7 +108,7 @@ def compute_divergence(self, sample, posterior: td.Distribution):

return kl

def encode(self, x, stochastic=True, return_posterior=False):
def encode(self, x, stochastic=False, return_posterior=False):
loc, scale = self.encoder(x).chunk(2, dim=1)

if stochastic or return_posterior:
Expand All @@ -100,32 +122,36 @@ def encode(self, x, stochastic=True, return_posterior=False):
else:
return sample

def routine(self, x, recon_loss_fn):
def routine(self, x, recon_loss_fn, s=None):
sample, posterior = self.encode(x, stochastic=True, return_posterior=True)
kl = self.compute_divergence(sample, posterior)
recon = self.decoder(sample)
recon_loss = recon_loss_fn(recon, x)

decoder_input = sample
recon = self.decode(decoder_input, s)
recon_loss = recon_loss_fn(recon, s)

recon_loss /= x.size(0)
kl /= x.size(0)

loss = recon_loss + self.kl_weight * kl

return loss
return sample, recon, loss

def fit(self, train_data, epochs, device, loss_fn=nn.MSELoss()):
def fit(self, train_data, epochs, device, loss_fn):

self.train()

with tqdm(total=epochs * len(train_data)) as pbar:
for epoch in range(epochs):

for x, _, _ in train_data:
for x, s, _ in train_data:

x = x.to(device)
if self.decode_with_s:
s = s.to(device)

self.zero_grad()
loss = self.routine(x, recon_loss_fn=loss_fn)
_, _, loss = self.routine(x, recon_loss_fn=loss_fn, s=s)
loss.backward()
self.step()

Expand Down
10 changes: 6 additions & 4 deletions nosinn/models/configs/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def gated_up_conv(in_channels, out_channels, kernel_size, stride, padding, outpu
)


def conv_autoencoder(input_shape, initial_hidden_channels, levels, encoding_dim, decoding_dim, vae):
def conv_autoencoder(
input_shape, initial_hidden_channels, levels, encoding_dim, decoding_dim, vae, s_dim=0
):
encoder = []
decoder = []
c_in, h, w = input_shape
Expand All @@ -49,7 +51,7 @@ def conv_autoencoder(input_shape, initial_hidden_channels, levels, encoding_dim,
encoder_out_dim = 2 * encoding_dim if vae else encoding_dim

encoder += [nn.Conv2d(c_out, encoder_out_dim, kernel_size=1, stride=1, padding=0)]
decoder += [nn.Conv2d(encoding_dim, c_out, kernel_size=1, stride=1, padding=0)]
decoder += [nn.Conv2d(encoding_dim + s_dim, c_out, kernel_size=1, stride=1, padding=0)]
decoder = decoder[::-1]
decoder += [nn.Conv2d(input_shape[0], decoding_dim, kernel_size=1, stride=1, padding=0)]

Expand All @@ -65,7 +67,7 @@ def _linear_block(in_channels, out_channels):
return nn.Sequential(nn.SELU(), nn.Linear(in_channels, out_channels))


def fc_autoencoder(input_shape, hidden_channels, levels, encoding_dim, vae):
def fc_autoencoder(input_shape, hidden_channels, levels, encoding_dim, vae, s_dim=0):
encoder = []
decoder = []

Expand All @@ -80,7 +82,7 @@ def fc_autoencoder(input_shape, hidden_channels, levels, encoding_dim, vae):
encoder_out_dim = 2 * encoding_dim if vae else encoding_dim

encoder += [_linear_block(c_out, encoder_out_dim)]
decoder += [_linear_block(encoding_dim, c_out)]
decoder += [_linear_block(encoding_dim + s_dim, c_out)]
decoder = decoder[::-1]

encoder = nn.Sequential(*encoder)
Expand Down
3 changes: 2 additions & 1 deletion nosinn/models/configs/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def linear_disciminator(in_dim, target_dim, hidden_channels=512, num_blocks=4, u

act = F.relu if use_bn else F.selu
layers = [
nn.Flatten(),
ResidualNet(
in_features=in_dim,
out_features=target_dim,
Expand All @@ -17,7 +18,7 @@ def linear_disciminator(in_dim, target_dim, hidden_channels=512, num_blocks=4, u
activation=act,
dropout_probability=0.0,
use_batch_norm=use_bn,
)
),
]
return nn.Sequential(*layers)

Expand Down
3 changes: 2 additions & 1 deletion nosinn/optimisation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .loss import *
from .train import *
from .train_nosinn import *
from .train_vae_baseline import *
from .config import *
from .utils import *
117 changes: 114 additions & 3 deletions nosinn/optimisation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, my_dict)


def parse_arguments(raw_args=None):
def nosinn_args(raw_args=None):
parser = argparse.ArgumentParser()

# General data set settings
Expand Down Expand Up @@ -97,8 +97,8 @@ def parse_arguments(raw_args=None):

# Discriminator settings
parser.add_argument("--disc-lr", type=float, default=3e-4)
parser.add_argument("--disc-depth", type=int, default=2)
parser.add_argument("--disc-channels", type=int, default=512)
parser.add_argument("--disc-depth", type=int, default=1)
parser.add_argument("--disc-channels", type=int, default=256)
parser.add_argument("--disc-hidden-dims", nargs="*", type=int, default=[])

# Optimization settings
Expand Down Expand Up @@ -143,6 +143,117 @@ def parse_arguments(raw_args=None):
choices=[True, False],
help="Train classifier on encodings as part of validation step.",
)
parser.add_argument("--val-freq", type=int, default=5)
parser.add_argument("--log-freq", type=int, default=10)
parser.add_argument("--root", type=str, default="data")
parser.add_argument(
"--results-csv", type=str, default="", help="name of CSV file to save results to"
)

return parser.parse_args(raw_args)


def vae_args(raw_args=None):
parser = argparse.ArgumentParser()

# General data set settings
parser.add_argument("--dataset", choices=["adult", "cmnist", "celeba"], default="cmnist")
parser.add_argument(
"--data-pcnt",
type=restricted_float,
metavar="P",
default=1.0,
help="data %% should be a real value > 0, and up to 1",
)
parser.add_argument(
"--task-mixing-factor",
type=float,
metavar="P",
default=0.0,
help="How much of meta train should be mixed into task train?",
)
parser.add_argument(
"--pretrain",
type=eval,
default=True,
choices=[True, False],
help="Whether to perform unsupervised pre-training.",
)
parser.add_argument("--pretrain-pcnt", type=float, default=0.4)
parser.add_argument("--task-pcnt", type=float, default=0.2)

# Adult data set feature settings
parser.add_argument("--drop-native", type=eval, default=True, choices=[True, False])
parser.add_argument("--drop-discrete", type=eval, default=False)

# Colored MNIST settings
parser.add_argument("--scale", type=float, default=0.02)
parser.add_argument("--greyscale", type=eval, default=False, choices=[True, False])
parser.add_argument("-bg", "--background", type=eval, default=False, choices=[True, False])
parser.add_argument("--black", type=eval, default=True, choices=[True, False])
parser.add_argument("--binarize", type=eval, default=True, choices=[True, False])
parser.add_argument("--rotate-data", type=eval, default=False, choices=[True, False])
parser.add_argument("--shift-data", type=eval, default=False, choices=[True, False])
parser.add_argument("--padding", type=int, default=2)

# VAEsettings
parser.add_argument("--levels", type=int, default=4)
parser.add_argument("--enc-y-dim", type=int, default=64)
parser.add_argument("--enc-s-dim", type=int, default=0)
parser.add_argument("--cond-decoder", type=eval, choices=[True, False], default=True)
parser.add_argument("--init-channels", type=int, default=32)
parser.add_argument("--recon-loss", type=str, choices=["l1", "l2", "huber", "ce"], default="l2")

# Discriminator settings
parser.add_argument("--disc-enc-y-depth", type=int, default=1)
parser.add_argument("--disc-enc-y-channels", type=int, default=256)
parser.add_argument("--disc-enc-s-depth", type=int, default=1)
parser.add_argument("--disc-enc-s-channels", type=int, default=128)

# Optimization settings
parser.add_argument("--early-stopping", type=int, default=30)
parser.add_argument("--epochs", type=int, default=250)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--test-batch-size", type=int, default=None)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--disc-lr", type=float, default=1e-3)
parser.add_argument("--weight-decay", type=float, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--data-split-seed", type=int, default=888)
parser.add_argument("--warmup-steps", type=int, default=0)
parser.add_argument(
"--gamma",
type=float,
default=1.0,
help="Gamma value for Exponential Learning Rate scheduler.",
)
parser.add_argument(
"--train-on-recon",
type=eval,
default=False,
choices=[True, False],
help="whether to train the discriminator on the reconstructions" "of the encodings.",
)
parser.add_argument("--kl-weight", type=float, default=0.1)
parser.add_argument("--elbo-weight", type=float, default=1)
parser.add_argument("--pred-s-weight", type=float, default=1)

# Evaluation settings
parser.add_argument("--eval-epochs", type=int, metavar="N", default=40)
parser.add_argument("--eval-lr", type=float, default=1e-3)

# Misc
parser.add_argument("--gpu", type=int, default=0, help="which GPU to use (if available)")
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--save", type=str, default="experiments/finn")
parser.add_argument("--evaluate", action="store_true")
parser.add_argument(
"--super-val",
type=eval,
default=True,
choices=[True, False],
help="Train classifier on encodings as part of validation step.",
)
parser.add_argument("--val-freq", type=int, default=4)
parser.add_argument("--log-freq", type=int, default=10)
parser.add_argument("--root", type=str, default="data")
Expand Down
Loading