Skip to content

The problem of few-shot learning is investigated, when only limited amounts of labeled data for training is available, with the usage of generative models.

Notifications You must be signed in to change notification settings

kzGarifullin/GenerativeClassification

Repository files navigation

GenerativeClassification

Few Shot Generative Classification

Traditional supervised classification approaches limit the scalability and training efficiency of neural networks because they require significant human effort and computational resources to partition the data.

The main goal of this research is to develop a method that reduces the need for manual annotation by training feature representations directly from unlabeled data.

Concept

image

Given a scenario where we possess a small labeled dataset alongside a larger unlabeled dataset, we can approach classification through the following steps:

  • Train a generative model on the unlabeled dataset to learn the underlying data distribution.
  • Utilize the trained generative model to extract more representative features for the labeled images, effectively enriching the feature space.
  • Train a small neural network using the enriched features to make predictions for the corresponding labels.

This approach leverages the generative model to enable the small neural network to make accurate predictions despite the limited labeled data.

Optimal time selection for diffusion model

In the framework of the diffusion model for feature aggregation, the choice of the optimal diffusion time step parameter becomes paramount in determining the temporal influence of features.

MNIST CIFAR-10 CIFAR-100
image image image

As evident from the plots, the optimal step values for MNIST and CIFAR-10 are $100$ and $50$, respectively. Consequently, we set these timesteps as constants for subsequent experiments.

Features Quality

Assessing the separability of features is an important step towards evaluation of models quality in learning the internal structure of dataset. To assess visually the quality of extracted features from generative models, we implemented code to project those features in 2- and 3-dimensional spaces using Uniform Manifold Approximation and Projection, UMAP.

Features for MNIST:

Diffusion model VAE GAN

Features for CIFAR-10:

Diffusion model VQ-VAE GAN

Features for CIFAR-100:

Diffusion model VQ-VAE GAN

Results

Results for the MNIST dataset

Comparison of three generative models for feature extraction for MNIST dataset.

image

GAN model outperforms the baseline model on small labeled datasets, with up to 64 labeled images per class. The diffusion model and VAE also show an advantage over the baseline, with improvements up to 32 and 12 labeled images per class, respectively.

Results for the CIFAR-10 dataset

image

Results for the CIFAR-100 dataset

image

Reproducibility of diffusion experiments

All experiments were performed in Google Colab. Links to all experiments:

Reproducibility of GAN experiments

All experiments were performed in Google Colab. Links to all experiments:

Scripts Usage

Train VAE/VQ-VAE

MNIST VAE
    usage: train_vae_mnist.py [-h] [-d DEVICE] [-bs BATCH_SIZE] [-e EPOCHS]
                          [-lr LR] [-ld LATENT_DIM] [-pth PATH]

optional arguments:
  -h, --help            show this help message and exit
  -d DEVICE, --device DEVICE
                        Device for training
  -bs BATCH_SIZE, --batch_size BATCH_SIZE
                        Batch size
  -e EPOCHS, --epochs EPOCHS
                        Number of epochs
  -lr LR, --lr LR       Learning rate
  -ld LATENT_DIM, --latent_dim LATENT_DIM
                        Laten space dimension
  -pth PATH, --path PATH
                        Weights path
CIFAR VAE
usage: train_vae_cifar.py [-h] [-d DEVICE] [-bs BATCH_SIZE] [-e EPOCHS]
                       [-lr LR] [-ld LATENT_DIM] [-ct CIFAR_TYPE]
                       [-pth PATH]

optional arguments:
-h, --help            show this help message and exit
-d DEVICE, --device DEVICE
                     Device for training
-bs BATCH_SIZE, --batch_size BATCH_SIZE
                     Batch size
-e EPOCHS, --epochs EPOCHS
                     Number of epochs
-lr LR, --lr LR       Learning rate
-ld LATENT_DIM, --latent_dim LATENT_DIM
                     Laten space dimension
-ct CIFAR_TYPE, --cifar_type CIFAR_TYPE
                     CIFAR10 or CIFAR100
-pth PATH, --path PATH
                     Weights path
CIFAR VQ-VAE
usage: train_vqvae.py [-h] [-bs BATCH_SIZE] [-e EPOCHS] [-lr LR]
                      [-ld LATENT_DIM] [-ct CIFAR_TYPE] [-pth PATH]

optional arguments:
  -h, --help            show this help message and exit
  -bs BATCH_SIZE, --batch_size BATCH_SIZE
                        Batch size
  -e EPOCHS, --epochs EPOCHS
                        Number of epochs
  -lr LR, --lr LR       Learning rate
  -ld LATENT_DIM, --latent_dim LATENT_DIM
                        Laten space dimension
  -ct CIFAR_TYPE, --cifar_type CIFAR_TYPE
                        CIFAR10 or CIFAR100
  -pth PATH, --path PATH
                        Weights path

Extract Features From Different Models

MNIST VAE
usage: extract_features_train_smallnet_mnist.py [-h] [-p PATH] [-d DEVICE]
                                                [-s SIZE_PER_CLASS]
                                                [-e EPOCHS] [-hd HEAD]

optional arguments:
  -h, --help            show this help message and exit
  -p PATH, --path PATH  Path of weights
  -d DEVICE, --device DEVICE
                        Device for training
  -s SIZE_PER_CLASS, --size_per_class SIZE_PER_CLASS
                        Number of images per class
  -e EPOCHS, --epochs EPOCHS
                        Number of epochs
  -hd HEAD, --head HEAD
                        Type of head model: Lin or NonLin
CIFAR VAE
usage: extract_features_train_smallnet_cifar.py [-h] [-p PATH] [-d DEVICE]
                                                [-s SIZE_PER_CLASS]
                                                [-e EPOCHS] [-hd HEAD]
                                                [-ct CIFAR_TYPE]

optional arguments:
  -h, --help            show this help message and exit
  -p PATH, --path PATH  Path of weights
  -d DEVICE, --device DEVICE
                        Device for training
  -s SIZE_PER_CLASS, --size_per_class SIZE_PER_CLASS
                        Number of images per class
  -e EPOCHS, --epochs EPOCHS
                        Number of epochs
  -hd HEAD, --head HEAD
                        Type of head model: Lin or NonLin
  -ct CIFAR_TYPE, --cifar_type CIFAR_TYPE
                        CIFAR10 or CIFAR100
CIFAR VQ-VAE
usage: extract_features_train_smallnet_vqvae.py [-h] [-p PATH] [-d DEVICE]
                                                [-s SIZE_PER_CLASS]
                                                [-e EPOCHS] [-hd HEAD]
                                                [-nc NUM_CLASSES]
                                                [-ct CIFAR_TYPE]

optional arguments:
  -h, --help            show this help message and exit
  -p PATH, --path PATH  Path of weights
  -d DEVICE, --device DEVICE
                        Device for training
  -s SIZE_PER_CLASS, --size_per_class SIZE_PER_CLASS
                        Number of images per class
  -e EPOCHS, --epochs EPOCHS
                        Number of epochs
  -hd HEAD, --head HEAD
                        Type of head model: Lin or NonLin
  -nc NUM_CLASSES, --num_classes NUM_CLASSES
                        Num Classes
  -ct CIFAR_TYPE, --cifar_type CIFAR_TYPE
                        CIFAR10 or CIFAR100

Developers

About

The problem of few-shot learning is investigated, when only limited amounts of labeled data for training is available, with the usage of generative models.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages