Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions docker/common/uv-pytorch.lock
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ overrides = [
{ name = "nvidia-nccl-cu12", marker = "sys_platform == 'never'" },
{ name = "torch", marker = "sys_platform == 'never'", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchao", marker = "sys_platform == 'never'" },
{ name = "torchvision", marker = "sys_platform == 'never'" },
{ name = "torchvision", marker = "sys_platform == 'never'", index = "https://download.pytorch.org/whl/cpu" },
{ name = "transformer-engine", marker = "sys_platform == 'never'" },
{ name = "transformer-engine-torch", marker = "sys_platform == 'never'" },
{ name = "triton", marker = "sys_platform == 'never'" },
Expand Down Expand Up @@ -2279,6 +2279,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" },
]

[[package]]
name = "kernels"
version = "0.12.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "packaging" },
{ name = "pyyaml" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a9/07/d2b635e965b232cae1aa873c6e0458947196be8dca7bb02e64d3cd6e8d19/kernels-0.12.2.tar.gz", hash = "sha256:812fc43c2814f046cee655cbebf3918cddd489715773670bdb38cca3f5203b5b", size = 57108, upload-time = "2026-03-04T10:03:00.379Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/08/be/f5d6758b48633e4f6a28198fcf4bf9f763cc6a82e2335d9fe8802a5cb440/kernels-0.12.2-py3-none-any.whl", hash = "sha256:1289261804748cf3cf8e3afab80b505b0f1b28e4ec88379cdf08dc31e64964b8", size = 55205, upload-time = "2026-03-04T10:02:59.305Z" },
]

[[package]]
name = "kiwisolver"
version = "1.4.9"
Expand Down Expand Up @@ -3235,6 +3250,7 @@ all = [
{ name = "ftfy" },
{ name = "imageio" },
{ name = "imageio-ffmpeg" },
{ name = "kernels" },
{ name = "mamba-ssm" },
{ name = "mistral-common", extra = ["opencv"] },
{ name = "numba", version = "0.53.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" },
Expand All @@ -3252,6 +3268,7 @@ all = [
{ name = "sentencepiece" },
{ name = "timm" },
{ name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin'" },
{ name = "torchvision", marker = "sys_platform == 'never'" },
{ name = "transformer-engine", marker = "sys_platform == 'never'" },
]
cuda = [
Expand All @@ -3271,7 +3288,9 @@ diffusion = [
{ name = "ftfy" },
{ name = "imageio" },
{ name = "imageio-ffmpeg" },
{ name = "kernels" },
{ name = "opencv-python-headless" },
{ name = "torchvision", marker = "sys_platform == 'never'" },
]
extra = [
{ name = "flash-linear-attention" },
Expand Down Expand Up @@ -3358,6 +3377,7 @@ requires-dist = [
{ name = "ftfy", marker = "extra == 'diffusion'" },
{ name = "imageio", marker = "extra == 'diffusion'" },
{ name = "imageio-ffmpeg", marker = "extra == 'diffusion'" },
{ name = "kernels", marker = "extra == 'diffusion'" },
{ name = "mamba-ssm", marker = "extra == 'cuda'" },
{ name = "megatron-fsdp", specifier = ">=0.2.3" },
{ name = "mistral-common", extras = ["audio", "hf-hub", "image", "sentencepiece"] },
Expand Down Expand Up @@ -3390,6 +3410,9 @@ requires-dist = [
{ name = "torchao" },
{ name = "torchcodec", marker = "platform_machine == 'x86_64' and sys_platform != 'darwin' and extra == 'vlm'" },
{ name = "torchdata" },
{ name = "torchvision", marker = "sys_platform == 'darwin' and extra == 'diffusion'", index = "https://pypi.org/simple" },
{ name = "torchvision", marker = "sys_platform != 'darwin' and sys_platform != 'linux' and extra == 'diffusion'", index = "https://download.pytorch.org/whl/cpu" },
{ name = "torchvision", marker = "sys_platform == 'linux' and extra == 'diffusion'", index = "https://download.pytorch.org/whl/cu129" },
{ name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'cuda'", specifier = "<=2.11.0" },
{ name = "transformers", specifier = ">=5.0.0" },
{ name = "wandb" },
Expand Down Expand Up @@ -6409,8 +6432,8 @@ wheels = [

[[package]]
name = "torchvision"
version = "0.23.0"
source = { registry = "https://pypi.org/simple" }
version = "0.25.0+cpu"
source = { registry = "https://download.pytorch.org/whl/cpu" }
dependencies = [
{ name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13' and sys_platform != 'darwin' and sys_platform != 'linux'" },
{ name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux'" },
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusion/finetune/flux_t2i_flow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ step_scheduler:

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_flux_multiresolution_dataloader
_target_: nemo_automodel.components.datasets.diffusion.build_text_to_image_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
train_text_encoder: false
num_workers: 10
Expand Down
10 changes: 7 additions & 3 deletions examples/diffusion/finetune/wan2_1_t2v_flow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ step_scheduler:

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_dataloader
meta_folder: PATH_TO_YOUR_DATA
_target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
model_type: wan
base_resolution: [512, 512]
dynamic_batch_size: false
shuffle: true
drop_last: false
num_workers: 2
device: cpu

optim:
learning_rate: 5e-6
Expand Down
10 changes: 7 additions & 3 deletions examples/diffusion/finetune/wan2_1_t2v_flow_multinode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ step_scheduler:

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_dataloader
meta_folder: PATH_TO_YOUR_DATA
_target_: nemo_automodel.components.datasets.diffusion.build_video_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
model_type: wan
base_resolution: [512, 512]
dynamic_batch_size: false
shuffle: true
drop_last: false
num_workers: 2
device: cpu


optim:
Expand Down
4 changes: 2 additions & 2 deletions examples/diffusion/generate/flux_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from diffusers import FluxPipeline

# Import the provided dataloader builder
from nemo_automodel.components.datasets.diffusion import build_flux_multiresolution_dataloader
from nemo_automodel.components.datasets.diffusion import build_text_to_image_multiresolution_dataloader


def parse_args():
Expand Down Expand Up @@ -187,7 +187,7 @@ def main():
print("=" * 80)
print(f"Initializing Multiresolution Dataloader: {args.data_path}")

dataloader, _ = build_flux_multiresolution_dataloader(
dataloader, _ = build_text_to_image_multiresolution_dataloader(
cache_dir=args.data_path, batch_size=1, num_workers=args.num_workers, dynamic_batch_size=True, shuffle=False
)
print(f"[INFO] Dataloader ready. Batches: {len(dataloader)}")
Expand Down
2 changes: 1 addition & 1 deletion examples/diffusion/pretrain/flux_t2i_flow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ step_scheduler:

data:
dataloader:
_target_: nemo_automodel.components.datasets.diffusion.build_flux_multiresolution_dataloader
_target_: nemo_automodel.components.datasets.diffusion.build_text_to_image_multiresolution_dataloader
cache_dir: PATH_TO_YOUR_DATA
train_text_encoder: false
num_workers: 1
Expand Down
20 changes: 16 additions & 4 deletions nemo_automodel/components/datasets/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,26 @@
import importlib

_LAZY_ATTRS = {
"MetaFilesDataset": (".meta_files_dataset", "MetaFilesDataset"),
# Dataset classes
"BaseMultiresolutionDataset": (".base_dataset", "BaseMultiresolutionDataset"),
"TextToImageDataset": (".text_to_image_dataset", "TextToImageDataset"),
"TextToVideoDataset": (".text_to_video_dataset", "TextToVideoDataset"),
"MetaFilesDataset": (".meta_files_dataset", "MetaFilesDataset"),
# Utilities
"MultiTierBucketCalculator": (".multi_tier_bucketing", "MultiTierBucketCalculator"),
"SequentialBucketSampler": (".sampler", "SequentialBucketSampler"),
"collate_fn_flux": (".collate_fns", "collate_fn_flux"),
"build_flux_multiresolution_dataloader": (".collate_fns", "build_flux_multiresolution_dataloader"),
"build_mock_dataloader": (".mock_dataloader", "build_mock_dataloader"),
"VIDEO_OPTIONAL_FIELDS": (".text_to_video_dataset", "VIDEO_OPTIONAL_FIELDS"),
# Collate functions
"collate_fn_text_to_image": (".collate_fns", "collate_fn_text_to_image"),
"collate_fn_video": (".collate_fns", "collate_fn_video"),
"collate_fn_production": (".collate_fns", "collate_fn_production"),
# Dataloader builders
"build_text_to_image_multiresolution_dataloader": (".collate_fns", "build_text_to_image_multiresolution_dataloader"),
"build_video_multiresolution_dataloader": (".collate_fns", "build_video_multiresolution_dataloader"),
# Legacy (non-multiresolution)
"build_dataloader": (".meta_files_dataset", "build_dataloader"),
# Mock/test
"build_mock_dataloader": (".mock_dataloader", "build_mock_dataloader"),
}

__all__ = sorted(_LAZY_ATTRS.keys())
Expand Down
133 changes: 133 additions & 0 deletions nemo_automodel/components/datasets/diffusion/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List

from torch.utils.data import Dataset

from .multi_tier_bucketing import MultiTierBucketCalculator

logger = logging.getLogger(__name__)


class BaseMultiresolutionDataset(Dataset, ABC):
"""Abstract base class for multiresolution datasets with bucket-based sampling."""

def __init__(self, cache_dir: str, quantization: int = 64):
"""
Args:
cache_dir: Directory containing preprocessed cache (metadata.json + shards)
quantization: Resolution quantization factor (64 for images, 8 for video)
"""
self.cache_dir = Path(cache_dir)

# Load metadata
self.metadata = self._load_metadata()

logger.info(f"Loaded dataset with {len(self.metadata)} samples")

# Group by bucket
self._group_by_bucket()

# Initialize bucket calculator for dynamic batch sizes
self.calculator = MultiTierBucketCalculator(quantization=quantization)

def _load_metadata(self) -> List[Dict]:
"""Load metadata from cache directory.

Expects metadata.json with "shards" key referencing shard files.
"""
metadata_file = self.cache_dir / "metadata.json"

if not metadata_file.exists():
raise FileNotFoundError(f"No metadata.json found in {self.cache_dir}")

with open(metadata_file, "r") as f:
data = json.load(f)

if not isinstance(data, dict) or "shards" not in data:
raise ValueError(f"Invalid metadata format in {metadata_file}. Expected dict with 'shards' key.")

# Load all shard files
metadata = []
for shard_name in data["shards"]:
shard_path = self.cache_dir / shard_name
with open(shard_path, "r") as f:
shard_data = json.load(f)
metadata.extend(shard_data)

return metadata

def _aspect_ratio_to_name(self, aspect_ratio: float) -> str:
"""Convert aspect ratio to a descriptive name."""
if aspect_ratio < 0.85:
return "tall"
elif aspect_ratio > 1.18:
return "wide"
else:
return "square"

def _group_by_bucket(self):
"""Group samples by bucket (aspect_ratio + resolution)."""
self.bucket_groups = {}

# Support both bucket_resolution (video) and crop_resolution (image) keys
resolution_key = "bucket_resolution" if "bucket_resolution" in self.metadata[0] else "crop_resolution"

for idx, item in enumerate(self.metadata):
aspect_ratio = item.get("aspect_ratio", 1.0)
aspect_name = self._aspect_ratio_to_name(aspect_ratio)
resolution = tuple(item[resolution_key])
bucket_key = (aspect_name, resolution)

if bucket_key not in self.bucket_groups:
self.bucket_groups[bucket_key] = {
"indices": [],
"aspect_name": aspect_name,
"aspect_ratio": aspect_ratio,
"resolution": resolution,
"pixels": resolution[0] * resolution[1],
}

self.bucket_groups[bucket_key]["indices"].append(idx)

# Sort buckets by resolution (low to high for optimal memory usage)
self.sorted_bucket_keys = sorted(self.bucket_groups.keys(), key=lambda k: self.bucket_groups[k]["pixels"])

logger.info(f"\nDataset organized into {len(self.bucket_groups)} buckets:")
for key in self.sorted_bucket_keys:
bucket = self.bucket_groups[key]
aspect_name, resolution = key
logger.info(
f" {aspect_name:6s} {resolution[0]:4d}x{resolution[1]:4d}: {len(bucket['indices']):5d} samples"
)

def get_bucket_info(self) -> Dict:
"""Get bucket organization information."""
return {
"total_buckets": len(self.bucket_groups),
"buckets": {f"{k[0]}/{k[1][0]}x{k[1][1]}": len(v["indices"]) for k, v in self.bucket_groups.items()},
}

def __len__(self) -> int:
return len(self.metadata)

@abstractmethod
def __getitem__(self, idx: int) -> Dict:
"""Load a single sample. Subclasses must implement."""
...
Loading