Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: CI

on:
push:
branches: [ main, feature/* ]
tags:
- 'v*'
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .

- name: Run tests
run: python tests/runner.py

publish:
if: startsWith(github.ref, 'refs/tags/v')
needs: test
runs-on: ubuntu-latest
permissions:
id-token: write
contents: read

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install build tools
run: |
python -m pip install --upgrade pip
pip install build

- name: Build package
run: python -m build

- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,24 @@ PY

## Installation

This codebase requires `Python<=3.11` to run. We do require both PyTorch and Tensorflow, but only the CPU-only version and should incur minimal overhead.
This codebase requires `Python>=3.9,<3.12` to run. We require both PyTorch and Tensorflow, but only the CPU-only versions to minimize overhead.

```bash
# Clone the repository
git clone <repository-url>
cd diffuse_nnx

# Install dependencies
pip install -r requirements.txt
# Install for TPU (Google Cloud TPU machines)
pip install -e .[tpu]

# Install the codebase
# OR install for GPU (CUDA 12)
pip install -e .[gpu]

# OR install base package (CPU-only JAX)
pip install -e .
```

**Note**: The codebase is mainly designed for using on Google Cloud TPU machines and haven't been extensively tested on GPUs. To use it on GPU, replace the `jax[tpu]==0.5.1` dependency in `requirements.txt` with `jax[cuda12]==0.5.1`. We greatly appreciate it if you can help validate the performances on GPU!
**Note**: The codebase is mainly designed for Google Cloud TPU machines and hasn't been extensively tested on GPUs. We greatly appreciate it if you can help validate the performance on GPU!

## Quick Start

Expand Down
4 changes: 2 additions & 2 deletions commands/dit_imagenet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ BUCKET="$GCS_BUCKET"

export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592

WANDB_API_KEY="$WANDB_API_KEY" python main.py \
WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \
--workdir=$WORKDIR \
--bucket=$BUCKET \
--config=configs/$CONFIG.py:imagenet_256-XL_2 \
--config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_256-XL_2 \
--config.data.batch_size=$BATCH_SIZE \
--config.standalone_eval=False \
--config.project_name='diffuse_nnx' \
Expand Down
4 changes: 2 additions & 2 deletions commands/dit_repa_imagenet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ BUCKET="$GCS_BUCKET"

export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592

WANDB_API_KEY="$WANDB_API_KEY" python main.py \
WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \
--workdir=$WORKDIR \
--bucket=$BUCKET \
--config=configs/$CONFIG.py:imagenet_raw_256-XL_2 \
--config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_raw_256-XL_2 \
--config.data.batch_size=$BATCH_SIZE \
--config.standalone_eval=False \
--config.project_name='diffuse_nnx' \
Expand Down
4 changes: 2 additions & 2 deletions commands/mf_imagenet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ BUCKET="$GCS_BUCKET"

export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592

WANDB_API_KEY="$WANDB_API_KEY" python main.py \
WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \
--workdir=$WORKDIR \
--bucket=$BUCKET \
--config=configs/$CONFIG.py:imagenet_256-XL_2 \
--config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_256-XL_2 \
--config.data.batch_size=$BATCH_SIZE \
--config.standalone_eval=False \
--config.project_name='diffuse_nnx' \
Expand Down
4 changes: 2 additions & 2 deletions commands/rae_imagenet.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ BUCKET="$GCS_BUCKET"
export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592

# only RAE eval is supported at the moment
WANDB_API_KEY="$WANDB_API_KEY" python main.py \
WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \
--workdir=$WORKDIR \
--bucket=$BUCKET \
--config=configs/$CONFIG.py:imagenet_raw_256-XL_1 \
--config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_raw_256-XL_1 \
--config.data.batch_size=$BATCH_SIZE \
--config.standalone_eval=True \
--config.project_name='diffuse_nnx' \
Expand Down
33 changes: 0 additions & 33 deletions requirements.txt

This file was deleted.

68 changes: 66 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,71 @@
from setuptools import find_packages, setup

with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

# Core dependencies (without platform-specific JAX variants)
install_requires = [
"jax>=0.5.1",
"jaxlib>=0.5.1",
"flax>=0.10.2",
"optax>=0.2.4",
"orbax-checkpoint>=0.11.4",
"tensorflow",
"torch>=2.5.0",
"torchvision>=0.20.0",
"clu",
"einops",
"transformers",
"timm",
"ml_collections",
"termcolor",
"wandb",
"webdataset",
"google-api-core",
"google-cloud-core",
"google-cloud-storage",
]

# Optional dependencies
extras_require = {
"tpu": ["jax[tpu]==0.5.1"],
"gpu": ["jax[cuda12]==0.5.1"],
"docs": [
"sphinx>=7.0.0",
"pydata-sphinx-theme>=0.15.0",
"myst-parser>=2.0.0",
"sphinx-copybutton>=0.5.0",
"sphinx-design>=0.5.0",
"linkify-it-py>=2.0.0",
],
}

setup(
name="diffuse_nnx",
version="0.1.0",
packages=find_packages(),
)
author="Nanye Ma",
description="A JAX/NNX Library for Diffusion and Flow Matching",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/willisma/diffuse_nnx",
package_dir={"": "src"},
packages=find_packages(where="src"),
python_requires=">=3.9,<3.12",
install_requires=install_requires,
extras_require=extras_require,
entry_points={
"console_scripts": [
"diffuse-nnx=diffuse_nnx.__main__:main",
],
},
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
)
3 changes: 3 additions & 0 deletions src/diffuse_nnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""DiffuseNNX: A JAX/NNX Library for Diffusion and Flow Matching."""

__version__ = "0.1.0"
4 changes: 2 additions & 2 deletions main.py → src/diffuse_nnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ml_collections import config_flags

# deps
from utils import logging_utils, gcloud_utils
from diffuse_nnx.utils import logging_utils, gcloud_utils

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -43,7 +43,7 @@ def prepare_local_workdir(workdir):
def get_trainers(trainer):
"""Get the trainers for the experiment."""
if trainer == 'DiT_ImageNet':
from trainers import dit_imagenet
from diffuse_nnx.trainers import dit_imagenet
return dit_imagenet
else:
raise ValueError(f'Unknown trainer: {trainer}')
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ml_collections

# deps
from configs import common_specs
from diffuse_nnx.configs import common_specs

def get_config(options='imagenet_64-B_2'):
data_options, network_options = options.split('-')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ml_collections

# deps
from configs import common_specs
from diffuse_nnx.configs import common_specs

def get_config(options='imagenet_64-B_2'):
data_options, network_options = options.split('-')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ml_collections

# deps
from configs import common_specs
from diffuse_nnx.configs import common_specs

def get_config(options='imagenet_64-B_2'):
data_options, network_options = options.split('-')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ml_collections

# deps
from configs import common_specs
from diffuse_nnx.configs import common_specs

def get_config(options='imagenet_64-B_2'):
data_options, network_options = options.split('-')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import ml_collections

# deps
from configs import common_specs
from configs import dit_imagenet
from diffuse_nnx.configs import common_specs
from diffuse_nnx.configs import dit_imagenet

def get_config(options='imagenet_64-B_2'):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ml_collections

# deps
from configs import common_specs
from diffuse_nnx.configs import common_specs

def get_config(options='imagenet_64-B_2'):
data_options, network_options = options.split('-')
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import webdataset as wds

# deps
from data import utils, wds_imagenet_dataset
from diffuse_nnx.data import utils, wds_imagenet_dataset


class IterableDatasetShard(torch.utils.data.IterableDataset):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
pyspng = None

# deps
from data import utils
from diffuse_nnx.data import utils


class LatentDataset(torch.utils.data.Dataset):
Expand Down
2 changes: 1 addition & 1 deletion data/utils.py → src/diffuse_nnx/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torchvision import transforms

# deps
from utils import sharding_utils
from diffuse_nnx.utils import sharding_utils

class EasyDict(dict):
def __init__(self, *args, **kwargs):
Expand Down
Loading