@@ -87,6 +87,19 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
8787 "Flux Validation Test" ,
8888 "validation" ,
8989 ),
90+ OverrideDefinitions (
91+ [
92+ [
93+ "--checkpoint.enable" ,
94+ ],
95+ [
96+ # placeholder for the generation script's generate step
97+ ],
98+ ],
99+ "Flux Generation script test" ,
100+ "test_generate" ,
101+ ngpu = 2 ,
102+ ),
90103 ]
91104 return integration_tests_flavors
92105
@@ -116,6 +129,15 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
116129 cmd = f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK={ all_ranks } ./torchtitan/experiments/flux/run_train.sh"
117130 # dump compile trace for debugging purpose
118131 cmd = f'TORCH_TRACE="{ output_dir } /{ test_name } /compile_trace" ' + cmd
132+
133+ # save checkpoint (idx == 0) and load it for generation (idx == 1)
134+ if test_name == "test_generate" and idx == 1 :
135+ # For flux generation, test using inference script
136+ cmd = (
137+ f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK={ all_ranks } "
138+ f"./torchtitan/experiments/flux/inference/run_infer.sh"
139+ )
140+
119141 cmd += " " + model_arg
120142 cmd += " " + dump_folder_arg
121143 cmd += " " + random_init_encoder_arg
@@ -124,10 +146,10 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
124146 cmd += " " + tokenzier_path_arg
125147 if override_arg :
126148 cmd += " " + " " .join (override_arg )
149+
127150 logger .info (
128151 f"=====Flux Integration test, flavor : { test_flavor .test_descr } , command : { cmd } ====="
129152 )
130-
131153 result = _run_cmd (cmd )
132154 logger .info (result .stdout )
133155 if result .returncode != 0 :
0 commit comments