Ground truth (top), VQ-VAE (middle), DC-VAE (bottom)
Metroplex is a minimalistic JAX/Flax-based VAE codebase. A user can customize a model by switching as:
- Discrete VAE (VQ-VAE) vs. Cotinuous VAE (VD-VAE)
- Unsupervised image modeling vs. DALL-E-like text-to-image modeling
- Pixel-level loss (e.g. L1 loss) vs. non-pixel-level loss (e.g. DC-VAE loss)
In addition to Tensorflow 2, Jax and Flax, one has to also install pillow
, tensorflow_datasets
and einops
:
pip install pillow
pip install tensorflow_datasets
pip install einops
Dataset preparation follows the practice of Tensorflow Datasets.
Runs on TPUs, untested on GPUs but should work with minimal modifications. The example configs are designed to run on a TPU v3-8 pod.
To set up TPUs, sign up for Google Cloud Platform, and create a storage bucket.
Create your VM through a google shell (https://ssh.cloud.google.com/
) with ctpu up --vm-only
so that it can connect to your Google bucket and TPUs and setup the repo as above.
To run the training, adjust the params in a JSON file of your choice in jsons
and run:
python3 train.py --model json_file_path
We believe the adjusting the params properly is not easy for the user at this moment. Please wait until this issue is resolved.
- Test core components of Metroplex
- Add the option for sampling with Transformer and collecting the latents
- Perform large-scale training
- Add WIT as tfds for training our first DALLE-like model
- Add OWIT as tfds for training our final DALLE-like model
- Add evaluation metrics (e.g. Precision & Recall, FID, etc)
- Add options for other modalities, such as audio and video
- We would like to thank EleutherAI for letting us use their computational resources.
- This repo heavily borrows various utility functions, training pipeline and VD-VAE specific components from vdvae-jax by James Townsend.