Skip to content

Commit 868a1c2

Browse files
committed
flux integration test
1 parent bd52d3c commit 868a1c2

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tests/integration_tests/flux.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
6666
# run_test supports sequence of tests.
6767
test_name = test_flavor.test_name
6868
dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}"
69+
custom_job_args = "--job.custom_config_module torchtitan/models/flux/job_config.py"
6970

7071
# Random init encoder for offline testing
7172
model_arg = "--model.name flux"
@@ -91,8 +92,9 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
9192
f"scripts/flux_inference/run_infer.sh"
9293
)
9394

94-
cmd += " " + model_arg
9595
cmd += " " + dump_folder_arg
96+
cmd += " " + custom_job_args
97+
cmd += " " + model_arg
9698
cmd += " " + random_init_encoder_arg
9799
cmd += " " + clip_encoder_version_arg
98100
cmd += " " + t5_encoder_version_arg

torchtitan/models/flux/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchtitan.components.lr_scheduler import build_lr_schedulers
99
from torchtitan.components.optimizer import build_optimizers
1010

11-
from torchtitan.datasets.flux_dataset import build_flux_dataloader
11+
from torchtitan.hf_datasets.flux_dataset import build_flux_dataloader
1212
from torchtitan.protocols.train_spec import TrainSpec
1313
from .infra.parallelize import parallelize_flux
1414
from .model.args import FluxModelArgs

0 commit comments

Comments
 (0)