Skip to content

Commit 7c10480

Browse files
authored
[refactor] graduate custom_config_module and unify args/config naming (#1871)
In the past, the terms "args" and "config" have been used in a mix. To make it unambiguous, in torchtitan we use - "args" as in `ModelArgs` to refer to parameters used to define a model in model code - "config" as in `JobConfig` to refer to configurable training job commands used in training script This also PR also moves `custom_args_module` (which should be `custom_config_module` according to the naming rule above) from `Experimental` to `Job`, as it has been extensively used by various models in torchtitan, especially those in the `experiments` folder.
1 parent 8c1d1c5 commit 7c10480

File tree

33 files changed

+106
-124
lines changed

33 files changed

+106
-124
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
/torchtitan/experiments/
1111

1212
# codeowners for experiments/forge
13-
/torchtitan/experiments/forge/* @ebsmothers @pbontrager @joecummings @allenwang28 @tianyu-l @wwwjn
13+
/torchtitan/experiments/forge/* @ebsmothers @pbontrager @joecummings @allenwang28 @tianyu-l @wwwjn @fegin

benchmarks/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ A submission should be a file / files including the following information
99
3. The hardware setup, including the types of GPUs, interconnections, etc.
1010
4. The actual performance report with training configs, e.g. via
1111
- `.toml` files / commandline arguments
12-
- complete configs, which can be found in the log with [`--print_args`](https://github.com/pytorch/torchtitan/blob/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc/torchtitan/config_manager.py#L47) turned on (preferred as the default value not shown in `.toml` or specified in commandline could change from time to time)
12+
- complete configs, which can be found in the log with [`--print_config`](https://github.com/pytorch/torchtitan/blob/e7c0cae934df78d6e9c2835f42ff1f757dc3fddc/torchtitan/config_manager.py#L47) turned on (preferred as the default value not shown in `.toml` or specified in commandline could change from time to time)
1313
5. The versions and date/time of `torchtitan`, `torch`, `torchao`, or any relevant dependencies.
1414
6. Other notes which could help reproduce the results.
1515

docs/extension.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The extension points and protocols mentioned in this note are subject to change.
66
### `TrainSpec`
77

88
[`TrainSpec`](../torchtitan/protocols/train_spec.py) supports configuring high-level components in model training, including
9-
- definitions of model class and model args config
9+
- definitions of model class and model args
1010
- model parallelization functions
1111
- loss functions
1212
- factory methods for creating dataloader / tokenizer / optimizer / learning rate scheduler / metrics processor
@@ -36,7 +36,7 @@ This is an ongoing effort, and the level of grouping is subject to change.
3636

3737
### Extending `JobConfig`
3838

39-
[`JobConfig`](../torchtitan/config/job_config.py) supports custom extension through the `--experimental.custom_args_module` flag.
39+
[`JobConfig`](../torchtitan/config/job_config.py) supports custom extension through the `--job.custom_config_module` flag.
4040
This lets you define a custom module that extends `JobConfig` with additional fields.
4141

4242
When specified, your custom `JobConfig` is merged with the default:
@@ -45,14 +45,14 @@ When specified, your custom `JobConfig` is merged with the default:
4545

4646
#### Example
4747

48-
To add a custom `custom_args` section, define your own `JobConfig`:
48+
To add a custom `custom_config` section, define your own `JobConfig`:
4949

5050
```python
51-
# torchtitan/experiments/your_folder/custom_args.py
51+
# torchtitan/experiments/your_folder/job_config.py
5252
from dataclasses import dataclass, field
5353

5454
@dataclass
55-
class CustomArgs:
55+
class CustomConfig:
5656
how_is_your_day: str = "good"
5757
"""Just an example."""
5858

@@ -68,19 +68,19 @@ class Training:
6868

6969
@dataclass
7070
class JobConfig:
71-
custom_args: CustomArgs = field(default_factory=CustomArgs)
71+
custom_config: CustomConfig = field(default_factory=CustomConfig)
7272
training: Training= field(default_factory=Training)
7373
```
7474

7575
Then run your script with:
7676

7777
```bash
78-
--experimental.custom_args_module=torchtitan.experiments.your_folder.custom_args
78+
--job.custom_config_module=torchtitan.experiments.your_folder.job_config
7979
```
8080

8181
Or specify it in your `.toml` config:
8282

8383
```toml
84-
[experimental]
85-
custom_args_module = "torchtitan.experiments.your_folder.custom_args"
84+
[job]
85+
custom_config_module = "torchtitan.experiments.your_folder.job_config"
8686
```

tests/assets/extend_jobconfig_example.py renamed to tests/assets/extended_job_config_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@dataclass
11-
class CustomArgs:
11+
class CustomConfig:
1212
how_is_your_day: str = "good"
1313
"""Just an example helptext"""
1414

@@ -28,5 +28,5 @@ class JobConfig:
2828
This is an example of how to extend the tyro parser with custom config classes.
2929
"""
3030

31-
custom_args: CustomArgs = field(default_factory=CustomArgs)
31+
custom_config: CustomConfig = field(default_factory=CustomConfig)
3232
training: Training = field(default_factory=Training)

tests/integration_tests/base_config.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[job]
22
dump_folder = "./outputs"
3-
description = "model debug training for integration test"
4-
print_args = false
3+
description = "model debug training for integration tests"
4+
print_config = false
55

66
[profiling]
77
enable_profiling = false

tests/unit_tests/test_job_config.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def test_print_help(self):
208208
parser = get_parser(ConfigManager)
209209
parser.print_help()
210210

211-
def test_extend_jobconfig_directly(self):
211+
def test_extend_job_config_directly(self):
212212
@dataclass
213213
class CustomCheckpoint:
214214
convert_path: str = "/custom/path"
@@ -233,19 +233,19 @@ class CustomJobConfig:
233233
assert hasattr(config, "model")
234234

235235
def test_custom_parser(self):
236-
path = "tests.assets.extend_jobconfig_example"
236+
path = "tests.assets.extended_job_config_example"
237237

238238
config_manager = ConfigManager()
239239
config = config_manager.parse_args(
240240
[
241-
f"--experimental.custom_args_module={path}",
242-
"--custom_args.how-is-your-day",
241+
f"--job.custom_config_module={path}",
242+
"--custom_config.how-is-your-day",
243243
"bad",
244244
"--model.converters",
245245
"float8,mxfp",
246246
]
247247
)
248-
assert config.custom_args.how_is_your_day == "bad"
248+
assert config.custom_config.how_is_your_day == "bad"
249249
assert config.model.converters == ["float8", "mxfp"]
250250
result = config.to_dict()
251251
assert isinstance(result, dict)
@@ -254,8 +254,8 @@ def test_custom_parser(self):
254254
with self.assertRaisesRegex(SystemExit, "2"):
255255
config = config_manager.parse_args(
256256
[
257-
f"--experimental.custom_args_module={path}",
258-
"--custom_args.how-is-your-day",
257+
f"--job.custom_config_module={path}",
258+
"--custom_config.how-is-your-day",
259259
"bad",
260260
"--model.converters",
261261
"float8,mxfp",
@@ -266,8 +266,8 @@ def test_custom_parser(self):
266266
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as fp:
267267
tomli_w.dump(
268268
{
269-
"experimental": {
270-
"custom_args_module": path,
269+
"job": {
270+
"custom_config_module": path,
271271
}
272272
},
273273
fp,
@@ -278,14 +278,14 @@ def test_custom_parser(self):
278278
config = config_manager.parse_args(
279279
[
280280
f"--job.config_file={fp.name}",
281-
f"--experimental.custom_args_module={path}",
282-
"--custom_args.how-is-your-day",
281+
f"--job.custom_config_module={path}",
282+
"--custom_config.how-is-your-day",
283283
"bad",
284284
"--model.converters",
285285
"float8,mxfp",
286286
]
287287
)
288-
assert config.custom_args.how_is_your_day == "bad"
288+
assert config.custom_config.how_is_your_day == "bad"
289289
assert config.training.my_custom_steps == 32
290290
assert config.model.converters == ["float8", "mxfp"]
291291
result = config.to_dict()
@@ -294,10 +294,10 @@ def test_custom_parser(self):
294294
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as fp:
295295
tomli_w.dump(
296296
{
297-
"experimental": {
298-
"custom_args_module": path,
297+
"job": {
298+
"custom_config_module": path,
299299
},
300-
"custom_args": {"how_is_your_day": "really good"},
300+
"custom_config": {"how_is_your_day": "really good"},
301301
"model": {"converters": ["float8", "mxfp"]},
302302
},
303303
fp,
@@ -311,7 +311,7 @@ def test_custom_parser(self):
311311
]
312312
)
313313

314-
assert config.custom_args.how_is_your_day == "really good"
314+
assert config.custom_config.how_is_your_day == "really good"
315315
assert config.model.converters == ["float8", "mxfp"]
316316
result = config.to_dict()
317317
assert isinstance(result, dict)

torchtitan/config/job_config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@ class Job:
1919
description: str = "default job"
2020
"""Description of the job"""
2121

22-
print_args: bool = False
23-
"""Print the args to terminal"""
22+
print_config: bool = False
23+
"""Print the configs to terminal"""
24+
25+
custom_config_module: str = ""
26+
"""
27+
This option allows users to extend the existing JobConfig with a customized
28+
JobConfig dataclass. Users need to ensure that the path can be imported.
29+
"""
2430

2531

2632
@dataclass
@@ -834,6 +840,8 @@ class Experimental:
834840

835841
custom_args_module: str = ""
836842
"""
843+
DEPRECATED (moved to Job.custom_config_module). Will be removed soon.
844+
837845
This option allows users to extend TorchTitan's existing JobConfig by extending
838846
a user defined JobConfig dataclass. Similar to ``--experimental.custom_model_path``, the user
839847
needs to ensure that the path can be imported.

torchtitan/config/manager.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, config_cls: Type[JobConfig] = JobConfig):
4545

4646
def parse_args(self, args: list[str] = sys.argv[1:]) -> JobConfig:
4747
toml_values = self._maybe_load_toml(args)
48-
config_cls = self._maybe_add_custom_args(args, toml_values)
48+
config_cls = self._maybe_add_custom_config(args, toml_values)
4949

5050
base_config = (
5151
self._dict_to_dataclass(config_cls, toml_values)
@@ -83,16 +83,19 @@ def _maybe_load_toml(self, args: list[str]) -> dict[str, Any] | None:
8383
logger.exception(f"Error while loading config file: {file_path}")
8484
raise e
8585

86-
def _maybe_add_custom_args(
86+
def _maybe_add_custom_config(
8787
self, args: list[str], toml_values: dict[str, Any] | None
8888
) -> Type[JobConfig]: # noqa: B006
89-
"""Find and merge custom arguments module with current JobConfig class"""
89+
"""
90+
Find and merge custom config module with current JobConfig class, if it is given.
91+
The search order is first searching CLI args, then toml config file.
92+
"""
9093
module_path = None
9194

9295
# 1. Check CLI
9396
valid_keys = {
94-
"--experimental.custom_args_module",
95-
"--experimental.custom-args-module",
97+
"--job.custom_config_module",
98+
"--job.custom-config-module",
9699
}
97100
for i, arg in enumerate(args):
98101
key = arg.split("=")[0]
@@ -102,9 +105,9 @@ def _maybe_add_custom_args(
102105

103106
# 2. If not found in CLI, check TOML
104107
if not module_path and toml_values:
105-
experimental = toml_values.get("experimental", {})
106-
if isinstance(experimental, dict):
107-
module_path = experimental.get("custom_args_module")
108+
job = toml_values.get("job", {})
109+
if isinstance(job, dict):
110+
module_path = job.get("custom_config_module")
108111

109112
if not module_path:
110113
return self.config_cls
@@ -178,6 +181,15 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any:
178181
return cls(**result)
179182

180183
def _validate_config(self) -> None:
184+
if self.config.experimental.custom_args_module:
185+
logger.warning(
186+
"This field is being moved to --job.custom_config_module and "
187+
"will be deprecated soon. Setting job.custom_config_module to "
188+
"experimental.custom_args_module temporarily."
189+
)
190+
self.config.job.custom_config_module = (
191+
self.config.experimental.custom_args_module
192+
)
181193
# TODO: temporary mitigation of BC breaking change in hf_assets_path
182194
# tokenizer default path, need to remove later
183195
if self.config.model.tokenizer_path:

torchtitan/experiments/flux/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Run the following command to train the model on a single GPU:
2929

3030
```
3131

32-
If you want to train with other model config, run the following command:
32+
If you want to train with other model args, run the following command:
3333
```bash
3434
CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh
3535
```

torchtitan/experiments/flux/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
#
7-
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8-
96

107
from torchtitan.components.lr_scheduler import build_lr_schedulers
118
from torchtitan.components.optimizer import build_optimizers
@@ -28,7 +25,7 @@
2825
]
2926

3027

31-
flux_configs = {
28+
flux_args = {
3229
"flux-dev": FluxModelArgs(
3330
in_channels=64,
3431
out_channels=64,
@@ -110,7 +107,7 @@
110107
def get_train_spec() -> TrainSpec:
111108
return TrainSpec(
112109
model_cls=FluxModel,
113-
model_args=flux_configs,
110+
model_args=flux_args,
114111
parallelize_fn=parallelize_flux,
115112
pipelining_fn=None,
116113
build_optimizers_fn=build_optimizers,

0 commit comments

Comments
 (0)