Skip to content

Implementing Generative Adversarial Network (GAN) training #221

Open
@elahpeca

Description

@elahpeca

neural-fortran in it's current state effectively supports supervised learning, where gradients are computed in the network%backward() method of each layer based on the derivative of the loss function with respect to target labels. However, training GANs requires computing gradients for the generator based on feedback from the discriminator, rather than explicit target labels for the generated data.

The standard approach to GAN training involves:

  1. Generating data by the generator.
  2. Evaluating the generated data by the discriminator.
  3. Calculating the generator's loss based on the discriminator's output (so the goal is to fool the discriminator).
  4. Calculating the gradient of this loss with respect to the generator's output.
  5. Backpropagating this gradient through the generator to update its weights.

Currently, framework does not appear to provide a built-in mechanism for automatically computing the gradient of the loss with respect to intermediate network outputs (in our case, the generator's output). The backward method in the current implementation is tightly coupled with comparing the network's output to explicit target labels.

What potentially can be implemented:

  1. Allow the backward method (or a new, dedicated method) to accept a pre-computed gradient from an external source. This would enable users to calculate gradients using custom logic (e.g., based on the output of another network like the discriminator in a GAN) and inject them for backpropagation.

  2. A rudimentary system for tracking operations and automatically computing gradients of arbitrary tensors with respect to others. This could start with support for a limited set of fundamental operations and gradually expand.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions