This model supports Accelerate for Multi-GPU and Mixed Precision Training.
| Argument | Description | Default | Choices |
|---|---|---|---|
--train |
Train model | False |
|
--sample |
Sample model | False |
|
--outlier_detection |
Outlier detection | False |
|
--dataset |
Dataset name | mnist |
mnist, cifar10, fashionmnist, chestmnist, octmnist, tissuemnist, pneumoniamnist, svhn, tinyimagenet, cifar100, places365, dtd, imagenet |
--no_wandb |
Disable Wandb | False |
|
--out_dataset |
Outlier dataset name | fashionmnist |
mnist, cifar10, cifar100, places365, dtd, fashionmnist, chestmnist, octmnist, tissuemnist, pneumoniamnist, svhn, tinyimagenet, imagenet |
--batch_size |
Batch size | 128 |
|
--n_epochs |
Number of epochs | 100 |
|
--lr |
Learning rate | 5e-4 |
|
--patch_size |
Patch size | 2 |
|
--dim |
Dimension | 64 |
|
--n_layers |
Number of layers | 6 |
|
--n_heads |
Number of heads | 4 |
|
--multiple_of |
Multiple of | 256 |
|
--ffn_dim_multiplier |
FFN dim multiplier | None |
|
--norm_eps |
Norm eps | 1e-5 |
|
--class_dropout_prob |
Class dropout probability | 0.1 |
|
--sample_and_save_freq |
Sample and save frequency | 5 |
|
--num_classes |
Number of classes | 10 |
|
--checkpoint |
Checkpoint path | None |
|
--num_workers |
Number of workers for Dataloader | 0 |
|
--latent |
Use latent version | False |
|
--warmup |
Number of warmup epochs | 10 |
|
--decay |
Decay rate | 1e-5 |
|
--ema_rate |
Exponential moving average rate | 0.999 |
|
--conditional |
Conditional model | False |
|
--size |
Size of input image | None |
You can find out more about the parameters by checking util.py or by running the following command on the example script:
python RF.py --help
You can train this model with the following command:
accelerate launch RF.py --train --dataset mnist
To sample, please provide the checkpoint:
python RF.py --sample --dataset fashionmnist --checkpoint ./../../models/FlowMatching/FM_mnist.pt