Skip to content

Conversation

@DBraun
Copy link

@DBraun DBraun commented Oct 28, 2025

Thanks for this repository. I did this refactor to make it have more of a standard module organization. Then after installing, in other projects one can do from diffuse_nnx.networks.transformers.lightning_ddt_nnx import LightningDDT etc. Here is a CLAUDE.md in case you'd like to add it too

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Overview

DiffuseNNX is a JAX/NNX library for diffusion and flow matching generative models. It implements DiT (Diffusion Transformer) and variants for ImageNet training with various sampling strategies. Built on JAX and Flax NNX (with PyTorch-like syntax).

Project Structure

The codebase uses the standard src/ layout:

diffuse_nnx/
├── src/diffuse_nnx/          # Main package code
│   ├── __main__.py            # Entry point (python -m diffuse_nnx)
│   ├── configs/               # Configuration files
│   ├── data/                  # Data loading and preprocessing
│   ├── eval/                  # Evaluation metrics and tools
│   ├── interfaces/            # Diffusion/flow matching interfaces
│   ├── networks/              # Neural network architectures
│   ├── samplers/              # Sampling strategies
│   ├── trainers/              # Training loops
│   └── utils/                 # Utility functions
├── tests/                     # Test suite (outside package)
├── commands/                  # Reference training scripts
├── docs/                      # Documentation
└── setup.py                   # Package configuration

Development Commands

Installation

# For TPU (Google Cloud TPU)
pip install -e .[tpu]

# For GPU (CUDA 12)
pip install -e .[gpu]

# For CPU-only development
pip install -e .

# With documentation tools
pip install -e .[docs]

Testing

# Run all tests (discovers files matching *_tests.py)
python tests/runner.py

# Run individual test
python tests/<path-to-test-file>.py

Training

# With Google Cloud Storage
python -m diffuse_nnx \
    --config=src/diffuse_nnx/configs/dit_imagenet.py \
    --bucket=$GCS_BUCKET \
    --workdir=experiment_name

# Local filesystem (omit --bucket)
python -m diffuse_nnx \
    --config=src/diffuse_nnx/configs/dit_imagenet.py \
    --workdir=experiment_name

# Using the installed console script
diffuse-nnx \
    --config=src/diffuse_nnx/configs/dit_imagenet.py \
    --workdir=experiment_name

Reference training scripts are in commands/.

Architecture

Core Design Pattern

The codebase separates concerns into three main layers:

  1. Interfaces (diffuse_nnx.interfaces): Diffusion/flow matching algorithms (SiT, EDM, MeanFlow, REPA)
  2. Networks (diffuse_nnx.networks): Model architectures (DiT, LightningDiT, VAE encoders/decoders)
  3. Samplers (diffuse_nnx.samplers): Sampling strategies (Euler, Heun, Euler-Maruyama)

Training Flow

  1. __main__.py parses flags and creates workdir (GCS bucket or local filesystem)
  2. Loads config from configs/*.py (uses ml_collections.ConfigDict)
  3. Calls get_trainers() to select trainer module (currently only dit_imagenet)
  4. Trainer handles full train/eval loop with NNX models

Config System (diffuse_nnx.configs)

Configs use ml_collections.ConfigDict with presets from common_specs.py:

  • _imagenet_data_presets: Dataset paths, image sizes, batch sizes
  • _imagenet_encoder_presets: Encoder types (RGB, StabilityVAE, etc.)
  • _dit_network_presets: Network architectures (hidden_size, depth, num_heads)

Important: Update _*_data_presets entries in common_specs.py to point to your ImageNet data paths and FID statistics before training.

NNX Training Pattern

Training uses Flax NNX with in-place updates:

  • Models use nnx.GraphDef and nnx.State split/merge pattern
  • Optimizer wraps model: optimizer.model accesses the network
  • EMA tracking with separate ema_graph and ema_state
  • All state updates happen in-place via optimizer.update(grads) and ema.update(model)

Distributed Training (diffuse_nnx.utils.sharding_utils)

  • Supports TPU training with replicate and FSDP strategies
  • Uses JAX mesh and NamedSharding for distributed arrays
  • See docs/utils/fsdp_in_jax_nnx.ipynb for FSDP tutorial

Import Conventions

All imports use the full package name with diffuse_nnx. prefix:

from diffuse_nnx.data import local_imagenet_dataset
from diffuse_nnx.interfaces import continuous
from diffuse_nnx.networks.transformers import dit_nnx
from diffuse_nnx.utils import checkpoint as ckpt_utils

Important Notes

Platform & Dependencies

  • Designed for TPU (Google Cloud TPU machines); GPU support experimental
  • For GPU: Install with pip install -e .[gpu] instead of [tpu]
  • Requires Python ≥3.9, <3.12
  • Uses CPU-only PyTorch and TensorFlow (minimal overhead)

Environment Setup

Required environment variables (store in .env, never commit):

  • WANDB_API_KEY: For logging (get from https://wandb.ai/authorize)
  • WANDB_ENTITY: Your W&B team/space
  • GOOGLE_APPLICATION_CREDENTIALS: Path to GCP credentials JSON
  • GCS_BUCKET: Google Cloud Storage bucket name

Run gcloud auth application-default login before using GCS.

Flax Transformers Deprecation

Some parts still depend on deprecated Flax transformers library. Reproduction/replacement in progress.

Code Style (from README)

  • PEP 8 with 4-space indentation
  • snake_case for modules/functions, CamelCase for classes
  • Import order: stdlib, third-party, local deps (alphabetically sorted within groups)
  • Module docstrings + type hints for public APIs
  • Use absl.logging for logs, ml_collections.ConfigDict for configs
  • See trainers/dit_imagenet.py for import grouping example

Test Naming Convention

Test files must follow *_tests.py pattern for tests/runner.py to discover them.

@willisma
Copy link
Owner

willisma commented Nov 4, 2025

Thanks for the amazing work! The overall refactorization looks good to me! Have you tested if there's any effect of this modularization on training / evaluation code?

@DBraun
Copy link
Author

DBraun commented Nov 4, 2025

Thanks. I haven't tested it beyond basic installation and importing. I would suggest creating a GitHub Action workflow to automate some of this. I can add to this branch if you'd like.

@willisma
Copy link
Owner

willisma commented Nov 4, 2025

Yes that would be great!

@DBraun DBraun force-pushed the feature/modularize branch from 6dab321 to 97a0e15 Compare November 9, 2025 18:53
@DBraun
Copy link
Author

DBraun commented Nov 9, 2025

@willisma There's now a GitHub action that runs the tests. However, there are some failures. It looks like Flax 0.10.7 gets installed. If you were to update to 0.12.0 then there would definitely be some more errors. Do you think you could take it from here since you're more familiar with the test code?

@willisma
Copy link
Owner

Thanks for the amazing work! Yes let me take over!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants