Skip to content

[DFM Migration] Migration DFM -> Bridge#2534

Merged
huvunvidia merged 20 commits intomainfrom
migration/dfm
Mar 14, 2026
Merged

[DFM Migration] Migration DFM -> Bridge#2534
huvunvidia merged 20 commits intomainfrom
migration/dfm

Conversation

@abhinavg4
Copy link
Contributor

@abhinavg4 abhinavg4 commented Feb 25, 2026

What does this PR do ?

Migrate DFM to MB.

PR passing all tests in DFM is here: NVIDIA-NeMo/DFM#105

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • New Features

    • Added FLUX diffusion model support with pretraining, finetuning, and inference capabilities
    • Added WAN diffusion model support with pretraining and inference workflows
    • Added data modules for diffusion model training with energon integration
    • Added checkpoint conversion utilities between HuggingFace and Megatron formats
    • Added comprehensive training recipes and example configurations
  • Documentation

    • Added documentation for diffusion models, base classes, and examples
  • Tests

    • Added unit tests for diffusion models, data handling, and training pipelines
    • Added functional tests for FLUX and WAN pretraining workflows

For original authorship, see the DFM repo: https://github.com/NVIDIA-NeMo/DFM
Original Major Contributors for Megatron path: @abhinavg4, @huvunvidia, @sajadn, and @suiyoubi

Migrate the Megatron-based diffusion model code from the DFM repository
(commit 013ceca) into Megatron-Bridge as a self-contained `diffusion/`
module, following the shallow integration plan.

Source mapping:
- dfm/src/megatron/         -> src/megatron/bridge/diffusion/
- dfm/src/common/ (utils)   -> src/megatron/bridge/diffusion/common/
- examples/megatron/        -> examples/diffusion/
- tests/unit_tests/megatron -> tests/diffusion/unit_tests/
- tests/functional_tests/mcore -> tests/diffusion/functional_tests/

Key structural changes from DFM:
- model/ renamed to models/ (matches MB convention)
- model/*/conversion/ extracted to top-level conversion/ (separates
  bridge/checkpoint-conversion code from model implementation)
- All dfm.src.megatron.* imports rewritten to megatron.bridge.diffusion.*
- All dfm.src.common.* imports rewritten to megatron.bridge.diffusion.common.*
- dfm.src.automodel.* imports left as-is (automodel migrating separately)

Models included: DiT, FLUX, WAN (video generation)
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 25, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@abhinavg4
Copy link
Contributor Author

/ok to test a63854c

  F841 (unused variables) -- removed where clearly safe:
  • neg_t5_embed in inference_dit_model.py
  • device in inference_flux.py
  • error in wan/inference/utils.py (cache_image)
  • n -> _ in wan/rope_utils.py
  • data_batch, x0_from_data_batch in test_edm_pipeline.py
  • original_base in test_flux_hf_pretrained.py
  • total_layers, p_size, vp_size in test_flux_provider.py

  F841 -- kept with `# noqa: F841` (uncertain intent):
  • config = get_model_config(model) in flux_step_with_automodel.py
  • video_latents, loss_mask in flow_matching_pipeline_wan.py

  D101/D103 (missing docstrings) -- added `# noqa` markers to all 36 class/function definitions across 24 files.
  Markdown filename -- renamed README_perf_test.md to README-perf-test.md.
@abhinavg4
Copy link
Contributor Author

/ok to test 9242693

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Note

Due to the large number of review comments, Critical severity comments were prioritized as inline comments.

🟡 Minor comments (34)
tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh-1-14 (1)

1-14: ⚠️ Potential issue | 🟡 Minor

Add shebang and use uv run for script execution.

The script is missing a shebang (required by Google Shell Style Guide and flagged by static analysis SC2148). Additionally, per coding guidelines, use uv run to execute pytest instead of calling it directly.

Proposed fix
+#!/bin/bash
 # Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 # You may obtain a copy of the License at
 #
 #     http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing, software
 # distributed under the License is distributed on an "AS IS" BASIS,
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-CUDA_VISIBLE_DEVICES="0,1" pytest tests/functional_tests/diffusion/recipes -m "not pleasefixme" --with_downloads -v
+
+set -euo pipefail
+
+CUDA_VISIBLE_DEVICES="0,1" uv run --no-sync pytest tests/functional_tests/diffusion/recipes -m "not pleasefixme" --with_downloads -v

As per coding guidelines: "Use 'uv run' to execute scripts instead of activating a virtual environment and calling 'python' directly." Based on learnings: prefer --no-sync with uv run if dependencies are already installed in the environment (e.g., in CI containers).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh` around lines 1 - 14, Add a
POSIX shebang to the top of the script and replace the direct pytest invocation
with uv run; specifically, ensure the script (currently setting
CUDA_VISIBLE_DEVICES and calling pytest) begins with "#!/usr/bin/env bash" and
then use "uv run --no-sync pytest tests/functional_tests/diffusion/recipes -m
'not pleasefixme' --with_downloads -v" (preserving the CUDA_VISIBLE_DEVICES
environment variable) so static analysis SC2148 is satisfied and the project
guideline to use uv run is followed.
tests/unit_tests/diffusion/conftest.py-76-79 (1)

76-79: ⚠️ Potential issue | 🟡 Minor

Update the marker description for clarity.

config.addinivalue_line(...) only registers the marker; the actual test exclusion happens via the -m "not pleasefixme" flag in CI shell scripts (e.g., tests/unit_tests/Launch_Unit_Tests.sh), not through the pytest configuration. The description is accurate for CI but doesn't explain this mechanism. Consider revising to: "pleasefixme: marks tests that are broken and need fixing (excluded via -m flag in CI scripts)" to avoid confusion about how the skipping actually works.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/conftest.py` around lines 76 - 79, Update the
pytest marker description passed to config.addinivalue_line so it clearly states
that tests are only excluded via the CI -m flag rather than by pytest config;
locate the call to config.addinivalue_line registering "pleasefixme" and change
the message to something like "pleasefixme: marks tests that are broken and need
fixing (excluded via -m flag in CI scripts)" so the behavior and mechanism are
unambiguous.
tests/unit_tests/diffusion/conftest.py-57-62 (1)

57-62: ⚠️ Potential issue | 🟡 Minor

Raise an explicit error for unsupported run_only_on targets instead of silently ignoring them.

The marker is documented as supporting CPU/GPU (line 82), but the fixture only enforces "gpu" (line 60). Unsupported targets like @pytest.mark.run_only_on("cpu") or typos currently become silent no-ops, risking tests running under the wrong environment. Either narrow the marker contract to GPU-only or raise NotImplementedError on unsupported values to fail fast.

🛠️ Proposed fix
 def check_gpu_requirements(request):
     """Fixture to skip tests that require GPU when CUDA is not available"""
     marker = request.node.get_closest_marker("run_only_on")
-    if marker and "gpu" in [arg.lower() for arg in marker.args]:
-        if not torch.cuda.is_available():
-            pytest.skip("Test requires GPU but CUDA is not available")
+    if not marker:
+        return
+
+    targets = {str(arg).lower() for arg in marker.args}
+    if "gpu" in targets and not torch.cuda.is_available():
+        pytest.skip("Test requires GPU but CUDA is not available")
+
+    unsupported_targets = targets - {"gpu"}
+    if unsupported_targets:
+        raise NotImplementedError(
+            f"Unsupported run_only_on target(s): {sorted(unsupported_targets)}. "
+            "Supported target(s): ['gpu']."
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/conftest.py` around lines 57 - 62, The fixture
check_gpu_requirements currently only handles the "gpu" marker and silently
ignores other values; update it to validate marker.args explicitly and raise a
NotImplementedError for any unsupported/unknown targets (e.g., "cpu" or typos)
instead of no-op; keep the existing behavior for "gpu" by checking
torch.cuda.is_available() and calling pytest.skip when unavailable, and
reference request.node.get_closest_marker / marker.args /
torch.cuda.is_available / pytest.skip / NotImplementedError to locate and
implement the change.
src/megatron/bridge/diffusion/models/wan/inference/__init__.py-1-4 (1)

1-4: ⚠️ Potential issue | 🟡 Minor

Add required NVIDIA copyright header.

Per coding guidelines, all Python files must include the NVIDIA copyright header at the top.

📝 Proposed fix
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import os

As per coding guidelines: "Add NVIDIA copyright header to all Python files and shell scripts at the top of the file."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/models/wan/inference/__init__.py` around lines
1 - 4, This file is missing the required NVIDIA copyright header; add the
standard NVIDIA copyright header comment block at the very top of this Python
module (before any imports or code) in
src/megatron/bridge/diffusion/models/wan/inference/__init__.py so it precedes
the existing import os and the os.environ["TOKENIZERS_PARALLELISM"] assignment;
ensure the header matches the project's canonical header format and licensing
text.
examples/diffusion/recipes/wan/README-perf-test.md-36-37 (1)

36-37: ⚠️ Potential issue | 🟡 Minor

Use official repository instead of personal fork.

The clone command references a personal fork (huvunvidia/Megatron-Bridge) rather than the official NVIDIA-NeMo repository. This should point to the official repo for consistency and maintainability.

📝 Proposed fix
-git clone --no-checkout https://github.com/huvunvidia/Megatron-Bridge.git
+git clone --no-checkout https://github.com/NVIDIA-NeMo/Megatron-Bridge.git
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusion/recipes/wan/README-perf-test.md` around lines 36 - 37,
Replace the personal fork URL in the clone command with the official NVIDIA
repo: update the git clone line that references "huvunvidia/Megatron-Bridge" so
it clones "NVIDIA/NeMo" or the official "NVIDIA/Megatron-Bridge" repository
(keeping the subsequent git -C Megatron-Bridge checkout
713ab548e4bfee307eb94a7bb3f57c17dbb31b50 line intact) so the README uses the
official upstream source for Megatron-Bridge.
examples/diffusion/recipes/wan/README-perf-test.md-152-154 (1)

152-154: ⚠️ Potential issue | 🟡 Minor

Replace hardcoded internal paths with placeholders.

Similar to the pretraining section, these inference paths reference internal infrastructure and should use placeholders.

📝 Proposed fix
-T5_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/t5"
-VAE_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/vae"
-CKPT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_checkpoints/megatron_checkpoint_1.3B"
+T5_DIR="<your_t5_checkpoint_dir>"
+VAE_DIR="<your_vae_checkpoint_dir>"
+CKPT_DIR="<your_megatron_checkpoint_dir>"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusion/recipes/wan/README-perf-test.md` around lines 152 - 154,
The three hardcoded variables T5_DIR, VAE_DIR, and CKPT_DIR contain internal
absolute paths; replace them with configurable placeholders (e.g.,
environment-variable or template placeholders) so the recipe is portable—update
the assignment of T5_DIR, VAE_DIR, and CKPT_DIR to reference placeholder values
(for example ${T5_DIR}, ${VAE_DIR}, ${CKPT_DIR} or read from environment
variables) and add a short comment or README note explaining the expected
placeholder/env var format.
examples/diffusion/recipes/wan/README-perf-test.md-56-58 (1)

56-58: ⚠️ Potential issue | 🟡 Minor

Replace hardcoded internal paths with placeholders.

These paths reference internal NVIDIA infrastructure and a specific user's directory. Use placeholder notation consistent with other variables in this guide.

📝 Proposed fix
-DATASET_PATH="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_datasets/processed_arrietty_scene_automodel"
+DATASET_PATH="<your_dataset_path>"
 EXP_NAME=wan_debug_perf
-CHECKPOINT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/results/wan_finetune/${EXP_NAME}"
+CHECKPOINT_DIR="<your_checkpoint_dir>/${EXP_NAME}"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusion/recipes/wan/README-perf-test.md` around lines 56 - 58,
Replace the hardcoded internal NVIDIA user paths by using placeholder variables
for DATASET_PATH and CHECKPOINT_DIR (leave EXP_NAME as-is); update DATASET_PATH
to a generic placeholder like /path/to/dataset or ${DATASET_PATH} and set
CHECKPOINT_DIR to reference the experiment name via a placeholder such as
/path/to/checkpoints/${EXP_NAME} or ${CHECKPOINT_DIR} so the README no longer
contains user- or infra-specific paths (search for the DATASET_PATH, EXP_NAME
and CHECKPOINT_DIR assignments to make the change).
tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py-33-38 (1)

33-38: ⚠️ Potential issue | 🟡 Minor

Replace assert False with raise AssertionError() or use pytest.raises.

assert False is removed when Python runs with -O flag, which would silently pass the test even when the exception isn't raised.

🐛 Proposed fix using pytest.raises
-    try:
-        inf_utils.str2bool("maybe")
-    except argparse.ArgumentTypeError:
-        pass
-    else:
-        assert False, "Expected argparse.ArgumentTypeError for invalid boolean string"
+    import pytest
+    with pytest.raises(argparse.ArgumentTypeError):
+        inf_utils.str2bool("maybe")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py`
around lines 33 - 38, The test uses `assert False` which is removed under Python
-O; update the test in
`tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py` to fail
reliably by either replacing the `else: assert False, "Expected ..."` branch
with `raise AssertionError("Expected argparse.ArgumentTypeError for invalid
boolean string")` or, preferably, restructure the block to use pytest's context
manager (`pytest.raises(argparse.ArgumentTypeError)`) around
`inf_utils.str2bool("maybe")`; reference `inf_utils.str2bool` in the change so
the test asserts that this call raises `argparse.ArgumentTypeError`.
src/megatron/bridge/diffusion/models/wan/inference/utils.py-70-84 (1)

70-84: ⚠️ Potential issue | 🟡 Minor

Add explicit return None and consider logging failed retries.

cache_image implicitly returns None when all retries fail, unlike cache_video which explicitly returns None. Also, silently swallowing exceptions makes debugging difficult.

🐛 Proposed fix
+import logging
+
+logger = logging.getLogger(__name__)
+
 def cache_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5):  # noqa: D103
     # cache file
     suffix = osp.splitext(save_file)[1]
     if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]:
         suffix = ".png"
 
     # save to cache
+    error = None
     for _ in range(retry):
         try:
             tensor = tensor.clamp(min(value_range), max(value_range))
             torchvision.utils.save_image(tensor, save_file, nrow=nrow, normalize=normalize, value_range=value_range)
             return save_file
-        except Exception:
+        except Exception as e:
+            error = e
             continue
+    else:
+        logger.warning(f"cache_image failed after {retry} retries, last error: {error}")
+        return None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/models/wan/inference/utils.py` around lines 70
- 84, The cache_image function currently swallows exceptions and implicitly
returns None after retries; update cache_image to mirror cache_video by
explicitly returning None when all retries fail, and add logging of each caught
exception (including exception details) inside the except block before retrying
so failures are visible; reference the cache_image function, the retry loop, the
torchvision.utils.save_image call, and the value_range clamp to locate where to
add the logging and final explicit return None.
tests/unit_tests/diffusion/model/wan/test_utils.py-15-20 (1)

15-20: ⚠️ Potential issue | 🟡 Minor

Mark this module as unit-test coverage.

Add a module-level unit mark so test selection stays consistent with the rest of tests/unit_tests. As per coding guidelines, "tests/**/*.py: Use pytest.mark to categorize tests (unit, integration, system)."

♻️ Proposed fix
+import pytest
 import torch
 
 from megatron.bridge.diffusion.models.wan.utils import grid_sizes_calculation, patchify, unpatchify
 
+pytestmark = pytest.mark.unit
+
 
 def test_grid_sizes_calculation_basic():
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/wan/test_utils.py` around lines 15 - 20, Add
a module-level pytest mark to classify this file as a unit test: import pytest
at the top of tests/unit_tests/diffusion/model/wan/test_utils.py and set
pytestmark = pytest.mark.unit (so the module is discovered as unit tests
alongside the existing imports and functions like
test_grid_sizes_calculation_basic and the imported functions
grid_sizes_calculation, patchify, unpatchify).
tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py-15-24 (1)

15-24: ⚠️ Potential issue | 🟡 Minor

Mark this module as unit-test coverage.

Add a module-level pytestmark = pytest.mark.unit so CI can select this file consistently. As per coding guidelines, "tests/**/*.py: Use pytest.mark to categorize tests (unit, integration, system)."

♻️ Proposed fix
+import pytest
 import torch
 
 from megatron.bridge.diffusion.models.wan.flow_matching.time_shift_utils import (
     compute_density_for_timestep_sampling,
     get_flow_match_loss_weight,
     time_shift,
 )
 
+pytestmark = pytest.mark.unit
+
 
 def test_time_shift_constant_linear_sqrt_bounds_and_monotonic():
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py`
around lines 15 - 24, Add module-level pytest marker by importing pytest and
defining pytestmark = pytest.mark.unit at the top of the test module so CI can
select it; update
tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py to
include an import pytest and a module-level assignment pytestmark =
pytest.mark.unit (place it above or just below the existing imports) to mark the
file as unit-test coverage.
src/megatron/bridge/diffusion/recipes/flux/flux.py-102-103 (1)

102-103: ⚠️ Potential issue | 🟡 Minor

data_paths=None does not do what the docstring says.

Lines 144-145 promise mock data when data_paths is None, but the implementation only selects FluxMockDataModuleConfig when mock is True; otherwise pretrain_config() still constructs FluxDataModuleConfig(path=None). That makes the public API ambiguous and leaves the default config dependent on later overrides.

Also applies to: 144-145, 213-233

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/recipes/flux/flux.py` around lines 102 - 103,
The code currently chooses FluxMockDataModuleConfig only when mock is True, but
the docstring promises mock data when data_paths is None; update the logic in
pretrain_config (and the similar block around lines 213-233) so that if
data_paths is None OR mock is True you construct FluxMockDataModuleConfig,
otherwise construct FluxDataModuleConfig with the provided data_paths; ensure
the parameter handling in the function signature (data_paths:
Optional[List[str]] = None, mock: bool = False) is respected and any downstream
callers still receive the intended config type when data_paths is None.
examples/diffusion/recipes/flux/finetune_flux.py-160-160 (1)

160-160: ⚠️ Potential issue | 🟡 Minor

--debug currently has no effect.

The flag is parsed, but main() never sets the root/logger level from args.debug, so every logger.debug(...) call stays silent. Configure logging once right after argument parsing so the command-line interface behaves as advertised.

🔧 Proposed fix
 def main() -> None:
@@
     args, cli_overrides = parse_cli_args()
+    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
+    logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
 
     logger.info("Megatron-Bridge FLUX Fine-tuning Script with YAML & CLI Overrides")

Also applies to: 226-230

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusion/recipes/flux/finetune_flux.py` at line 160, The --debug
flag is parsed but never applied, so logger.debug calls are silent; after
parsing args in main() (where parser.add_argument("--debug", ... ) is defined)
configure the logging root/logger level based on args.debug (e.g., call
logging.basicConfig(...) or logging.getLogger().setLevel(logging.DEBUG if
args.debug else logging.INFO) and adjust the module logger used in the script)
so debug messages are emitted; apply the same fix to the other parsing site
around the symbols handling lines 226-230 (the other parser/args block) to
ensure both CLI entry points respect args.debug.
src/megatron/bridge/diffusion/recipes/flux/flux.py-135-136 (1)

135-136: ⚠️ Potential issue | 🟡 Minor

precision_config=None will crash despite the current type hint.

After the string-conversion branch, Lines 211-212 unconditionally mutate precision_config. If callers follow the Optional[...] annotation and pass None, this becomes an AttributeError. Either drop None from the type or materialize a default MixedPrecisionConfig before mutating it.

Also applies to: 208-212

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/recipes/flux/flux.py` around lines 135 - 136,
The function treats precision_config (and similarly comm_overlap_config) as if
it is always an object and mutates it after the "string-conversion" branch, but
the signature allows None which will cause AttributeError; fix by either
removing None from the type hint (make precision_config:
Union[MixedPrecisionConfig, str] = "bf16_mixed") or, preferably, materialize
defaults up-front: if precision_config is None set precision_config =
MixedPrecisionConfig() (and if comm_overlap_config is None set
comm_overlap_config = CommOverlapConfig()) before any in-place mutations or
attribute access so subsequent string-conversion and mutation code can safely
operate on an object.
tests/functional_tests/diffusion/recipes/test_wan_pretrain.py-97-106 (1)

97-106: ⚠️ Potential issue | 🟡 Minor

Preserve subprocess output on timeout.

TimeoutExpired carries partial stdout/stderr, but Line 98 fails before result is populated, so the finally block prints nothing in the timeout case. That is usually the most important failure mode to debug.

Suggested fix
-        except subprocess.TimeoutExpired:
-            pytest.fail("WAN pretrain mock run exceeded timeout of 1800 seconds (30 minutes)")
+        except subprocess.TimeoutExpired as e:
+            result = e
+            pytest.fail("WAN pretrain mock run exceeded timeout of 1800 seconds (30 minutes)")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/diffusion/recipes/test_wan_pretrain.py` around lines
97 - 106, The TimeoutExpired handler loses partial output because `result` isn't
set before the finally block; update the exception handling so the
TimeoutExpired exception is captured into `result` (e.g., assign the caught
exception object in the `except subprocess.TimeoutExpired as e:` block) or
initialize `result` before the try, then set `result = e` in both `except
subprocess.TimeoutExpired as e:` and `except subprocess.CalledProcessError as
e:` so the `finally` block that prints `result.stdout` and `result.stderr`
always sees the partial output; adjust references to `e.stdout`/`e.stderr` if
needed when populating `result`.
tests/unit_tests/diffusion/model/wan/inference/test_inference_init.py-37-40 (1)

37-40: ⚠️ Potential issue | 🟡 Minor

Rename the unused loop key to keep Ruff clean.

model_key is never referenced in the loop body, so this trips Ruff B007.

Suggested fix
-    for model_key, sizes in SUPPORTED_SIZES.items():
+    for _model_key, sizes in SUPPORTED_SIZES.items():

As per coding guidelines "Use ruff for linting and formatting Python code".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/wan/inference/test_inference_init.py` around
lines 37 - 40, The loop over SUPPORTED_SIZES uses an unused loop variable
model_key which triggers Ruff B007; rename it to _model_key (or simply _) in the
for statement that iterates SUPPORTED_SIZES.items() so only the used variable
sizes remains referenced, ensuring the checks for isinstance(sizes, tuple) and
membership of s in SIZE_CONFIGS remain unchanged.
tests/unit_tests/diffusion/model/wan/test_rope_utils.py-21-45 (1)

21-45: ⚠️ Potential issue | 🟡 Minor

Align the CUDA skip with the device you actually exercise.

This test is skipped on CPU-only runners, but Line 44 still forces device=torch.device("cpu"). That means CPU coverage is dropped unnecessarily, and CUDA-specific behavior is not exercised either. Please either remove the CUDA gate or run the test on CUDA consistently.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/wan/test_rope_utils.py` around lines 21 -
45, The test test_wan3d_rope_embeddings_shapes_and_padding is gated by
pytest.mark.skipif checking torch.cuda.is_available() but still forces
device=torch.device("cpu") when calling Wan3DRopeEmbeddings(...). Fix by
aligning the skip condition with the exercised device: either remove the skipif
to run on CPU by default, or change the device argument to select CUDA when
available (e.g., device=torch.device("cuda") if torch.cuda.is_available() else
"cpu") and keep the skipif for CPU-only removal; update the call site that
passes device to Wan3DRopeEmbeddings and the test decorator accordingly so the
runtime device and skip logic match.
tests/unit_tests/diffusion/model/common/test_normalization.py-15-17 (1)

15-17: ⚠️ Potential issue | 🟡 Minor

Add pytest.mark.unit marker to this test module.

This unit test file is missing the required pytestmark = [pytest.mark.unit] declaration. According to coding guidelines, all tests under tests/**/*.py must use pytest markers for categorization. Add the marker after imports:

✅ Minimal fix
+import pytest
 import torch
 
 from megatron.bridge.diffusion.models.common.normalization import RMSNorm
 
+pytestmark = [pytest.mark.unit]
+
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/common/test_normalization.py` around lines
15 - 17, The test module test_normalization.py is missing the pytest unit
marker; add a module-level declaration pytestmark = [pytest.mark.unit]
immediately after the imports (after the import torch and RMSNorm import) so the
file is properly categorized; ensure pytest is imported or referenced (use
pytestmark exactly as shown) and do not modify any test functions like those
referencing RMSNorm.
tests/unit_tests/diffusion/data/wan/test_wan_mock_datamodule.py-15-18 (1)

15-18: ⚠️ Potential issue | 🟡 Minor

Add the module-level unit marker.

This file is located in tests/unit_tests/ and must include pytestmark = [pytest.mark.unit] to enable marker-based test selection. All other unit tests in the repository follow this pattern.

Suggested fix
+import pytest
 import torch
 
 from megatron.bridge.diffusion.data.wan.wan_mock_datamodule import WanMockDataModuleConfig
 
+pytestmark = [pytest.mark.unit]
+
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/data/wan/test_wan_mock_datamodule.py` around lines
15 - 18, This test module is missing the module-level unit marker; add an import
for pytest and a top-level assignment pytestmark = [pytest.mark.unit] near the
existing imports (e.g., alongside the torch and WanMockDataModuleConfig imports)
so the test runner can select unit tests; ensure the marker is at module scope
(not inside any function or class).
tests/unit_tests/diffusion/data/wan/test_wan_energon_datamodule.py-15-17 (1)

15-17: ⚠️ Potential issue | 🟡 Minor

Add the missing unit marker.

This module lives under tests/unit_tests but lacks pytest.mark.unit, preventing marker-based test selection from discovering it. Add the pytest import and module-level pytestmark.

Minimal fix
+import pytest
+
 from megatron.bridge.diffusion.data.wan import wan_energon_datamodule as wan_dm_mod
 from megatron.bridge.diffusion.data.wan.wan_taskencoder import WanTaskEncoder
 
+pytestmark = [pytest.mark.unit]
+

Per coding guidelines: "tests/**/*.py: Use 'pytest.mark' to categorize tests (unit, integration, system)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/data/wan/test_wan_energon_datamodule.py` around
lines 15 - 17, The test module currently imports wan_energon_datamodule as
wan_dm_mod and WanTaskEncoder but lacks a pytest unit marker; add an import for
pytest and set a module-level pytestmark = pytest.mark.unit (so the module is
discoverable by marker-based selection) near the top of the file alongside the
existing imports.
tests/unit_tests/diffusion/model/wan/test_wan_provider.py-15-22 (1)

15-22: ⚠️ Potential issue | 🟡 Minor

Add the missing unit marker.

The file is located under tests/unit_tests/ but lacks both the pytest import and the pytestmark = [pytest.mark.unit] declaration. According to coding guidelines, all tests in tests/**/*.py must be categorized using pytest markers. Other unit tests in the same directory follow this pattern consistently.

✅ Minimal fix
+import pytest
 import torch
 import torch.nn as nn
 from megatron.core import parallel_state
 
 import megatron.bridge.diffusion.models.wan.wan_model as wan_model_module
 from megatron.bridge.diffusion.models.wan.wan_model import WanModel
 from megatron.bridge.diffusion.models.wan.wan_provider import WanModelProvider
 
+pytestmark = [pytest.mark.unit]
+
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/wan/test_wan_provider.py` around lines 15 -
22, This test file is missing the pytest unit marker and pytest import; add an
import for pytest near the existing imports and declare pytestmark =
[pytest.mark.unit] at module scope (e.g., just below the imports) so the test is
categorized as a unit test; update the top of the file that imports torch,
WanModel, WanModelProvider, etc., to include pytest and the pytestmark
declaration.
src/megatron/bridge/diffusion/data/common/diffusion_sample.py-27-30 (1)

27-30: ⚠️ Potential issue | 🟡 Minor

Align the docstring with the actual text-field names.

The class docstring still documents t5_text_embeddings and t5_text_mask, but the public dataclass fields are context_embeddings and context_mask. That makes the interface misleading for callers.

As per coding guidelines "For interfaces that may be used outside a file, prefer docstrings over comments".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/data/common/diffusion_sample.py` around lines
27 - 30, Update the class docstring for DiffusionSample to document the actual
public fields: replace references to t5_text_embeddings and t5_text_mask with
context_embeddings (torch.Tensor, S D) and context_mask (torch.Tensor)
respectively, and keep the existing video description; ensure the docstring
names and types exactly match the dataclass attributes (context_embeddings,
context_mask, video) so external callers see the correct interface.
src/megatron/bridge/diffusion/data/common/sequence_packing_utils.py-45-47 (1)

45-47: ⚠️ Potential issue | 🟡 Minor

Fix the return contract in this docstring.

The implementation appends s values directly, so each bin contains sequence lengths/items, not sequence indices. The current description points callers at the wrong contract.

As per coding guidelines "For interfaces that may be used outside a file, prefer docstrings over comments".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/data/common/sequence_packing_utils.py` around
lines 45 - 47, The Returns docstring in sequence_packing_utils.py is incorrect:
the implementation appends the variable s itself into bins, so each inner list
contains the sequence items/lengths (the s values), not sequence indices; update
the Returns section of the function docstring to state that it returns
List[List[...] ] where each inner list contains the appended s values (sequence
lengths/items) and clarify the element type and semantics (i.e., that callers
receive sequence values, not indices), and adjust any example/type hints in the
docstring to match this contract.
tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py-15-18 (1)

15-18: ⚠️ Potential issue | 🟡 Minor

Mark this file as unit tests.

Unlike the other new diffusion unit suites in this PR, these tests are unmarked. That makes marker-based selection/reporting inconsistent and can leave them out of targeted unit jobs.

🏷️ Minimal fix
+import pytest
 import torch
 
 from megatron.bridge.diffusion.data.flux.flux_mock_datamodule import FluxMockDataModuleConfig
 
+
+pytestmark = [pytest.mark.unit]
As per coding guidelines "tests/**/*.py: Use 'pytest.mark' to categorize tests (unit, integration, system)".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py` around
lines 15 - 18, This test module is missing a unit test marker; add pytest import
and a module-level marker so pytest treats all tests here as unit tests (e.g.
add "import pytest" and set "pytestmark = pytest.mark.unit" at top of the file)
to ensure marker-based selection/reporting; locate the file referenced by
FluxMockDataModuleConfig import
(tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py) and add the
pytestmark variable so existing test functions/classes in the module are marked
as unit tests.
src/megatron/bridge/diffusion/data/common/diffusion_sample.py-80-114 (1)

80-114: ⚠️ Potential issue | 🟡 Minor

Return NotImplemented from these dunder methods.

For Python's binary operator protocol, unsupported operand types should return NotImplemented here, not raise NotImplementedError. Raising NotImplementedError short-circuits reflected-operator fallback and prevents Python from attempting the operation on the other operand's type or raising the appropriate TypeError.

🔧 Minimal fix
-        raise NotImplementedError
+        return NotImplemented
@@
-        raise NotImplementedError
+        return NotImplemented
@@
-        raise NotImplementedError
+        return NotImplemented
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/data/common/diffusion_sample.py` around lines
80 - 114, The dunder methods __add__, __radd__, and __lt__ currently raise
NotImplementedError for unsupported operand types; change those to return
NotImplemented so Python's binary operator protocol and reflected fallbacks work
correctly (keep the existing logic for choosing seq_len_q_padded vs seq_len_q
and the isinstance checks, but replace the final "raise NotImplementedError" in
DiffusionSample.__add__, DiffusionSample.__radd__, and DiffusionSample.__lt__
with "return NotImplemented").
tests/unit_tests/diffusion/recipes/flux/test_flux_recipe.py-100-107 (1)

100-107: ⚠️ Potential issue | 🟡 Minor

The custom-LR case never proves the LR override is wired through.

lr=5e-5 is passed into pretrain_config(), but the assertions stop at config.train. If optimizer or scheduler wiring drops the override, this test still passes.

Also applies to: 109-112

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/recipes/flux/test_flux_recipe.py` around lines 100
- 107, The test sets lr=5e-5 in pretrain_config but never asserts the
optimizer/scheduler received it; update the test (around pretrain_config usage)
to assert the learning-rate override is wired through by inspecting the created
optimizer/scheduler objects (e.g., check config.optimizer or the built
optimizer's param_groups[0]['lr'] or the scheduler's initial_lr/learning_rate
attribute) and assert it equals 5e-5; add equivalent assertions for the other
case referenced (lines 109-112) so the test fails if the override is dropped.
tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py-166-185 (1)

166-185: ⚠️ Potential issue | 🟡 Minor

This test never exercises FluxSafeTensorsStateSource.save_generator().

TestFluxSource reimplements the transformer/ path logic locally, so the test still passes even if the real override in flux_hf_module.FluxSafeTensorsStateSource regresses or its inheritance chain changes. Please instantiate the production class and spy on the parent call instead of validating a stand-in subclass.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py`
around lines 166 - 185, The test currently defines TestFluxSource that
reimplements the transformer path, so it doesn't exercise
flux_hf_module.FluxSafeTensorsStateSource.save_generator(); instead, remove the
local TestFluxSource and instantiate the actual production class
flux_hf_module.FluxSafeTensorsStateSource, monkeypatch its parent
(SafeTensorsStateSource) save_generator to capture calls (e.g.,
parent_save_called spy) and then call
flux_hf_module.FluxSafeTensorsStateSource().save_generator(output_path,
strict=False); assert the spy recorded a single call with str(output_path /
"transformer") and False and that the returned result equals "success".
tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py-15-24 (1)

15-24: ⚠️ Potential issue | 🟡 Minor

Add a unit-test marker for this module.

This file sits under tests/unit_tests/ but doesn't declare pytest.mark.unit, so marker-based selection can miss it.

🧪 Minimal fix
 import pytest
 import torch
 from dfm.src.automodel.flow_matching.adapters.base import FlowMatchingContext
 
 from megatron.bridge.diffusion.models.wan.flow_matching.flow_matching_pipeline_wan import (
     WanAdapter,
     WanFlowMatchingPipeline,
 )
 
+pytestmark = [pytest.mark.unit]
+
 
 class TestWanAdapter:
As per coding guidelines, "`tests/**/*.py`: Use 'pytest.mark' to categorize tests (unit, integration, system)`."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py`
around lines 15 - 24, Add a pytest module marker to classify these tests as unit
tests: insert a module-level marker (e.g., pytestmark = pytest.mark.unit) near
the top of the test module that imports WanAdapter and WanFlowMatchingPipeline
so pytest's marker-based selection includes this file; place it after the
imports in
tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py
so the marker applies to all tests in the module.
tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py-197-217 (1)

197-217: ⚠️ Potential issue | 🟡 Minor

Assert the transformer subfolder in _load_model() as well.

Right now this only records the root path, so _load_model() can stop passing subfolder="transformer" and the test still stays green.

🔍 Tighten the assertion
     class FakeFlux:
         `@classmethod`
         def from_pretrained(cls, path, **kwargs):
-            calls.append(str(path))
+            calls.append((str(path), kwargs.get("subfolder")))
             return FakeFlux()
@@
-    assert calls[0] == str(src)
+    assert calls[0] == (str(src), "transformer")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py`
around lines 197 - 217, The test records only the root path so it won't catch if
_load_model stops passing subfolder="transformer"; update the assertion to
require the transformer subfolder by expecting the called path to equal str(src
/ "transformer") (or to contain "transformer") and reference
PreTrainedFlux._load_model and FluxTransformer2DModel.from_pretrained (captured
via the calls list) so the test fails if the subfolder argument is omitted.
tests/unit_tests/diffusion/data/flux/test_flux_taskencoder.py-15-17 (1)

15-17: ⚠️ Potential issue | 🟡 Minor

Add a unit-test marker for this file.

Like the other new unit-test modules, this should declare pytest.mark.unit; otherwise marker-based selection can skip it.

🧪 Minimal fix
+import pytest
 import torch
 
 from megatron.bridge.diffusion.data.flux.flux_taskencoder import FluxTaskEncoder, cook, parallel_state
 
+pytestmark = [pytest.mark.unit]
+
As per coding guidelines, "`tests/**/*.py`: Use 'pytest.mark' to categorize tests (unit, integration, system)`."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/diffusion/data/flux/test_flux_taskencoder.py` around lines
15 - 17, Add a pytest marker to classify this test as a unit test: import pytest
at the top of the file and define pytestmark = pytest.mark.unit (e.g.,
immediately after the imports). This ensures the test module with symbols
FluxTaskEncoder, cook, parallel_state is discovered by marker-based selection.
src/megatron/bridge/diffusion/recipes/wan/wan.py-98-100 (1)

98-100: ⚠️ Potential issue | 🟡 Minor

Guard the precision_config=None case promised by the signature.

The type allows None, but Line 160 dereferences it unconditionally and crashes.

Proposed fix
     if isinstance(precision_config, str):
         precision_config = get_mixed_precision_config(precision_config)
 
-    precision_config.grad_reduce_in_fp32 = False
+    if precision_config is not None:
+        precision_config.grad_reduce_in_fp32 = False

Also applies to: 157-160

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/recipes/wan/wan.py` around lines 98 - 100, The
code dereferences precision_config (type Optional[Union[MixedPrecisionConfig,
str]]) without guarding the None case—update the function to check if
precision_config is None before accessing its attributes; in the block that
currently assumes precision_config (around where precision_config is used to
build precision-related config values) either set a sensible default (e.g.,
leave precision unset or use a default MixedPrecisionConfig) or skip
constructing precision entries when precision_config is None, and apply the same
guard to the other occurrence referenced in the review (the second use around
lines 157-160) so no attribute access happens on None; reference the
precision_config symbol and MixedPrecisionConfig to locate where to add the
conditional.
examples/diffusion/recipes/flux/pretrain_flux.py-141-141 (1)

141-141: ⚠️ Potential issue | 🟡 Minor

Wire --debug into logging or remove the flag.

Line 141 advertises debug logging, but main() never changes the logger level, so the option is currently a no-op.

Proposed fix
 def main() -> None:
@@
     args, cli_overrides = parse_cli_args()
+    log_level = logging.DEBUG if args.debug else logging.INFO
+    logging.basicConfig(level=log_level)
+    logger.setLevel(log_level)
 
     logger.info("Megatron-Bridge FLUX Pretraining Script with YAML & CLI Overrides")

Also applies to: 199-203

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusion/recipes/flux/pretrain_flux.py` at line 141, The --debug
argparse flag is declared but not applied; update main() to set the logger level
when args.debug is true (e.g., call logging.basicConfig or set the module/root
logger via logger.setLevel(logging.DEBUG)) so debug logging is enabled; ensure
the same wiring is applied for the other parser instance/usage mentioned (the
second parser block around the pretrain/run entrypoints) so the flag is not a
no-op, and reference parser.add_argument("--debug", ...) and main() (and any
other run/pretrain entrypoint functions) to locate where to set the log level.
src/megatron/bridge/diffusion/conversion/flux/flux_hf_pretrained.py-119-121 (1)

119-121: ⚠️ Potential issue | 🟡 Minor

Catching bare Exception and silently passing hides configuration errors.

When config export fails, users should be informed. Log a warning at minimum so issues are discoverable.

🛡️ Proposed fix

First ensure logger is available:

import logging
logger = logging.getLogger(__name__)

Then update the exception handling:

-        except Exception:
-            # Best-effort: if config cannot be produced, leave only weights
-            pass
+        except Exception as e:
+            # Best-effort: if config cannot be produced, leave only weights
+            logger.warning(f"Could not export config.json: {e}. Only weights will be saved.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/conversion/flux/flux_hf_pretrained.py` around
lines 119 - 121, The except block that currently swallows all errors should be
changed to log the failure so configuration export issues are discoverable: add
a module-level logger (import logging; logger = logging.getLogger(__name__)),
change the bare “except Exception:” to “except Exception as e:” in the
try/except around the config export in flux_hf_pretrained.py, and call
logger.warning (or logger.warning(..., exc_info=True)) with a clear message like
"config export failed, leaving only weights" including the exception details,
then preserve the current best-effort behavior (i.e., do not raise further).
src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py-76-77 (1)

76-77: ⚠️ Potential issue | 🟡 Minor

Bare except with silent pass hides errors.

This pattern can mask bugs and make debugging difficult. At minimum, catch a specific exception and log a warning.

🛡️ Proposed fix
         if store_in_state:
             try:
                 from megatron.bridge.training.pretrain import get_current_state

                 state = get_current_state()
                 state._last_validation_batch = _batch
-            except:
-                pass  # If state access fails, silently continue
+            except (ImportError, AttributeError) as e:
+                logger.debug(f"Could not store batch in state: {e}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py` around
lines 76 - 77, Replace the bare "except: pass" that swallows state-access errors
with a targeted catch and a warning log: catch the likely exceptions (e.g.,
AttributeError and KeyError) or at minimum "except Exception as e", log a
warning that includes the exception and context (use the module/logger for this
file, e.g. logger.warning(..., exc_info=True) or logger.exception(...))
referencing the code block that attempts to read "state" in
flux_step_with_automodel.py so errors aren't silently hidden.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2e089469-0ee3-49f6-89cf-d8941796b46a

📥 Commits

Reviewing files that changed from the base of the PR and between 1d25ea2 and 2e90f66.

📒 Files selected for processing (138)
  • examples/diffusion/README.md
  • examples/diffusion/override_configs/README.md
  • examples/diffusion/override_configs/wan_pretrain_sample_data.yaml
  • examples/diffusion/recipes/README.md
  • examples/diffusion/recipes/__init__.py
  • examples/diffusion/recipes/flux/conf/flux_pretrain_override_example.yaml
  • examples/diffusion/recipes/flux/conversion/convert_checkpoints.py
  • examples/diffusion/recipes/flux/finetune_flux.py
  • examples/diffusion/recipes/flux/inference_flux.py
  • examples/diffusion/recipes/flux/prepare_energon_dataset_flux.py
  • examples/diffusion/recipes/flux/pretrain_flux.py
  • examples/diffusion/recipes/wan/README-perf-test.md
  • examples/diffusion/recipes/wan/conf/gb200_perf_pretrain_mock.yaml
  • examples/diffusion/recipes/wan/conf/gb300_perf_pretrain_mock.yaml
  • examples/diffusion/recipes/wan/conf/h100_perf_pretrain_mock.yaml
  • examples/diffusion/recipes/wan/conf/wan_14B.yaml
  • examples/diffusion/recipes/wan/conf/wan_1_3B.yaml
  • examples/diffusion/recipes/wan/conf/wan_pretrain_override_example.yaml
  • examples/diffusion/recipes/wan/conversion/convert_checkpoints.py
  • examples/diffusion/recipes/wan/inference_wan.py
  • examples/diffusion/recipes/wan/pretrain_wan.py
  • src/megatron/bridge/diffusion/README.md
  • src/megatron/bridge/diffusion/__init__.py
  • src/megatron/bridge/diffusion/base/README.md
  • src/megatron/bridge/diffusion/base/__init__.py
  • src/megatron/bridge/diffusion/common/__init__.py
  • src/megatron/bridge/diffusion/common/tokenizers/__init__.py
  • src/megatron/bridge/diffusion/common/utils/__init__.py
  • src/megatron/bridge/diffusion/common/utils/batch_ops.py
  • src/megatron/bridge/diffusion/common/utils/dynamic_import.py
  • src/megatron/bridge/diffusion/common/utils/save_video.py
  • src/megatron/bridge/diffusion/common/utils/torch_split_tensor_for_cp.py
  • src/megatron/bridge/diffusion/conversion/__init__.py
  • src/megatron/bridge/diffusion/conversion/flux/__init__.py
  • src/megatron/bridge/diffusion/conversion/flux/flux_bridge.py
  • src/megatron/bridge/diffusion/conversion/flux/flux_hf_pretrained.py
  • src/megatron/bridge/diffusion/conversion/wan/__init__.py
  • src/megatron/bridge/diffusion/conversion/wan/wan_bridge.py
  • src/megatron/bridge/diffusion/conversion/wan/wan_hf_pretrained.py
  • src/megatron/bridge/diffusion/data/__init__.py
  • src/megatron/bridge/diffusion/data/common/__init__.py
  • src/megatron/bridge/diffusion/data/common/diffusion_energon_datamodule.py
  • src/megatron/bridge/diffusion/data/common/diffusion_sample.py
  • src/megatron/bridge/diffusion/data/common/diffusion_task_encoder_with_sp.py
  • src/megatron/bridge/diffusion/data/common/sequence_packing_utils.py
  • src/megatron/bridge/diffusion/data/flux/__init__.py
  • src/megatron/bridge/diffusion/data/flux/flux_energon_datamodule.py
  • src/megatron/bridge/diffusion/data/flux/flux_mock_datamodule.py
  • src/megatron/bridge/diffusion/data/flux/flux_taskencoder.py
  • src/megatron/bridge/diffusion/data/wan/__init__.py
  • src/megatron/bridge/diffusion/data/wan/wan_energon_datamodule.py
  • src/megatron/bridge/diffusion/data/wan/wan_mock_datamodule.py
  • src/megatron/bridge/diffusion/data/wan/wan_taskencoder.py
  • src/megatron/bridge/diffusion/models/README.md
  • src/megatron/bridge/diffusion/models/__init__.py
  • src/megatron/bridge/diffusion/models/common/__init__.py
  • src/megatron/bridge/diffusion/models/common/dit_attention.py
  • src/megatron/bridge/diffusion/models/common/dit_embeddings.py
  • src/megatron/bridge/diffusion/models/common/normalization.py
  • src/megatron/bridge/diffusion/models/flux/__init__.py
  • src/megatron/bridge/diffusion/models/flux/flow_matching/__init__.py
  • src/megatron/bridge/diffusion/models/flux/flow_matching/flux_adapter.py
  • src/megatron/bridge/diffusion/models/flux/flow_matching/flux_inference_pipeline.py
  • src/megatron/bridge/diffusion/models/flux/flux_attention.py
  • src/megatron/bridge/diffusion/models/flux/flux_layer_spec.py
  • src/megatron/bridge/diffusion/models/flux/flux_model.py
  • src/megatron/bridge/diffusion/models/flux/flux_provider.py
  • src/megatron/bridge/diffusion/models/flux/flux_step.py
  • src/megatron/bridge/diffusion/models/flux/flux_step_with_automodel.py
  • src/megatron/bridge/diffusion/models/flux/layers.py
  • src/megatron/bridge/diffusion/models/wan/__init__.py
  • src/megatron/bridge/diffusion/models/wan/flow_matching/__init__.py
  • src/megatron/bridge/diffusion/models/wan/flow_matching/flow_inference_pipeline.py
  • src/megatron/bridge/diffusion/models/wan/flow_matching/flow_matching_pipeline_wan.py
  • src/megatron/bridge/diffusion/models/wan/flow_matching/time_shift_utils.py
  • src/megatron/bridge/diffusion/models/wan/inference/__init__.py
  • src/megatron/bridge/diffusion/models/wan/inference/utils.py
  • src/megatron/bridge/diffusion/models/wan/rope_utils.py
  • src/megatron/bridge/diffusion/models/wan/utils.py
  • src/megatron/bridge/diffusion/models/wan/wan_layer_spec.py
  • src/megatron/bridge/diffusion/models/wan/wan_model.py
  • src/megatron/bridge/diffusion/models/wan/wan_provider.py
  • src/megatron/bridge/diffusion/models/wan/wan_step.py
  • src/megatron/bridge/diffusion/recipes/__init__.py
  • src/megatron/bridge/diffusion/recipes/flux/__init__.py
  • src/megatron/bridge/diffusion/recipes/flux/flux.py
  • src/megatron/bridge/diffusion/recipes/wan/__init__.py
  • src/megatron/bridge/diffusion/recipes/wan/wan.py
  • tests/functional_tests/L2_Mcore_Mock_Tests_GPU.sh
  • tests/functional_tests/diffusion/__init__.py
  • tests/functional_tests/diffusion/recipes/__init__.py
  • tests/functional_tests/diffusion/recipes/test_flux_pretrain.py
  • tests/functional_tests/diffusion/recipes/test_wan_pretrain.py
  • tests/unit_tests/diffusion/__init__.py
  • tests/unit_tests/diffusion/conftest.py
  • tests/unit_tests/diffusion/data/common/__init__.py
  • tests/unit_tests/diffusion/data/common/test_diffusion_data_module.py
  • tests/unit_tests/diffusion/data/common/test_diffusion_sample.py
  • tests/unit_tests/diffusion/data/common/test_diffusion_task_encoder.py
  • tests/unit_tests/diffusion/data/common/test_sequence_packing_utils.py
  • tests/unit_tests/diffusion/data/flux/__init__.py
  • tests/unit_tests/diffusion/data/flux/test_flux_energon_datamodule.py
  • tests/unit_tests/diffusion/data/flux/test_flux_mock_datamodule.py
  • tests/unit_tests/diffusion/data/flux/test_flux_taskencoder.py
  • tests/unit_tests/diffusion/data/wan/__init__.py
  • tests/unit_tests/diffusion/data/wan/test_wan_energon_datamodule.py
  • tests/unit_tests/diffusion/data/wan/test_wan_mock_datamodule.py
  • tests/unit_tests/diffusion/data/wan/test_wan_taskencoder.py
  • tests/unit_tests/diffusion/model/common/__init__.py
  • tests/unit_tests/diffusion/model/common/test_normalization.py
  • tests/unit_tests/diffusion/model/flux/__init__.py
  • tests/unit_tests/diffusion/model/flux/conversion/__init__.py
  • tests/unit_tests/diffusion/model/flux/conversion/test_flux_bridge.py
  • tests/unit_tests/diffusion/model/flux/conversion/test_flux_hf_pretrained.py
  • tests/unit_tests/diffusion/model/flux/test_flux_layer_spec.py
  • tests/unit_tests/diffusion/model/flux/test_flux_layers.py
  • tests/unit_tests/diffusion/model/flux/test_flux_pipeline.py
  • tests/unit_tests/diffusion/model/flux/test_flux_provider.py
  • tests/unit_tests/diffusion/model/flux/test_flux_step.py
  • tests/unit_tests/diffusion/model/wan/__init__.py
  • tests/unit_tests/diffusion/model/wan/conversion/__init__.py
  • tests/unit_tests/diffusion/model/wan/conversion/test_wan_bridge.py
  • tests/unit_tests/diffusion/model/wan/conversion/test_wan_hf_pretrained.py
  • tests/unit_tests/diffusion/model/wan/flow_matching/__init__.py
  • tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_inference_pipeline.py
  • tests/unit_tests/diffusion/model/wan/flow_matching/test_flow_matching_pipeline_wan.py
  • tests/unit_tests/diffusion/model/wan/flow_matching/test_time_shift_utils.py
  • tests/unit_tests/diffusion/model/wan/inference/__init__.py
  • tests/unit_tests/diffusion/model/wan/inference/test_inference_init.py
  • tests/unit_tests/diffusion/model/wan/inference/test_inference_utils.py
  • tests/unit_tests/diffusion/model/wan/test_rope_utils.py
  • tests/unit_tests/diffusion/model/wan/test_utils.py
  • tests/unit_tests/diffusion/model/wan/test_wan_layer_spec.py
  • tests/unit_tests/diffusion/model/wan/test_wan_model_misc.py
  • tests/unit_tests/diffusion/model/wan/test_wan_provider.py
  • tests/unit_tests/diffusion/model/wan/test_wan_step.py
  • tests/unit_tests/diffusion/recipes/flux/__init__.py
  • tests/unit_tests/diffusion/recipes/flux/test_flux_recipe.py

@cuichenx cuichenx self-requested a review March 10, 2026 20:48
@ko3n1g
Copy link
Contributor

ko3n1g commented Mar 10, 2026

/ok to test 2e90f66

@cuichenx
Copy link
Contributor

please also check coderabbit's comments

@huvunvidia
Copy link
Contributor

/ok to test 2d20c47

@huvunvidia
Copy link
Contributor

/ok to test 1dda679

@huvunvidia
Copy link
Contributor

/ok to test 9d7f512

@huvunvidia
Copy link
Contributor

/ok to test c9e0395

Huy Vu2 and others added 2 commits March 13, 2026 12:06
Signed-off-by: Ao Tang <aot@nvidia.com>
Co-authored-by: Huy Vu2 <huvu@login-eos02.eos.clusters.nvidia.com>
@huvunvidia
Copy link
Contributor

/ok to test a8ae02d

cuichenx
cuichenx previously approved these changes Mar 13, 2026
@huvunvidia
Copy link
Contributor

/ok to test f550944

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants