77import argparse
88import os
99
10+ from torchtitan .tools .logging import logger
11+
1012from tests .integration_tests import OverrideDefinitions
1113from tests .integration_tests .run_tests import _run_cmd
1214
13- from torchtitan .tools .logging import logger
14-
1515
1616def build_flux_test_list () -> list [OverrideDefinitions ]:
1717 """
@@ -20,53 +20,6 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
2020 same root config file.
2121 """
2222 integration_tests_flavors = [
23- # basic tests
24- OverrideDefinitions (
25- [
26- [
27- "--profiling.enable_profiling" ,
28- "--metrics.enable_tensorboard" ,
29- ],
30- ],
31- "default" ,
32- "default" ,
33- ),
34- # Checkpointing tests.
35- OverrideDefinitions (
36- [
37- [
38- "--checkpoint.enable" ,
39- ],
40- [
41- "--checkpoint.enable" ,
42- "--training.steps 20" ,
43- ],
44- ],
45- "Checkpoint Integration Test - Save Load Full Checkpoint" ,
46- "full_checkpoint" ,
47- ),
48- OverrideDefinitions (
49- [
50- [
51- "--checkpoint.enable" ,
52- "--checkpoint.last_save_model_only" ,
53- ],
54- ],
55- "Checkpoint Integration Test - Save Model Only fp32" ,
56- "last_save_model_only_fp32" ,
57- ),
58- # Parallelism tests.
59- OverrideDefinitions (
60- [
61- [
62- "--parallelism.data_parallel_shard_degree 4" ,
63- "--parallelism.data_parallel_replicate_degree 1" ,
64- ]
65- ],
66- "FSDP" ,
67- "fsdp" ,
68- ngpu = 4 ,
69- ),
7023 OverrideDefinitions (
7124 [
7225 [
@@ -117,16 +70,16 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
11770 # Random init encoder for offline testing
11871 model_arg = "--model.name flux"
11972 random_init_encoder_arg = "--training.test_mode"
120- clip_encoder_version_arg = "--encoder.clip_encoder torchtitan/experiments/flux/tests/assets/clip-vit-large-patch14/"
121- t5_encoder_version_arg = (
122- "--encoder.t5_encoder torchtitan/experiments/flux/tests/assets/t5-v1_1-xxl/"
73+ clip_encoder_version_arg = (
74+ "--encoder.clip_encoder tests/assets/clip-vit-large-patch14/"
12375 )
76+ t5_encoder_version_arg = "--encoder.t5_encoder tests/assets/t5-v1_1-xxl/"
12477 tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer"
12578
12679 all_ranks = "," .join (map (str , range (test_flavor .ngpu )))
12780
12881 for idx , override_arg in enumerate (test_flavor .override_args ):
129- cmd = f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK={ all_ranks } ./torchtitan/experiments /flux/run_train.sh"
82+ cmd = f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK={ all_ranks } ./torchtitan/models /flux/run_train.sh"
13083 # dump compile trace for debugging purpose
13184 cmd = f'TORCH_TRACE="{ output_dir } /{ test_name } /compile_trace" ' + cmd
13285
@@ -135,7 +88,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
13588 # For flux generation, test using inference script
13689 cmd = (
13790 f"CONFIG_FILE={ full_path } NGPU={ test_flavor .ngpu } LOG_RANK={ all_ranks } "
138- f"./torchtitan/experiments/flux/inference /run_infer.sh"
91+ f"scripts/flux_inference /run_infer.sh"
13992 )
14093
14194 cmd += " " + model_arg
@@ -189,7 +142,7 @@ def main():
189142 parser .add_argument ("output_dir" )
190143 parser .add_argument (
191144 "--config_path" ,
192- default = "./torchtitan/experiments/flux/train_configs/debug_model .toml" ,
145+ default = "./tests/integration_tests/base_config .toml" ,
193146 help = "Base config path for integration tests. This is the config that will be used as a base for all tests." ,
194147 )
195148 parser .add_argument (
0 commit comments