Skip to content

Commit f13c235

Browse files
author
hyunsooha
committed
Revert removal of eos_id_args
1 parent 50b10b8 commit f13c235

File tree

6 files changed

+6
-0
lines changed

6 files changed

+6
-0
lines changed

torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(self, model_args: Qwen3ModelArgs):
288288
self.model_args = model_args
289289
self.vocab_size = model_args.vocab_size
290290
self.n_layers = model_args.n_layers
291+
self.eos_id = model_args.eos_id
291292
self.head_dim = model_args.head_dim
292293

293294
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

torchtitan/experiments/deterministic_vllm_rl/simple_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr
332332
max_seq_len=getattr(hf_config, "max_position_embeddings", 32768),
333333
qk_norm=True,
334334
depth_init=True,
335+
eos_id=getattr(hf_config, "eos_token_id", 151645),
335336
)
336337

337338
# state_dict is in standard TorchTitan format (w1, w2, w3)

torchtitan/experiments/transformers_backend/model/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs):
5454
"n_kv_heads": "num_key_value_heads",
5555
"norm_eps": "rms_norm_eps",
5656
"max_seq_len": "max_position_embeddings",
57+
"eos_id": "eos_token_id",
5758
}
5859
}
5960

torchtitan/models/llama3/model/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TransformerModelArgs(BaseModelArgs):
4545

4646
use_flex_attn: bool = False
4747
attn_mask_type: str = "causal"
48+
eos_id: int = 0
4849

4950
def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5051
seq_len = job_config.training.seq_len

torchtitan/models/qwen3/model/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Qwen3ModelArgs(BaseModelArgs):
3838

3939
use_flex_attn: bool = False
4040
attn_mask_type: str = "causal"
41+
eos_id: int = 151645
4142

4243
enable_weight_tying: bool = False
4344

torchtitan/models/qwen3/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def __init__(self, model_args: Qwen3ModelArgs):
384384
self.model_args = model_args
385385
self.vocab_size = model_args.vocab_size
386386
self.n_layers = model_args.n_layers
387+
self.eos_id = model_args.eos_id
387388
self.head_dim = model_args.head_dim
388389

389390
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

0 commit comments

Comments
 (0)