Skip to content

Commit 7e805a9

Browse files
authored
[Flux] Remove generate_image script and add CI test for inference (#1746)
As follow up of #1726
1 parent 3e1b843 commit 7e805a9

File tree

2 files changed

+23
-140
lines changed

2 files changed

+23
-140
lines changed

torchtitan/experiments/flux/tests/integration_tests.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
8787
"Flux Validation Test",
8888
"validation",
8989
),
90+
OverrideDefinitions(
91+
[
92+
[
93+
"--checkpoint.enable",
94+
],
95+
[
96+
# placeholder for the generation script's generate step
97+
],
98+
],
99+
"Flux Generation script test",
100+
"test_generate",
101+
ngpu=2,
102+
),
90103
]
91104
return integration_tests_flavors
92105

@@ -116,6 +129,15 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
116129
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./torchtitan/experiments/flux/run_train.sh"
117130
# dump compile trace for debugging purpose
118131
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
132+
133+
# save checkpoint (idx == 0) and load it for generation (idx == 1)
134+
if test_name == "test_generate" and idx == 1:
135+
# For flux generation, test using inference script
136+
cmd = (
137+
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
138+
f"./torchtitan/experiments/flux/inference/run_infer.sh"
139+
)
140+
119141
cmd += " " + model_arg
120142
cmd += " " + dump_folder_arg
121143
cmd += " " + random_init_encoder_arg
@@ -124,10 +146,10 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
124146
cmd += " " + tokenzier_path_arg
125147
if override_arg:
126148
cmd += " " + " ".join(override_arg)
149+
127150
logger.info(
128151
f"=====Flux Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
129152
)
130-
131153
result = _run_cmd(cmd)
132154
logger.info(result.stdout)
133155
if result.returncode != 0:

torchtitan/experiments/flux/tests/test_generate_image.py

Lines changed: 0 additions & 139 deletions
This file was deleted.

0 commit comments

Comments
 (0)