Skip to content

minoring/VAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

fd4fae1 · Nov 19, 2021

History

2 Commits
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021
Nov 19, 2021

Repository files navigation

VAE for celebA, MNIST dataset

Training Variational Autoencoder using celebA and MNIST dataset.

To train celebA dataset, download imgimg_align_celeba.zip and list_eval_partition.txt, unzip, place it under data/celebA/. For training, you can modify hyperparameters e.g. learning rate, latent-dim, batch size, lr decay schedule... (Take a look at parser_utils.py). Learning curve will be saved at "{model}{dataset}.csv" by default (in this case e.g. cnn_celebA.csv). While training, reconstruction results for the test dataset are saved in results/{dataset}/{model} direcotry, (e.g. results/celebA/cnn). Pretrained model is provided in "{model}{dataset}.pt" format (e.g. cnn_celebA.pt).

celebA Dataset

Training with CNN-based Model

python train.py --model cnn --dataset celebA --num-epochs 300 --latent-dim 128

Reconstruction examples

Original image on the top, reconstruction on the bottom.

Generate Samples

Generation (decode) from random Gaussian distribution. python generate_samples.py --model cnn --dataset celebA --saved-path cnn_celebA.pt --latent-dim 128

Plot Learning Curve

python plot_learning_curve.py --loss Loss --learning-curve-csv cnn_celebA.csv
python plot_learning_curve.py --loss Reconstruction_Loss --learning-curve-csv cnn_celebA.csv
python plot_learning_curve.py --loss KL_Loss --learning-curve-csv cnn_celebA.csv
Loss (Reconstruction + KL) Reconstruction Loss KL Loss

(Both training loss and Test loss decrease, I guess you can train more epochs if you want)

Training with FC-based Model

python train.py --model fc --dataset celebA --num-epochs 300 --latent-dim 128

Reconstruction examples

Original image on the top, reconstruction on the bottom.

Generate Samples

Generation (decode) from random Gaussian distribution.

python generate_samples.py --model fc --dataset celebA --saved-path fc_celebA.pt --latent-dim 128

Plot Learning Curve

python plot_learning_curve.py --loss Loss --learning-curve-csv fc_celebA.csv
python plot_learning_curve.py --loss Reconstruction_Loss --learning-curve-csv fc_celebA.csv
python plot_learning_curve.py --loss KL_Loss --learning-curve-csv fc_celebA.csv
Loss (Reconstruction + KL) Reconstruction Loss KL Loss

MNIST Dataset

Training with CNN-based Model

python train.py --model cnn --dataset mnist --num-epochs 1000 --latent-dim 16

Reconstruction examples

Original image on the top, reconstruction on the bottom.

Generate Samples

Generation (decode) from random Gaussian distribution.

python generate_samples.py --model cnn --dataset mnist --saved-path cnn_mnist.pt --latent-dim 16

Plot Learning Curve

python plot_learning_curve.py --loss Loss --learning-curve-csv cnn_mnist.csv
python plot_learning_curve.py --loss Reconstruction_Loss --learning-curve-csv cnn_mnist.csv
python plot_learning_curve.py --loss KL_Loss --learning-curve-csv cnn_mnist.csv
Loss (Reconstruction + KL) Reconstruction Loss KL Loss

Plot Latent Space

Using t-SNE, maps 16 dimensional latent space to 2D.

python plot_latent_space.py --saved-path cnn_mnist.pt --latent-dim 16 --dataset mnist --model cnn

References

Papers

Dataset

About

Variational Autoencoder in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages