| Argument | Description | Default | Choices |
|---|---|---|---|
--train |
Train model | False |
|
--sample |
Sample from model | False |
|
--outlier_detection |
Outlier detection | False |
|
--dataset |
Dataset name | mnist |
mnist, cifar10, fashionmnist, chestmnist, octmnist, tissuemnist, pneumoniamnist, svhn, tinyimagenet, cifar100, places365, dtd, imagenet |
--no_wandb |
Disable Wandb | False |
|
--out_dataset |
Outlier dataset name | fashionmnist |
mnist, cifar10, fashionmnist, chestmnist, octmnist, tissuemnist, pneumoniamnist, svhn, tinyimagenet, cifar100, places365, dtd,imagenet |
--batch_size |
Batch size | 128 |
|
--n_epochs |
Number of epochs | 100 |
|
--lr |
Learning rate | 1e-3 |
|
--gamma |
Gamma for the lr scheduler | 0.99 |
|
--sample_and_save_freq |
Sample and save frequency | 5 |
|
--hidden_channels |
Number of channels for the convolutional layers | 64 |
|
--checkpoint |
Checkpoint path | None |
|
--num_workers |
Number of workers for Dataloader | 0 |
The PixelCNN can be trained with:
python P-CNN.py --train
For sampling you must provide the checkpoint:
python P-CNN.py --sample --checkpoint ./../../models/PixelCNN/PixelCNN_mnist.pt
Outlier Detection is performed by using the NLL scores generated by the model:
python P-CNN.py --outlier_detection --checkpoint ./../../models/PixelCNN/PixelCNN_mnist.pt