diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..c35e5d8 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,59 @@ +name: CI + +on: + push: + branches: [ main, feature/* ] + tags: + - 'v*' + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + + - name: Run tests + run: python tests/runner.py + + publish: + if: startsWith(github.ref, 'refs/tags/v') + needs: test + runs-on: ubuntu-latest + permissions: + id-token: write + contents: read + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tools + run: | + python -m pip install --upgrade pip + pip install build + + - name: Build package + run: python -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/README.md b/README.md index aee0d40..b343b7e 100644 --- a/README.md +++ b/README.md @@ -118,21 +118,24 @@ PY ## Installation -This codebase requires `Python<=3.11` to run. We do require both PyTorch and Tensorflow, but only the CPU-only version and should incur minimal overhead. +This codebase requires `Python>=3.9,<3.12` to run. We require both PyTorch and Tensorflow, but only the CPU-only versions to minimize overhead. ```bash # Clone the repository git clone cd diffuse_nnx -# Install dependencies -pip install -r requirements.txt +# Install for TPU (Google Cloud TPU machines) +pip install -e .[tpu] -# Install the codebase +# OR install for GPU (CUDA 12) +pip install -e .[gpu] + +# OR install base package (CPU-only JAX) pip install -e . ``` -**Note**: The codebase is mainly designed for using on Google Cloud TPU machines and haven't been extensively tested on GPUs. To use it on GPU, replace the `jax[tpu]==0.5.1` dependency in `requirements.txt` with `jax[cuda12]==0.5.1`. We greatly appreciate it if you can help validate the performances on GPU! +**Note**: The codebase is mainly designed for Google Cloud TPU machines and hasn't been extensively tested on GPUs. We greatly appreciate it if you can help validate the performance on GPU! ## Quick Start diff --git a/commands/dit_imagenet.sh b/commands/dit_imagenet.sh index 30c2a1a..9efe1f5 100755 --- a/commands/dit_imagenet.sh +++ b/commands/dit_imagenet.sh @@ -11,10 +11,10 @@ BUCKET="$GCS_BUCKET" export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592 -WANDB_API_KEY="$WANDB_API_KEY" python main.py \ +WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \ --workdir=$WORKDIR \ --bucket=$BUCKET \ - --config=configs/$CONFIG.py:imagenet_256-XL_2 \ + --config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_256-XL_2 \ --config.data.batch_size=$BATCH_SIZE \ --config.standalone_eval=False \ --config.project_name='diffuse_nnx' \ diff --git a/commands/dit_repa_imagenet.sh b/commands/dit_repa_imagenet.sh index b49c84d..9c7a847 100755 --- a/commands/dit_repa_imagenet.sh +++ b/commands/dit_repa_imagenet.sh @@ -10,10 +10,10 @@ BUCKET="$GCS_BUCKET" export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592 -WANDB_API_KEY="$WANDB_API_KEY" python main.py \ +WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \ --workdir=$WORKDIR \ --bucket=$BUCKET \ - --config=configs/$CONFIG.py:imagenet_raw_256-XL_2 \ + --config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_raw_256-XL_2 \ --config.data.batch_size=$BATCH_SIZE \ --config.standalone_eval=False \ --config.project_name='diffuse_nnx' \ diff --git a/commands/mf_imagenet.sh b/commands/mf_imagenet.sh index 00d5c6c..595f015 100755 --- a/commands/mf_imagenet.sh +++ b/commands/mf_imagenet.sh @@ -10,10 +10,10 @@ BUCKET="$GCS_BUCKET" export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592 -WANDB_API_KEY="$WANDB_API_KEY" python main.py \ +WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \ --workdir=$WORKDIR \ --bucket=$BUCKET \ - --config=configs/$CONFIG.py:imagenet_256-XL_2 \ + --config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_256-XL_2 \ --config.data.batch_size=$BATCH_SIZE \ --config.standalone_eval=False \ --config.project_name='diffuse_nnx' \ diff --git a/commands/rae_imagenet.sh b/commands/rae_imagenet.sh index 9bfb1ea..0fa403e 100755 --- a/commands/rae_imagenet.sh +++ b/commands/rae_imagenet.sh @@ -11,10 +11,10 @@ BUCKET="$GCS_BUCKET" export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=8589934592 # only RAE eval is supported at the moment -WANDB_API_KEY="$WANDB_API_KEY" python main.py \ +WANDB_API_KEY="$WANDB_API_KEY" python -m diffuse_nnx \ --workdir=$WORKDIR \ --bucket=$BUCKET \ - --config=configs/$CONFIG.py:imagenet_raw_256-XL_1 \ + --config=src/diffuse_nnx/configs/$CONFIG.py:imagenet_raw_256-XL_1 \ --config.data.batch_size=$BATCH_SIZE \ --config.standalone_eval=True \ --config.project_name='diffuse_nnx' \ diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 5da85fe..0000000 --- a/requirements.txt +++ /dev/null @@ -1,33 +0,0 @@ -jax[tpu]==0.5.1 -# GPU version -# jax[cuda12]==0.5.1 -flax==0.10.2 -optax==0.2.4 -orbax-checkpoint==0.11.4 -tensorflow - -clu -einops -transformers -timm - -ml_collections -termcolor -wandb -webdataset - -google-api-core -google-cloud-core -google-cloud-storage - ---extra-index-url https://download.pytorch.org/whl/cpu -torch==2.5.1 -torchvision==0.20.1 - -# Documentation dependencies -sphinx>=7.0.0 -pydata-sphinx-theme>=0.15.0 -myst-parser>=2.0.0 -sphinx-copybutton>=0.5.0 -sphinx-design>=0.5.0 -linkify-it-py>=2.0.0 diff --git a/setup.py b/setup.py index d4184c6..746ea7a 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,71 @@ from setuptools import find_packages, setup +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +# Core dependencies (without platform-specific JAX variants) +install_requires = [ + "jax>=0.5.1", + "jaxlib>=0.5.1", + "flax>=0.10.2", + "optax>=0.2.4", + "orbax-checkpoint>=0.11.4", + "tensorflow", + "torch>=2.5.0", + "torchvision>=0.20.0", + "clu", + "einops", + "transformers", + "timm", + "ml_collections", + "termcolor", + "wandb", + "webdataset", + "google-api-core", + "google-cloud-core", + "google-cloud-storage", +] + +# Optional dependencies +extras_require = { + "tpu": ["jax[tpu]==0.5.1"], + "gpu": ["jax[cuda12]==0.5.1"], + "docs": [ + "sphinx>=7.0.0", + "pydata-sphinx-theme>=0.15.0", + "myst-parser>=2.0.0", + "sphinx-copybutton>=0.5.0", + "sphinx-design>=0.5.0", + "linkify-it-py>=2.0.0", + ], +} + setup( name="diffuse_nnx", version="0.1.0", - packages=find_packages(), -) \ No newline at end of file + author="Nanye Ma", + description="A JAX/NNX Library for Diffusion and Flow Matching", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/willisma/diffuse_nnx", + package_dir={"": "src"}, + packages=find_packages(where="src"), + python_requires=">=3.9,<3.12", + install_requires=install_requires, + extras_require=extras_require, + entry_points={ + "console_scripts": [ + "diffuse-nnx=diffuse_nnx.__main__:main", + ], + }, + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], +) diff --git a/src/diffuse_nnx/__init__.py b/src/diffuse_nnx/__init__.py new file mode 100644 index 0000000..1168162 --- /dev/null +++ b/src/diffuse_nnx/__init__.py @@ -0,0 +1,3 @@ +"""DiffuseNNX: A JAX/NNX Library for Diffusion and Flow Matching.""" + +__version__ = "0.1.0" diff --git a/main.py b/src/diffuse_nnx/__main__.py similarity index 96% rename from main.py rename to src/diffuse_nnx/__main__.py index 6a6a0f1..14d25a4 100644 --- a/main.py +++ b/src/diffuse_nnx/__main__.py @@ -11,7 +11,7 @@ from ml_collections import config_flags # deps -from utils import logging_utils, gcloud_utils +from diffuse_nnx.utils import logging_utils, gcloud_utils FLAGS = flags.FLAGS @@ -43,7 +43,7 @@ def prepare_local_workdir(workdir): def get_trainers(trainer): """Get the trainers for the experiment.""" if trainer == 'DiT_ImageNet': - from trainers import dit_imagenet + from diffuse_nnx.trainers import dit_imagenet return dit_imagenet else: raise ValueError(f'Unknown trainer: {trainer}') diff --git a/configs/README.md b/src/diffuse_nnx/configs/README.md similarity index 100% rename from configs/README.md rename to src/diffuse_nnx/configs/README.md diff --git a/interfaces/__init__.py b/src/diffuse_nnx/configs/__init__.py similarity index 100% rename from interfaces/__init__.py rename to src/diffuse_nnx/configs/__init__.py diff --git a/configs/common_specs.py b/src/diffuse_nnx/configs/common_specs.py similarity index 100% rename from configs/common_specs.py rename to src/diffuse_nnx/configs/common_specs.py diff --git a/configs/dit_imagenet.py b/src/diffuse_nnx/configs/dit_imagenet.py similarity index 99% rename from configs/dit_imagenet.py rename to src/diffuse_nnx/configs/dit_imagenet.py index 282ba02..d8fbbf0 100644 --- a/configs/dit_imagenet.py +++ b/src/diffuse_nnx/configs/dit_imagenet.py @@ -6,7 +6,7 @@ import ml_collections # deps -from configs import common_specs +from diffuse_nnx.configs import common_specs def get_config(options='imagenet_64-B_2'): data_options, network_options = options.split('-') diff --git a/configs/dit_imagenet_repa.py b/src/diffuse_nnx/configs/dit_imagenet_repa.py similarity index 99% rename from configs/dit_imagenet_repa.py rename to src/diffuse_nnx/configs/dit_imagenet_repa.py index 66f5a32..56e2fe7 100755 --- a/configs/dit_imagenet_repa.py +++ b/src/diffuse_nnx/configs/dit_imagenet_repa.py @@ -6,7 +6,7 @@ import ml_collections # deps -from configs import common_specs +from diffuse_nnx.configs import common_specs def get_config(options='imagenet_64-B_2'): data_options, network_options = options.split('-') diff --git a/configs/lightning_ddt_imagenet.py b/src/diffuse_nnx/configs/lightning_ddt_imagenet.py similarity index 99% rename from configs/lightning_ddt_imagenet.py rename to src/diffuse_nnx/configs/lightning_ddt_imagenet.py index 044d3c6..9122ebb 100644 --- a/configs/lightning_ddt_imagenet.py +++ b/src/diffuse_nnx/configs/lightning_ddt_imagenet.py @@ -6,7 +6,7 @@ import ml_collections # deps -from configs import common_specs +from diffuse_nnx.configs import common_specs def get_config(options='imagenet_64-B_2'): data_options, network_options = options.split('-') diff --git a/configs/lightning_dit_imagenet.py b/src/diffuse_nnx/configs/lightning_dit_imagenet.py similarity index 99% rename from configs/lightning_dit_imagenet.py rename to src/diffuse_nnx/configs/lightning_dit_imagenet.py index ef69309..a46c094 100644 --- a/configs/lightning_dit_imagenet.py +++ b/src/diffuse_nnx/configs/lightning_dit_imagenet.py @@ -6,7 +6,7 @@ import ml_collections # deps -from configs import common_specs +from diffuse_nnx.configs import common_specs def get_config(options='imagenet_64-B_2'): data_options, network_options = options.split('-') diff --git a/configs/mf_imagenet.py b/src/diffuse_nnx/configs/mf_imagenet.py similarity index 96% rename from configs/mf_imagenet.py rename to src/diffuse_nnx/configs/mf_imagenet.py index 7ce0563..8347e8b 100644 --- a/configs/mf_imagenet.py +++ b/src/diffuse_nnx/configs/mf_imagenet.py @@ -6,8 +6,8 @@ import ml_collections # deps -from configs import common_specs -from configs import dit_imagenet +from diffuse_nnx.configs import common_specs +from diffuse_nnx.configs import dit_imagenet def get_config(options='imagenet_64-B_2'): diff --git a/configs/rae_imagenet.py b/src/diffuse_nnx/configs/rae_imagenet.py similarity index 99% rename from configs/rae_imagenet.py rename to src/diffuse_nnx/configs/rae_imagenet.py index 3dae8bc..7e18534 100644 --- a/configs/rae_imagenet.py +++ b/src/diffuse_nnx/configs/rae_imagenet.py @@ -6,7 +6,7 @@ import ml_collections # deps -from configs import common_specs +from diffuse_nnx.configs import common_specs def get_config(options='imagenet_64-B_2'): data_options, network_options = options.split('-') diff --git a/src/diffuse_nnx/data/__init__.py b/src/diffuse_nnx/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/custom_wds_imagenet_dataset.py b/src/diffuse_nnx/data/custom_wds_imagenet_dataset.py similarity index 99% rename from data/custom_wds_imagenet_dataset.py rename to src/diffuse_nnx/data/custom_wds_imagenet_dataset.py index fe86aff..7b55393 100644 --- a/data/custom_wds_imagenet_dataset.py +++ b/src/diffuse_nnx/data/custom_wds_imagenet_dataset.py @@ -13,7 +13,7 @@ import webdataset as wds # deps -from data import utils, wds_imagenet_dataset +from diffuse_nnx.data import utils, wds_imagenet_dataset class IterableDatasetShard(torch.utils.data.IterableDataset): diff --git a/data/local_imagenet_dataset.py b/src/diffuse_nnx/data/local_imagenet_dataset.py similarity index 99% rename from data/local_imagenet_dataset.py rename to src/diffuse_nnx/data/local_imagenet_dataset.py index 901ad38..18c993c 100644 --- a/data/local_imagenet_dataset.py +++ b/src/diffuse_nnx/data/local_imagenet_dataset.py @@ -23,7 +23,7 @@ pyspng = None # deps -from data import utils +from diffuse_nnx.data import utils class LatentDataset(torch.utils.data.Dataset): diff --git a/data/utils.py b/src/diffuse_nnx/data/utils.py similarity index 98% rename from data/utils.py rename to src/diffuse_nnx/data/utils.py index c5f639f..cf93be8 100644 --- a/data/utils.py +++ b/src/diffuse_nnx/data/utils.py @@ -12,7 +12,7 @@ from torchvision import transforms # deps -from utils import sharding_utils +from diffuse_nnx.utils import sharding_utils class EasyDict(dict): def __init__(self, *args, **kwargs): diff --git a/data/val_labels.txt b/src/diffuse_nnx/data/val_labels.txt similarity index 100% rename from data/val_labels.txt rename to src/diffuse_nnx/data/val_labels.txt diff --git a/data/wds_imagenet_dataset.py b/src/diffuse_nnx/data/wds_imagenet_dataset.py similarity index 98% rename from data/wds_imagenet_dataset.py rename to src/diffuse_nnx/data/wds_imagenet_dataset.py index 81db63d..d12abf9 100644 --- a/data/wds_imagenet_dataset.py +++ b/src/diffuse_nnx/data/wds_imagenet_dataset.py @@ -13,7 +13,7 @@ import webdataset as wds # deps -from data import utils +from diffuse_nnx.data import utils # Main entry point for imagenet dataset diff --git a/eval/README.md b/src/diffuse_nnx/eval/README.md similarity index 100% rename from eval/README.md rename to src/diffuse_nnx/eval/README.md diff --git a/src/diffuse_nnx/eval/__init__.py b/src/diffuse_nnx/eval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/eval/fid.py b/src/diffuse_nnx/eval/fid.py similarity index 98% rename from eval/fid.py rename to src/diffuse_nnx/eval/fid.py index 254ab08..2427186 100644 --- a/eval/fid.py +++ b/src/diffuse_nnx/eval/fid.py @@ -20,10 +20,10 @@ from tqdm import tqdm # deps -from data import utils as data_utils -from eval import utils -from samplers import samplers -from utils import wandb_utils, sharding_utils +from diffuse_nnx.data import utils as data_utils +from diffuse_nnx.eval import utils +from diffuse_nnx.samplers import samplers +from diffuse_nnx.utils import wandb_utils, sharding_utils def calculate_stats_for_iterable( diff --git a/eval/inception.py b/src/diffuse_nnx/eval/inception.py similarity index 99% rename from eval/inception.py rename to src/diffuse_nnx/eval/inception.py index 379cf76..c796517 100644 --- a/eval/inception.py +++ b/src/diffuse_nnx/eval/inception.py @@ -17,7 +17,7 @@ import flax.linen as nn # deps -from eval import utils +from diffuse_nnx.eval import utils PRNGKey = Any Array = Any diff --git a/eval/inception_v3_weights_fid.pickle b/src/diffuse_nnx/eval/inception_v3_weights_fid.pickle similarity index 100% rename from eval/inception_v3_weights_fid.pickle rename to src/diffuse_nnx/eval/inception_v3_weights_fid.pickle diff --git a/eval/utils.py b/src/diffuse_nnx/eval/utils.py similarity index 98% rename from eval/utils.py rename to src/diffuse_nnx/eval/utils.py index de270f9..2fd8669 100644 --- a/eval/utils.py +++ b/src/diffuse_nnx/eval/utils.py @@ -18,8 +18,8 @@ from tqdm import tqdm # deps -from eval import inception -from samplers import samplers +from diffuse_nnx.eval import inception +from diffuse_nnx.samplers import samplers def get(dictionary, key): diff --git a/interfaces/README.md b/src/diffuse_nnx/interfaces/README.md similarity index 100% rename from interfaces/README.md rename to src/diffuse_nnx/interfaces/README.md diff --git a/src/diffuse_nnx/interfaces/__init__.py b/src/diffuse_nnx/interfaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/interfaces/continuous.py b/src/diffuse_nnx/interfaces/continuous.py similarity index 100% rename from interfaces/continuous.py rename to src/diffuse_nnx/interfaces/continuous.py diff --git a/interfaces/discrete.py b/src/diffuse_nnx/interfaces/discrete.py similarity index 100% rename from interfaces/discrete.py rename to src/diffuse_nnx/interfaces/discrete.py diff --git a/interfaces/repa.py b/src/diffuse_nnx/interfaces/repa.py similarity index 98% rename from interfaces/repa.py rename to src/diffuse_nnx/interfaces/repa.py index c970a8b..77a19eb 100644 --- a/interfaces/repa.py +++ b/src/diffuse_nnx/interfaces/repa.py @@ -9,7 +9,7 @@ import numpy as np # deps -from networks.transformers import dit_nnx +from diffuse_nnx.networks.transformers import dit_nnx def build_mlp(hidden_size, projector_dim, feature_dim, rngs, dtype=jnp.float32): """Build a multi-layer perceptron for feature projection. diff --git a/networks/README.md b/src/diffuse_nnx/networks/README.md similarity index 100% rename from networks/README.md rename to src/diffuse_nnx/networks/README.md diff --git a/src/diffuse_nnx/networks/__init__.py b/src/diffuse_nnx/networks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/diffuse_nnx/networks/decoders/__init__.py b/src/diffuse_nnx/networks/decoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/decoders/utils.py b/src/diffuse_nnx/networks/decoders/utils.py similarity index 100% rename from networks/decoders/utils.py rename to src/diffuse_nnx/networks/decoders/utils.py diff --git a/networks/decoders/vit.py b/src/diffuse_nnx/networks/decoders/vit.py similarity index 99% rename from networks/decoders/vit.py rename to src/diffuse_nnx/networks/decoders/vit.py index f435f75..4eb4df4 100644 --- a/networks/decoders/vit.py +++ b/src/diffuse_nnx/networks/decoders/vit.py @@ -21,7 +21,7 @@ from PIL import Image import torch -from networks.decoders.utils import ACT2FN, ModelOutput, ViTMAEConfig, get_2d_sincos_pos_embed, convert_weights +from diffuse_nnx.networks.decoders.utils import ACT2FN, ModelOutput, ViTMAEConfig, get_2d_sincos_pos_embed, convert_weights Array = jnp.ndarray diff --git a/src/diffuse_nnx/networks/encoders/__init__.py b/src/diffuse_nnx/networks/encoders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/encoders/config.json b/src/diffuse_nnx/networks/encoders/config.json similarity index 100% rename from networks/encoders/config.json rename to src/diffuse_nnx/networks/encoders/config.json diff --git a/networks/encoders/dino.py b/src/diffuse_nnx/networks/encoders/dino.py similarity index 100% rename from networks/encoders/dino.py rename to src/diffuse_nnx/networks/encoders/dino.py diff --git a/networks/encoders/dino_w_register.py b/src/diffuse_nnx/networks/encoders/dino_w_register.py similarity index 100% rename from networks/encoders/dino_w_register.py rename to src/diffuse_nnx/networks/encoders/dino_w_register.py diff --git a/src/diffuse_nnx/networks/encoders/mae/__init__.py b/src/diffuse_nnx/networks/encoders/mae/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/encoders/mae/mae.py b/src/diffuse_nnx/networks/encoders/mae/mae.py similarity index 99% rename from networks/encoders/mae/mae.py rename to src/diffuse_nnx/networks/encoders/mae/mae.py index 8e3c818..0cdf36b 100644 --- a/networks/encoders/mae/mae.py +++ b/src/diffuse_nnx/networks/encoders/mae/mae.py @@ -12,7 +12,7 @@ import numpy as np # deps -from networks.encoders.mae import utils +from diffuse_nnx.networks.encoders.mae import utils fixed_gaussian_init = nn.initializers.normal(stddev=0.02) clstoken_init = fixed_gaussian_init diff --git a/networks/encoders/mae/utils.py b/src/diffuse_nnx/networks/encoders/mae/utils.py similarity index 100% rename from networks/encoders/mae/utils.py rename to src/diffuse_nnx/networks/encoders/mae/utils.py diff --git a/networks/encoders/pos_embed_base.npy b/src/diffuse_nnx/networks/encoders/pos_embed_base.npy similarity index 100% rename from networks/encoders/pos_embed_base.npy rename to src/diffuse_nnx/networks/encoders/pos_embed_base.npy diff --git a/networks/encoders/rae.py b/src/diffuse_nnx/networks/encoders/rae.py similarity index 96% rename from networks/encoders/rae.py rename to src/diffuse_nnx/networks/encoders/rae.py index d96a675..22e8131 100644 --- a/networks/encoders/rae.py +++ b/src/diffuse_nnx/networks/encoders/rae.py @@ -17,8 +17,8 @@ from transformers import AutoImageProcessor # deps -from networks.encoders.dino_w_register import DinoWithRegisters -from networks.decoders.vit import ViTMAEConfig, GeneralDecoder +from diffuse_nnx.networks.encoders.dino_w_register import DinoWithRegisters +from diffuse_nnx.networks.decoders.vit import ViTMAEConfig, GeneralDecoder def _load_config(config_path: Path | None) -> ViTMAEConfig: diff --git a/networks/encoders/rgb.py b/src/diffuse_nnx/networks/encoders/rgb.py similarity index 100% rename from networks/encoders/rgb.py rename to src/diffuse_nnx/networks/encoders/rgb.py diff --git a/networks/encoders/sd_vae.py b/src/diffuse_nnx/networks/encoders/sd_vae.py similarity index 99% rename from networks/encoders/sd_vae.py rename to src/diffuse_nnx/networks/encoders/sd_vae.py index 20f6d60..0cdf390 100755 --- a/networks/encoders/sd_vae.py +++ b/src/diffuse_nnx/networks/encoders/sd_vae.py @@ -20,7 +20,7 @@ from transformers import PretrainedConfig # deps -from networks.encoders import utils +from diffuse_nnx.networks.encoders import utils VAE_PRECISION = None diff --git a/src/diffuse_nnx/networks/encoders/stats/__init__.py b/src/diffuse_nnx/networks/encoders/stats/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/diffuse_nnx/networks/encoders/stats/wReg_base/__init__.py b/src/diffuse_nnx/networks/encoders/stats/wReg_base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/encoders/stats/wReg_base/stat.pt b/src/diffuse_nnx/networks/encoders/stats/wReg_base/stat.pt similarity index 100% rename from networks/encoders/stats/wReg_base/stat.pt rename to src/diffuse_nnx/networks/encoders/stats/wReg_base/stat.pt diff --git a/networks/encoders/utils.py b/src/diffuse_nnx/networks/encoders/utils.py similarity index 100% rename from networks/encoders/utils.py rename to src/diffuse_nnx/networks/encoders/utils.py diff --git a/src/diffuse_nnx/networks/transformers/__init__.py b/src/diffuse_nnx/networks/transformers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/networks/transformers/dit.py b/src/diffuse_nnx/networks/transformers/dit.py similarity index 99% rename from networks/transformers/dit.py rename to src/diffuse_nnx/networks/transformers/dit.py index 2d64cb6..f231ef7 100644 --- a/networks/transformers/dit.py +++ b/src/diffuse_nnx/networks/transformers/dit.py @@ -9,7 +9,7 @@ import flax import flax.linen as nn -from networks.transformers import utils +from diffuse_nnx.networks.transformers import utils from einops import rearrange class TimestepEmbedder(nn.Module): diff --git a/networks/transformers/dit_nnx.py b/src/diffuse_nnx/networks/transformers/dit_nnx.py similarity index 99% rename from networks/transformers/dit_nnx.py rename to src/diffuse_nnx/networks/transformers/dit_nnx.py index bef15b4..a4f5a4e 100644 --- a/networks/transformers/dit_nnx.py +++ b/src/diffuse_nnx/networks/transformers/dit_nnx.py @@ -12,7 +12,7 @@ import numpy as np # deps -from networks.transformers import utils +from diffuse_nnx.networks.transformers import utils PRECISION = None diff --git a/networks/transformers/lightning_ddt_nnx.py b/src/diffuse_nnx/networks/transformers/lightning_ddt_nnx.py similarity index 99% rename from networks/transformers/lightning_ddt_nnx.py rename to src/diffuse_nnx/networks/transformers/lightning_ddt_nnx.py index 2db09a6..ee5f701 100644 --- a/networks/transformers/lightning_ddt_nnx.py +++ b/src/diffuse_nnx/networks/transformers/lightning_ddt_nnx.py @@ -12,7 +12,7 @@ import numpy as np # deps -from networks.transformers import utils, dit_nnx, lightning_dit_nnx +from diffuse_nnx.networks.transformers import utils, dit_nnx, lightning_dit_nnx PRECISION = None diff --git a/networks/transformers/lightning_dit_nnx.py b/src/diffuse_nnx/networks/transformers/lightning_dit_nnx.py similarity index 99% rename from networks/transformers/lightning_dit_nnx.py rename to src/diffuse_nnx/networks/transformers/lightning_dit_nnx.py index b650e41..76ebc65 100644 --- a/networks/transformers/lightning_dit_nnx.py +++ b/src/diffuse_nnx/networks/transformers/lightning_dit_nnx.py @@ -14,7 +14,7 @@ import numpy as np # deps -from networks.transformers import utils, dit_nnx +from diffuse_nnx.networks.transformers import utils, dit_nnx PRECISION = None diff --git a/networks/transformers/port_nnx_to_torch.py b/src/diffuse_nnx/networks/transformers/port_nnx_to_torch.py similarity index 100% rename from networks/transformers/port_nnx_to_torch.py rename to src/diffuse_nnx/networks/transformers/port_nnx_to_torch.py diff --git a/networks/transformers/port_torch_to_nnx.py b/src/diffuse_nnx/networks/transformers/port_torch_to_nnx.py similarity index 100% rename from networks/transformers/port_torch_to_nnx.py rename to src/diffuse_nnx/networks/transformers/port_torch_to_nnx.py diff --git a/networks/transformers/utils.py b/src/diffuse_nnx/networks/transformers/utils.py similarity index 100% rename from networks/transformers/utils.py rename to src/diffuse_nnx/networks/transformers/utils.py diff --git a/samplers/README.md b/src/diffuse_nnx/samplers/README.md similarity index 100% rename from samplers/README.md rename to src/diffuse_nnx/samplers/README.md diff --git a/src/diffuse_nnx/samplers/__init__.py b/src/diffuse_nnx/samplers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/samplers/samplers.py b/src/diffuse_nnx/samplers/samplers.py similarity index 100% rename from samplers/samplers.py rename to src/diffuse_nnx/samplers/samplers.py diff --git a/trainers/README.md b/src/diffuse_nnx/trainers/README.md similarity index 100% rename from trainers/README.md rename to src/diffuse_nnx/trainers/README.md diff --git a/src/diffuse_nnx/trainers/__init__.py b/src/diffuse_nnx/trainers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trainers/dit_imagenet.py b/src/diffuse_nnx/trainers/dit_imagenet.py similarity index 98% rename from trainers/dit_imagenet.py rename to src/diffuse_nnx/trainers/dit_imagenet.py index aab27da..851b1ac 100644 --- a/trainers/dit_imagenet.py +++ b/src/diffuse_nnx/trainers/dit_imagenet.py @@ -16,10 +16,10 @@ import numpy as np # deps -from data import local_imagenet_dataset, utils as data_utils -from eval import fid -from interfaces import continuous -from utils import ( +from diffuse_nnx.data import local_imagenet_dataset, utils as data_utils +from diffuse_nnx.eval import fid +from diffuse_nnx.interfaces import continuous +from diffuse_nnx.utils import ( checkpoint as ckpt_utils, initialize as init_utils, logging_utils, diff --git a/src/diffuse_nnx/utils/__init__.py b/src/diffuse_nnx/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/checkpoint.py b/src/diffuse_nnx/utils/checkpoint.py similarity index 100% rename from utils/checkpoint.py rename to src/diffuse_nnx/utils/checkpoint.py diff --git a/utils/ema.py b/src/diffuse_nnx/utils/ema.py similarity index 100% rename from utils/ema.py rename to src/diffuse_nnx/utils/ema.py diff --git a/utils/gcloud_utils.py b/src/diffuse_nnx/utils/gcloud_utils.py similarity index 100% rename from utils/gcloud_utils.py rename to src/diffuse_nnx/utils/gcloud_utils.py diff --git a/utils/initialize.py b/src/diffuse_nnx/utils/initialize.py similarity index 95% rename from utils/initialize.py rename to src/diffuse_nnx/utils/initialize.py index a5735c6..eb6a098 100644 --- a/utils/initialize.py +++ b/src/diffuse_nnx/utils/initialize.py @@ -12,12 +12,12 @@ import optax # deps -from interfaces import continuous, discrete, repa -from networks.transformers import dit_nnx, lightning_dit_nnx, lightning_ddt_nnx -from networks.encoders import dino, rae -from samplers import samplers -from networks.encoders import sd_vae, rgb -from utils import ema +from diffuse_nnx.interfaces import continuous, discrete, repa +from diffuse_nnx.networks.transformers import dit_nnx, lightning_dit_nnx, lightning_ddt_nnx +from diffuse_nnx.networks.encoders import dino, rae +from diffuse_nnx.samplers import samplers +from diffuse_nnx.networks.encoders import sd_vae, rgb +from diffuse_nnx.utils import ema ENCODER_REGISTRY = { @@ -332,7 +332,7 @@ def build_models(config: ml_collections.ConfigDict): return encoder, model, optimizer, sampler, ema, learning_rate_fn -from configs import dit_imagenet +from diffuse_nnx.configs import dit_imagenet if __name__ == "__main__": config = dit_imagenet.get_config() diff --git a/utils/logging_utils.py b/src/diffuse_nnx/utils/logging_utils.py similarity index 100% rename from utils/logging_utils.py rename to src/diffuse_nnx/utils/logging_utils.py diff --git a/utils/sharding_utils.py b/src/diffuse_nnx/utils/sharding_utils.py similarity index 99% rename from utils/sharding_utils.py rename to src/diffuse_nnx/utils/sharding_utils.py index d8368ec..b89d471 100644 --- a/utils/sharding_utils.py +++ b/src/diffuse_nnx/utils/sharding_utils.py @@ -17,7 +17,7 @@ import numpy as np # deps -from utils import ema +from diffuse_nnx.utils import ema def flatten_state( diff --git a/utils/visualize.py b/src/diffuse_nnx/utils/visualize.py similarity index 97% rename from utils/visualize.py rename to src/diffuse_nnx/utils/visualize.py index 40e2ede..4334f74 100644 --- a/utils/visualize.py +++ b/src/diffuse_nnx/utils/visualize.py @@ -12,8 +12,8 @@ import ml_collections # deps -from utils import wandb_utils, sharding_utils -from samplers import samplers +from diffuse_nnx.utils import wandb_utils, sharding_utils +from diffuse_nnx.samplers import samplers @functools.partial( diff --git a/utils/wandb_utils.py b/src/diffuse_nnx/utils/wandb_utils.py similarity index 100% rename from utils/wandb_utils.py rename to src/diffuse_nnx/utils/wandb_utils.py diff --git a/tests/data_tests/dataloader_benchmark.py b/tests/data_tests/dataloader_benchmark.py index 58503f8..c9b65c5 100644 --- a/tests/data_tests/dataloader_benchmark.py +++ b/tests/data_tests/dataloader_benchmark.py @@ -11,8 +11,8 @@ from tqdm import tqdm # deps -from data import utils -from data import wds_imagenet_dataset as wds +from diffuse_nnx.data import utils +from diffuse_nnx.data import wds_imagenet_dataset as wds class TestWDS(unittest.TestCase): diff --git a/tests/data_tests/dataloader_ddp_tests.py b/tests/data_tests/dataloader_ddp_tests.py index 38c5c17..dbb7e24 100644 --- a/tests/data_tests/dataloader_ddp_tests.py +++ b/tests/data_tests/dataloader_ddp_tests.py @@ -13,9 +13,9 @@ from tqdm import tqdm # deps -from data import utils -from data import wds_imagenet_dataset as wds -from data import custom_wds_imagenet_dataset as cwds +from diffuse_nnx.data import utils +from diffuse_nnx.data import wds_imagenet_dataset as wds +from diffuse_nnx.data import custom_wds_imagenet_dataset as cwds # suppress resource warning diff --git a/tests/data_tests/dataloader_tests.py b/tests/data_tests/dataloader_tests.py index 8821356..4e910b3 100644 --- a/tests/data_tests/dataloader_tests.py +++ b/tests/data_tests/dataloader_tests.py @@ -9,9 +9,9 @@ import ml_collections # deps -from data import utils -from data import local_imagenet_dataset as lds -from data import wds_imagenet_dataset as wds +from diffuse_nnx.data import utils +from diffuse_nnx.data import local_imagenet_dataset as lds +from diffuse_nnx.data import wds_imagenet_dataset as wds class TestWDS(unittest.TestCase): diff --git a/tests/interface_tests/continuous_tests.py b/tests/interface_tests/continuous_tests.py index ce1d9a8..711a187 100644 --- a/tests/interface_tests/continuous_tests.py +++ b/tests/interface_tests/continuous_tests.py @@ -11,7 +11,7 @@ import jax.numpy as jnp # deps -from interfaces.continuous import SiTInterface, EDMInterface, TrainingTimeDistType, MeanFlowInterface +from diffuse_nnx.interfaces.continuous import SiTInterface, EDMInterface, TrainingTimeDistType, MeanFlowInterface class DummyMlp1(nn.Module): in_dim: int diff --git a/tests/network_tests/ddt/ddt_port.py b/tests/network_tests/ddt/ddt_port.py index 3f299e3..262ef1b 100644 --- a/tests/network_tests/ddt/ddt_port.py +++ b/tests/network_tests/ddt/ddt_port.py @@ -15,8 +15,8 @@ # deps from tests.network_tests.ddt.ddt_torch import DiTwDDTHead -from networks.transformers import lightning_ddt_nnx -from networks.transformers import dit_nnx, port_torch_to_nnx as port +from diffuse_nnx.networks.transformers import lightning_ddt_nnx +from diffuse_nnx.networks.transformers import dit_nnx, port_torch_to_nnx as port if __name__ == "__main__": diff --git a/tests/network_tests/ddt/ddt_tests.py b/tests/network_tests/ddt/ddt_tests.py index 7752710..f5a44d3 100644 --- a/tests/network_tests/ddt/ddt_tests.py +++ b/tests/network_tests/ddt/ddt_tests.py @@ -12,7 +12,7 @@ import torch # deps -from networks.transformers import lightning_ddt_nnx, port_torch_to_nnx as port +from diffuse_nnx.networks.transformers import lightning_ddt_nnx, port_torch_to_nnx as port from tests.network_tests.ddt.ddt_torch import DiTwDDTHead diff --git a/tests/network_tests/decoders/convert_weights.py b/tests/network_tests/decoders/convert_weights.py index 5f2b39d..e8beb9d 100644 --- a/tests/network_tests/decoders/convert_weights.py +++ b/tests/network_tests/decoders/convert_weights.py @@ -21,7 +21,7 @@ import torch # deps -from networks.decoders.vit import GeneralDecoder, ViTMAEConfig +from diffuse_nnx.networks.decoders.vit import GeneralDecoder, ViTMAEConfig Array = jnp.ndarray diff --git a/tests/network_tests/decoders/rae_tests.py b/tests/network_tests/decoders/rae_tests.py index 8a4da3c..be4c944 100644 --- a/tests/network_tests/decoders/rae_tests.py +++ b/tests/network_tests/decoders/rae_tests.py @@ -14,7 +14,7 @@ import torch # deps -from networks.decoders.vit import GeneralDecoder as FlaxGeneralDecoder, ViTMAEConfig +from diffuse_nnx.networks.decoders.vit import GeneralDecoder as FlaxGeneralDecoder, ViTMAEConfig from tests.network_tests.decoders.vit_torch import GeneralDecoder as TorchGeneralDecoder diff --git a/tests/network_tests/decoders/vit_torch.py b/tests/network_tests/decoders/vit_torch.py index a64b561..06eac14 100644 --- a/tests/network_tests/decoders/vit_torch.py +++ b/tests/network_tests/decoders/vit_torch.py @@ -27,7 +27,7 @@ from transformers.modeling_outputs import BaseModelOutput from transformers.utils import ModelOutput -from networks.decoders.utils import ViTMAEConfig +from diffuse_nnx.networks.decoders.utils import ViTMAEConfig diff --git a/tests/network_tests/dit/dit_linen_tests.py b/tests/network_tests/dit/dit_linen_tests.py index 31be58d..de4871e 100644 --- a/tests/network_tests/dit/dit_linen_tests.py +++ b/tests/network_tests/dit/dit_linen_tests.py @@ -9,8 +9,8 @@ import jax import jax.numpy as jnp from jax import random, jit -from networks.transformers.dit import DiT as DiT_jax -from networks.transformers.utils import get_2d_sincos_pos_embed, to_2tuple +from diffuse_nnx.networks.transformers.dit import DiT as DiT_jax +from diffuse_nnx.networks.transformers.utils import get_2d_sincos_pos_embed, to_2tuple from .dit_torch import DiT as DiT key = random.PRNGKey(0) diff --git a/tests/network_tests/dit/dit_tests.py b/tests/network_tests/dit/dit_tests.py index df5d252..150fda9 100644 --- a/tests/network_tests/dit/dit_tests.py +++ b/tests/network_tests/dit/dit_tests.py @@ -12,7 +12,7 @@ import torch # deps -from networks.transformers import dit_nnx, port_nnx_to_torch as port +from diffuse_nnx.networks.transformers import dit_nnx, port_nnx_to_torch as port from tests.network_tests.dit.dit_torch import DiT_B_2 diff --git a/tests/network_tests/encoders/dino_tests.py b/tests/network_tests/encoders/dino_tests.py index cc6aa47..777153b 100644 --- a/tests/network_tests/encoders/dino_tests.py +++ b/tests/network_tests/encoders/dino_tests.py @@ -16,10 +16,10 @@ import torch # deps -from configs import dit_imagenet -from data import local_imagenet_dataset -from networks.encoders.dino import DINO -from utils import initialize as init_utils +from diffuse_nnx.configs import dit_imagenet +from diffuse_nnx.data import local_imagenet_dataset +from diffuse_nnx.networks.encoders.dino import DINO +from diffuse_nnx.utils import initialize as init_utils def indexing(tree, key, value): diff --git a/tests/network_tests/encoders/dino_w_reg_tests.py b/tests/network_tests/encoders/dino_w_reg_tests.py index fcd5e95..9069955 100644 --- a/tests/network_tests/encoders/dino_w_reg_tests.py +++ b/tests/network_tests/encoders/dino_w_reg_tests.py @@ -12,7 +12,7 @@ from transformers import Dinov2WithRegistersModel # deps -from networks.encoders.dino_w_register import DinoWithRegisters, Dinov2WithRegistersConfig +from diffuse_nnx.networks.encoders.dino_w_register import DinoWithRegisters, Dinov2WithRegistersConfig class TestDinoWithRegistersEncoder(unittest.TestCase): diff --git a/tests/network_tests/encoders/sd_vae_tests.py b/tests/network_tests/encoders/sd_vae_tests.py index 4960d4e..c6beb86 100644 --- a/tests/network_tests/encoders/sd_vae_tests.py +++ b/tests/network_tests/encoders/sd_vae_tests.py @@ -12,9 +12,9 @@ import PIL # deps -from configs import dit_imagenet -from networks.encoders.sd_vae import StabilityVAE -from utils import initialize as init_utils +from diffuse_nnx.configs import dit_imagenet +from diffuse_nnx.networks.encoders.sd_vae import StabilityVAE +from diffuse_nnx.utils import initialize as init_utils if __name__ == "__main__": diff --git a/tests/network_tests/nnx/nnx_sharding_tests.py b/tests/network_tests/nnx/nnx_sharding_tests.py index 4ee8719..ad44f31 100644 --- a/tests/network_tests/nnx/nnx_sharding_tests.py +++ b/tests/network_tests/nnx/nnx_sharding_tests.py @@ -18,10 +18,10 @@ import optax # deps -from configs import dit_imagenet -from networks.transformers import dit_nnx -from utils.sharding_utils import create_device_mesh, flatten_state, infer_sharding -from utils import initialize as init_utils +from diffuse_nnx.configs import dit_imagenet +from diffuse_nnx.networks.transformers import dit_nnx +from diffuse_nnx.utils.sharding_utils import create_device_mesh, flatten_state, infer_sharding +from diffuse_nnx.utils import initialize as init_utils @dataclasses.dataclass(unsafe_hash=True) diff --git a/tests/runner.py b/tests/runner.py index 1970a9f..ef36d63 100644 --- a/tests/runner.py +++ b/tests/runner.py @@ -1,9 +1,15 @@ from pathlib import Path +import sys import unittest ROOT = Path(__file__).resolve().parent.parent TESTS = ROOT / "tests" +SRC = ROOT / "src" + +# Add src to path so tests can import diffuse_nnx +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) if __name__ == "__main__": loader = unittest.TestLoader() @@ -12,4 +18,4 @@ ) runner = unittest.TextTestRunner(verbosity=2) - runner.run(suite) \ No newline at end of file + runner.run(suite) diff --git a/tests/sampler_tests/sampler_tests.py b/tests/sampler_tests/sampler_tests.py index 8acca3c..0ced4a9 100644 --- a/tests/sampler_tests/sampler_tests.py +++ b/tests/sampler_tests/sampler_tests.py @@ -11,8 +11,8 @@ import jax.numpy as jnp # deps -from interfaces.continuous import SiTInterface, TrainingTimeDistType -from samplers.samplers import EulerSampler, HeunSampler, SamplingTimeDistType +from diffuse_nnx.interfaces.continuous import SiTInterface, TrainingTimeDistType +from diffuse_nnx.samplers.samplers import EulerSampler, HeunSampler, SamplingTimeDistType class DummyMlp(nn.Module): in_dim: int diff --git a/tests/util_tests/ema_tests.py b/tests/util_tests/ema_tests.py index 85ca940..ffa3440 100644 --- a/tests/util_tests/ema_tests.py +++ b/tests/util_tests/ema_tests.py @@ -9,7 +9,7 @@ import jax.numpy as jnp # deps -from utils import ema as ema_lib +from diffuse_nnx.utils import ema as ema_lib class TestEMA(unittest.TestCase):