Skip to content

Commit 8d20f02

Browse files
authored
followup fix to #1740 (#1747)
My bad that forgot to update qwen3.
1 parent 686b523 commit 8d20f02

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

torchtitan/experiments/qwen3/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchtitan.components.validate import build_validator
1414
from torchtitan.datasets.hf_datasets import build_hf_dataloader
1515
from torchtitan.models.moe import MoEArgs
16-
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
16+
from torchtitan.protocols.train_spec import TrainSpec
1717

1818
from .infra.parallelize import parallelize_qwen3
1919
from .model.args import Qwen3ModelArgs
@@ -178,8 +178,8 @@
178178
}
179179

180180

181-
register_train_spec(
182-
TrainSpec(
181+
def get_train_spec() -> TrainSpec:
182+
return TrainSpec(
183183
name="qwen3",
184184
model_cls=Qwen3Model,
185185
model_args=qwen3_configs, # Change from dict to Mapping
@@ -193,4 +193,3 @@
193193
build_validator_fn=build_validator,
194194
state_dict_adapter=Qwen3StateDictAdapter,
195195
)
196-
)

0 commit comments

Comments
 (0)