This model is a Transformer-based GPT architecture trained over discrete image tokens produced by a pre-trained VQ-GAN. It treats image generation as a language modeling task over the codebook indices, enabling autoregressive synthesis and reconstructions of high-fidelity images.
Note: A trained VQ-GAN model is required. This GPT does not operate directly on pixel space but on the quantized latent tokens produced by VQ-GAN.
| Argument | Description | Default | Choices |
|---|---|---|---|
--train |
Train transformer model | False |
|
--sample |
Sample images using transformer | False |
|
--dataset |
Dataset name | mnist |
mnist, cifar10, cifar100, places365, dtd, fashionmnist, chestmnist, bloodmnist, dermamnist, octmnist, tissuemnist, pneumoniamnist, retinamnist, svhn, tinyimagenet, imagenet, celeba |
--batch_size |
Batch size | 256 |
|
--n_epochs |
Number of training epochs | 100 |
|
--size |
Image size (overrides dataset default) | None |
|
--num_workers |
Number of dataloader workers | 0 |
|
--warmup |
Warmup epochs (may affect VQ-GAN if jointly trained) | 10 |
|
--channels |
Base channels in encoder/decoder (used for VQ-GAN context) | 64 |
|
--z_channels |
Latent channels (VQ-GAN latent space) | 64 |
|
--ch_mult |
Channel multipliers per resolution (VQ-GAN) | [1, 2, 2] |
|
--num_res_blocks |
Residual blocks per resolution | 2 |
|
--attn_resolutions |
Attention at specific resolutions | [16] |
|
--dropout |
Dropout in encoder/decoder (VQ-GAN) | 0.0 |
|
--double_z |
Use double latent encoding | False |
|
--disc_start |
When to activate VQ-GAN discriminator loss | 10000 |
|
--disc_weight |
Discriminator loss weight | 0.8 |
|
--codebook_weight |
Codebook loss weight (VQ-GAN) | 1.0 |
|
--n_embed |
Number of embeddings in codebook | 128 |
|
--embed_dim |
Embedding dimension (VQ-GAN) | 64 |
|
--embed_dim_t |
Embedding dimension (Transformer) | 64 |
|
--remap |
Remap codebook indices | None |
|
--sane_index_shape |
Use sane index shape for quantizer | False |
|
--checkpoint_vae |
Path to pre-trained VQ-GAN checkpoint | None |
|
--checkpoint_gpt |
Path to Transformer checkpoint | None |
|
--colorize_nlabels |
Labels used for colorization tasks | None |
|
--lr |
Learning rate | 1e-4 |
|
--no_wandb |
Disable Wandb logging | False |
|
--sample_and_save_freq |
Sample/save frequency (in epochs) | 20 |
|
--n_layer |
Number of Transformer layers | 6 |
|
--n_head |
Number of attention heads | 8 |
|
--block_size |
Length of token sequence (i.e., spatial token length) | 64 |
|
--bias |
Use bias terms in transformer layers | False |
|
--dropout_t |
Dropout in Transformer | 0.1 |
|
--betas |
Adam optimizer betas | [0.9, 0.95] |
|
--weight_decay |
Weight decay for optimizer | 0.1 |
|
--num_samples |
Number of samples to generate | 16 |
|
--temperature |
Sampling temperature | 1.0 |
|
--top_k |
Top-k filtering for sampling | None |
For more information, see util.py or run:
python GPT.py --helpBefore training GPT, you must have a trained VQ-GAN to encode images into discrete tokens. The GPT is trained autoregressively over these tokens using a Transformer.
python GPT.py --train --dataset cifar10 --checkpoint_vae ./../../models/VQGAN/VQGAN_cifar10.ptTo sample images using a trained GPT model and a VQ-GAN decoder:
python GPT.py --sample --dataset cifar10 --checkpoint_vae ./../../models/VQGAN/VQGAN_cifar10.pt --checkpoint_gpt ./../../models/GPT/GPT_cifar10.pt