Skip to content

Latest commit

 

History

History
71 lines (59 loc) · 14 KB

File metadata and controls

71 lines (59 loc) · 14 KB

GPT

This model is a Transformer-based GPT architecture trained over discrete image tokens produced by a pre-trained VQ-GAN. It treats image generation as a language modeling task over the codebook indices, enabling autoregressive synthesis and reconstructions of high-fidelity images.

Note: A trained VQ-GAN model is required. This GPT does not operate directly on pixel space but on the quantized latent tokens produced by VQ-GAN.

Parameters

Argument Description Default Choices
--train Train transformer model False
--sample Sample images using transformer False
--dataset Dataset name mnist mnist, cifar10, cifar100, places365, dtd, fashionmnist, chestmnist, bloodmnist, dermamnist, octmnist, tissuemnist, pneumoniamnist, retinamnist, svhn, tinyimagenet, imagenet, celeba
--batch_size Batch size 256
--n_epochs Number of training epochs 100
--size Image size (overrides dataset default) None
--num_workers Number of dataloader workers 0
--warmup Warmup epochs (may affect VQ-GAN if jointly trained) 10
--channels Base channels in encoder/decoder (used for VQ-GAN context) 64
--z_channels Latent channels (VQ-GAN latent space) 64
--ch_mult Channel multipliers per resolution (VQ-GAN) [1, 2, 2]
--num_res_blocks Residual blocks per resolution 2
--attn_resolutions Attention at specific resolutions [16]
--dropout Dropout in encoder/decoder (VQ-GAN) 0.0
--double_z Use double latent encoding False
--disc_start When to activate VQ-GAN discriminator loss 10000
--disc_weight Discriminator loss weight 0.8
--codebook_weight Codebook loss weight (VQ-GAN) 1.0
--n_embed Number of embeddings in codebook 128
--embed_dim Embedding dimension (VQ-GAN) 64
--embed_dim_t Embedding dimension (Transformer) 64
--remap Remap codebook indices None
--sane_index_shape Use sane index shape for quantizer False
--checkpoint_vae Path to pre-trained VQ-GAN checkpoint None
--checkpoint_gpt Path to Transformer checkpoint None
--colorize_nlabels Labels used for colorization tasks None
--lr Learning rate 1e-4
--no_wandb Disable Wandb logging False
--sample_and_save_freq Sample/save frequency (in epochs) 20
--n_layer Number of Transformer layers 6
--n_head Number of attention heads 8
--block_size Length of token sequence (i.e., spatial token length) 64
--bias Use bias terms in transformer layers False
--dropout_t Dropout in Transformer 0.1
--betas Adam optimizer betas [0.9, 0.95]
--weight_decay Weight decay for optimizer 0.1
--num_samples Number of samples to generate 16
--temperature Sampling temperature 1.0
--top_k Top-k filtering for sampling None

For more information, see util.py or run:

python GPT.py --help

Training

Before training GPT, you must have a trained VQ-GAN to encode images into discrete tokens. The GPT is trained autoregressively over these tokens using a Transformer.

python GPT.py --train --dataset cifar10 --checkpoint_vae ./../../models/VQGAN/VQGAN_cifar10.pt

Sample

To sample images using a trained GPT model and a VQ-GAN decoder:

python GPT.py --sample --dataset cifar10 --checkpoint_vae ./../../models/VQGAN/VQGAN_cifar10.pt --checkpoint_gpt ./../../models/GPT/GPT_cifar10.pt