A PyTorch implementation of a variational autoencoder that achieving disentanglement through the FactorVAE concept [1]. Trained on ChestMNIST (dataset of X-ray chest images) using PyTorch.
A TU Darmstadt project by Elias Fiedler, Fynn Becker, Patrick Reidelbach.
A variational autoencoder consists of two neural networks. The encoder and the decoder. The encoder converts an input image into a representation in a latent space. This latent space is of smaller dimension compared to the input. Therefore, the encoder essentially compresses the input data.
The decoder attempts to replicate the input image from the latent space representation provided by the encoder. It performs an action similar to decompression. By altering the latent space representation of an input image or by randomly sampling from the latent space and passing the sample to the decoder new images can be generated using the decoder.
A traditional VAE uses the latent space in a way that is not understandable for humans. The dimensions of the latent vector do not encode a single piece of information. Disentanglement refers to the model learning a latent vector, where each dimension represents a single piece of interpretable information (e.g. shape, orientation, size). This means each dimension controls a separate aspect of the image that can be adjusted independently of other aspects.
In their paper, Hyunjik Kim and Andriy Mnih [1] propose a special type of VAE that achieves disentanglement by factorizing. The presented FactorVAE is an adversarial model consisting of a normal VAE and a third neural network: The discriminator. During training, the discriminator receives a regular sample created by the encoder and a sample where the order of the latent dimensions has been randomly permutated. It then attempts to classify these two samples. Using the outcome of the classification the VAE is penalized based on how much the dimensions of the latent space are related to each other. This creates an incentive for the VAE to become more disentangled.
We used a latent vector dimension of
The encoder uses 3 convolution layers, which generate
For exact reconstruction, we set
The discriminator is a network with
Architecture of our discriminator.
Reconstruction quality started to suffer when disentangling the model. This was partly solved through a beta-annealing strategy, but can still be improved. Below, we have visualized 10 images from the dataset (first two rows) and our model's replications (last two rows):
By randomly generating latent vectors, we are able to create synthetic images that were not present in the dataset. Below, we have visualized a grid of 20 synthetic images generated by our model:
Synthetically generated images (random sample).
We were able to achieve decent disentanglement in our model. Each dimension in the latent space encodes
an independent piece of interpretable information. The visualization below demonstrates the disentanglement.
The same picture is used as a baseline for each row. Then in each row only one dimension of the latent space is altered
by adding an offset, all the other dimensions remain untouched. The offset ranges from
Latent space:
Latent space with interpretations.
Upon closer inspection, a specific characteristic that is changed throughout the images can be found for each row in the visualization.
Increasing the latent vector dimension just slightly to
Latent vector dimension is too high.
As long as the model was still disentangled well enough, it was also interesting to see the model find the
same (or very similar) attributes in the data (dimensions in the latent vector), even when changing parameters
such as
The created images are still blurry, with additional training resources this can probably be improved. Using a perceptual similarity metric (e.g. SSIM) for replicated images in the loss function would likely also be helpful.
To install required packages:
pip install -r requirements.txt
Please note that this does not install torch with CUDA support. If your machine supports CUDA, we strongly recommend using the CUDA enabled versions.
There are several ways to start the project. To train a model with your own parameters, use:
python main.py --mode train
When training is complete, a GIF will be saved to project_root/gifs with a reconstruction visualization from every
training epoch. Several metrics are logged to TensorBoard, which you can start as follows:
tensorboard --logdir [dir]
To evaluate a model by visualizing the latent vector or reconstruction quality, use:
python main.py --mode latent
python main.py --mode replicate
Lastly, to perform random sampling (i.e. generating synthetic images), use:
python main.py --mode random
A pre-trained model file (
This project is licensed under the MIT License.
[1]: Hyunjik Kim and Andriy Mnih. Disentangling by Factorising. arXiv preprint arXiv:1802.05983, 2019. https://arxiv.org/abs/1802.05983


