modularize with src/diffuse_nnx directory #2
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Thanks for this repository. I did this refactor to make it have more of a standard module organization. Then after installing, in other projects one can do
from diffuse_nnx.networks.transformers.lightning_ddt_nnx import LightningDDTetc. Here is aCLAUDE.mdin case you'd like to add it tooCLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Overview
DiffuseNNX is a JAX/NNX library for diffusion and flow matching generative models. It implements DiT (Diffusion Transformer) and variants for ImageNet training with various sampling strategies. Built on JAX and Flax NNX (with PyTorch-like syntax).
Project Structure
The codebase uses the standard src/ layout:
Development Commands
Installation
Testing
Training
Reference training scripts are in
commands/.Architecture
Core Design Pattern
The codebase separates concerns into three main layers:
diffuse_nnx.interfaces): Diffusion/flow matching algorithms (SiT, EDM, MeanFlow, REPA)diffuse_nnx.networks): Model architectures (DiT, LightningDiT, VAE encoders/decoders)diffuse_nnx.samplers): Sampling strategies (Euler, Heun, Euler-Maruyama)Training Flow
__main__.pyparses flags and creates workdir (GCS bucket or local filesystem)configs/*.py(uses ml_collections.ConfigDict)get_trainers()to select trainer module (currently onlydit_imagenet)Config System (
diffuse_nnx.configs)Configs use
ml_collections.ConfigDictwith presets fromcommon_specs.py:_imagenet_data_presets: Dataset paths, image sizes, batch sizes_imagenet_encoder_presets: Encoder types (RGB, StabilityVAE, etc.)_dit_network_presets: Network architectures (hidden_size, depth, num_heads)Important: Update
_*_data_presetsentries incommon_specs.pyto point to your ImageNet data paths and FID statistics before training.NNX Training Pattern
Training uses Flax NNX with in-place updates:
nnx.GraphDefandnnx.Statesplit/merge patternoptimizer.modelaccesses the networkema_graphandema_stateoptimizer.update(grads)andema.update(model)Distributed Training (
diffuse_nnx.utils.sharding_utils)docs/utils/fsdp_in_jax_nnx.ipynbfor FSDP tutorialImport Conventions
All imports use the full package name with
diffuse_nnx.prefix:Important Notes
Platform & Dependencies
pip install -e .[gpu]instead of[tpu]Environment Setup
Required environment variables (store in
.env, never commit):WANDB_API_KEY: For logging (get from https://wandb.ai/authorize)WANDB_ENTITY: Your W&B team/spaceGOOGLE_APPLICATION_CREDENTIALS: Path to GCP credentials JSONGCS_BUCKET: Google Cloud Storage bucket nameRun
gcloud auth application-default loginbefore using GCS.Flax Transformers Deprecation
Some parts still depend on deprecated Flax
transformerslibrary. Reproduction/replacement in progress.Code Style (from README)
snake_casefor modules/functions,CamelCasefor classesabsl.loggingfor logs,ml_collections.ConfigDictfor configstrainers/dit_imagenet.pyfor import grouping exampleTest Naming Convention
Test files must follow
*_tests.pypattern fortests/runner.pyto discover them.