Skip to content

Commit 309bc7c

Browse files
committed
move integration tests
1 parent 283204f commit 309bc7c

File tree

3 files changed

+9
-56
lines changed

3 files changed

+9
-56
lines changed

.github/workflows/integration_test_8gpu_models.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ jobs:
5252
5353
mkdir artifacts-to-be-uploaded
5454
python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8
55+
python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8

torchtitan/models/flux/tests/integration_tests.py renamed to tests/integration_tests/flux.py

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import argparse
88
import os
99

10+
from torchtitan.tools.logging import logger
11+
1012
from tests.integration_tests import OverrideDefinitions
1113
from tests.integration_tests.run_tests import _run_cmd
1214

13-
from torchtitan.tools.logging import logger
14-
1515

1616
def build_flux_test_list() -> list[OverrideDefinitions]:
1717
"""
@@ -20,53 +20,6 @@ def build_flux_test_list() -> list[OverrideDefinitions]:
2020
same root config file.
2121
"""
2222
integration_tests_flavors = [
23-
# basic tests
24-
OverrideDefinitions(
25-
[
26-
[
27-
"--profiling.enable_profiling",
28-
"--metrics.enable_tensorboard",
29-
],
30-
],
31-
"default",
32-
"default",
33-
),
34-
# Checkpointing tests.
35-
OverrideDefinitions(
36-
[
37-
[
38-
"--checkpoint.enable",
39-
],
40-
[
41-
"--checkpoint.enable",
42-
"--training.steps 20",
43-
],
44-
],
45-
"Checkpoint Integration Test - Save Load Full Checkpoint",
46-
"full_checkpoint",
47-
),
48-
OverrideDefinitions(
49-
[
50-
[
51-
"--checkpoint.enable",
52-
"--checkpoint.last_save_model_only",
53-
],
54-
],
55-
"Checkpoint Integration Test - Save Model Only fp32",
56-
"last_save_model_only_fp32",
57-
),
58-
# Parallelism tests.
59-
OverrideDefinitions(
60-
[
61-
[
62-
"--parallelism.data_parallel_shard_degree 4",
63-
"--parallelism.data_parallel_replicate_degree 1",
64-
]
65-
],
66-
"FSDP",
67-
"fsdp",
68-
ngpu=4,
69-
),
7023
OverrideDefinitions(
7124
[
7225
[
@@ -117,16 +70,16 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
11770
# Random init encoder for offline testing
11871
model_arg = "--model.name flux"
11972
random_init_encoder_arg = "--training.test_mode"
120-
clip_encoder_version_arg = "--encoder.clip_encoder torchtitan/experiments/flux/tests/assets/clip-vit-large-patch14/"
121-
t5_encoder_version_arg = (
122-
"--encoder.t5_encoder torchtitan/experiments/flux/tests/assets/t5-v1_1-xxl/"
73+
clip_encoder_version_arg = (
74+
"--encoder.clip_encoder tests/assets/clip-vit-large-patch14/"
12375
)
76+
t5_encoder_version_arg = "--encoder.t5_encoder tests/assets/t5-v1_1-xxl/"
12477
tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer"
12578

12679
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
12780

12881
for idx, override_arg in enumerate(test_flavor.override_args):
129-
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./torchtitan/experiments/flux/run_train.sh"
82+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./torchtitan/models/flux/run_train.sh"
13083
# dump compile trace for debugging purpose
13184
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
13285

@@ -135,7 +88,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir
13588
# For flux generation, test using inference script
13689
cmd = (
13790
f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
138-
f"./torchtitan/experiments/flux/inference/run_infer.sh"
91+
f"scripts/flux_inference/run_infer.sh"
13992
)
14093

14194
cmd += " " + model_arg
@@ -189,7 +142,7 @@ def main():
189142
parser.add_argument("output_dir")
190143
parser.add_argument(
191144
"--config_path",
192-
default="./torchtitan/experiments/flux/train_configs/debug_model.toml",
145+
default="./tests/integration_tests/base_config.toml",
193146
help="Base config path for integration tests. This is the config that will be used as a base for all tests.",
194147
)
195148
parser.add_argument(

tests/unit_tests/test_dataset_flux.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
_cc12m_wds_data_processor,
1717
build_flux_dataloader,
1818
DATASETS,
19-
2019
)
2120

2221

0 commit comments

Comments
 (0)