Skip to content

Commit 226150d

Browse files
committed
refactor v2
1 parent 0c4d28a commit 226150d

File tree

6 files changed

+257
-210
lines changed

6 files changed

+257
-210
lines changed

torchtitan/experiments/deterministic_vllm_rl/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections
1616
- batch_invariant_backward: Gradient support for vLLM's deterministic operations
1717
- simple_rl: End-to-end RL training loop
18+
- TorchTitanVLLMModel: Generic wrapper for TorchTitan models with vLLM
19+
20+
For vLLM inference with TorchTitan models, see:
21+
- models/base_wrapper.py: Core vLLM wrapper
22+
- models/__init__.py: Auto-registration with vLLM
23+
- infer.py: Example inference script
1824
"""
1925

2026
from .batch_invariant_backward import (
@@ -23,12 +29,15 @@
2329
silu_and_mul_with_gradients,
2430
)
2531
from .models import VLLMCompatibleFlashAttention
32+
from .models.base_wrapper import TorchTitanVLLMModel
2633
from .models.qwen3 import Qwen3VLLMCompatModel
2734

35+
2836
__all__ = [
2937
"VLLMCompatibleFlashAttention",
3038
"Qwen3VLLMCompatModel",
3139
"enable_batch_invariant_backward_mode",
3240
"rms_norm_with_gradients",
3341
"silu_and_mul_with_gradients",
42+
"TorchTitanVLLMModel",
3443
]

torchtitan/experiments/deterministic_vllm_rl/infer.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,12 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
"""
9-
Example CLI to run TorchTitan Qwen3 model inference with vLLM:
10-
11-
# Run inference
12-
python torchtitan/experiments/deterministic_vllm_rl/infer.py
13-
"""
14-
158
import argparse
169

1710
from vllm import LLM, SamplingParams
1811

19-
# Import and register the TorchTitan vLLM plugin
20-
from torchtitan.experiments.deterministic_vllm_rl.register import register
21-
22-
# Register TorchTitan models with vLLM.
23-
# NOTE(jianiw): We could use plug-in system instead: https://docs.vllm.ai/en/latest/design/plugin_system/
24-
register()
12+
# Import models module - this automatically registers TorchTitan models with vLLM
13+
from torchtitan.experiments.deterministic_vllm_rl import models # noqa: F401
2514

2615

2716
def parse_args():
@@ -66,39 +55,50 @@ def main():
6655
args = parse_args()
6756

6857
print("=" * 80)
69-
print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL")
58+
print("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL ")
7059
print("=" * 80)
60+
print(f"Model: {args.model}")
61+
print(f"Tensor Parallel Size: {args.tensor_parallel_size}")
62+
print()
7163

7264
# Build hf_overrides with checkpoint path
7365
hf_overrides = {
7466
"checkpoint_dir": args.model,
7567
}
7668

7769
# Initialize vLLM with custom TorchTitan Qwen3 model
70+
# The LLM initialization will internally:
71+
# 1. Load TrainSpec for Qwen3 (from register())
72+
# 2. Create TorchTitanVLLMModel instance
73+
# 3. Process parallelism settings via process_parallelism_settings()
74+
# 4. Build device mesh and apply parallelization via build_device_mesh_and_parallelize()
75+
# 5. Load model weights and prepare for inference
76+
print("Initializing vLLM engine...")
7877
llm = LLM(
79-
model=args.model, # Use temporary directory with config.json
78+
model=args.model, # Model checkpoint path
8079
hf_overrides=hf_overrides,
8180
dtype="bfloat16",
8281
trust_remote_code=True,
83-
enforce_eager=True, # Use eager mode for debugging
84-
# Disable kv cache, required for now
85-
enable_prefix_caching=False,
82+
enforce_eager=True, # Use eager mode
83+
enable_prefix_caching=False, # Disable kv cache for now
8684
tensor_parallel_size=args.tensor_parallel_size, # Multi-GPU support
8785
)
8886

8987
print("=" * 80)
90-
print("vLLM ENGINE INITIALIZED - STARTING GENERATION")
88+
print("vLLM ENGINE INITIALIZED - CONFIGURATION DETAILS")
9189
print("=" * 80)
90+
print(f"Prompt: {args.prompt}")
91+
print()
9292

93-
# Prepare prompt
93+
# Prepare prompt and sampling parameters
9494
prompts = [args.prompt]
9595
sampling_params = SamplingParams(
9696
temperature=args.temperature,
9797
top_p=0.95,
9898
max_tokens=args.max_tokens,
9999
)
100100

101-
# Generate
101+
# Generate text
102102
outputs = llm.generate(
103103
prompts=prompts,
104104
sampling_params=sampling_params,
@@ -109,8 +109,9 @@ def main():
109109
prompt = output.prompt
110110
generated_text = output.outputs[0].text
111111

112-
print(f"\nPrompt: {prompt}")
112+
print(f"Prompt: {prompt}")
113113
print(f"Generated text: {generated_text!r}")
114+
print()
114115

115116

116117
if __name__ == "__main__":

torchtitan/experiments/deterministic_vllm_rl/models/__init__.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,77 @@
66

77
"""
88
Models for deterministic vLLM RL training.
9+
10+
This module automatically registers TorchTitan models with vLLM when imported.
911
"""
1012

13+
from vllm.logger import init_logger
14+
15+
from torchtitan.protocols.train_spec import get_train_spec, TrainSpec
1116
from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention
17+
from .base_wrapper import TorchTitanVLLMModel
18+
19+
20+
logger = init_logger(__name__)
21+
22+
23+
def register_torchtitan_model_from_train_spec(
24+
train_spec: TrainSpec,
25+
model_name: str,
26+
) -> None:
27+
"""
28+
Register a TorchTitan model with vLLM using a TrainSpec.
29+
30+
Args:
31+
train_spec: TorchTitan TrainSpec containing model components
32+
model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM")
33+
34+
"""
35+
from vllm.model_executor.models.registry import ModelRegistry
36+
37+
# Extract model_args from TrainSpec
38+
# TrainSpec has model_args as a Mapping, get the first value
39+
if isinstance(train_spec.model_args, dict):
40+
model_args_cls = type(next(iter(train_spec.model_args.values())))
41+
else:
42+
model_args_cls = train_spec.model_args
43+
44+
# Create dynamic model class directly from TrainSpec components
45+
class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModel):
46+
"""Dynamically created vLLM model from TrainSpec."""
47+
48+
def __init__(self, *, vllm_config, prefix=""):
49+
super().__init__(
50+
model_cls=train_spec.model_cls,
51+
model_args_cls=model_args_cls,
52+
state_dict_adapter=train_spec.state_dict_adapter,
53+
parallelize_fn=train_spec.parallelize_fn,
54+
vllm_config=vllm_config,
55+
prefix=prefix,
56+
)
57+
58+
# Set the class name
59+
TorchTitanVLLMModelFromSpec.__name__ = model_name
60+
TorchTitanVLLMModelFromSpec.__qualname__ = model_name
61+
62+
# Register with vLLM
63+
ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec)
64+
65+
logger.info(
66+
f"Successfully registered {model_name} with vLLM using TrainSpec "
67+
f"(model_cls={train_spec.model_cls.__name__})"
68+
)
69+
70+
71+
# Auto-register TorchTitan models with vLLM when this module is imported
72+
register_torchtitan_model_from_train_spec(
73+
train_spec=get_train_spec("qwen3"),
74+
model_name="Qwen3TorchTitanForCausalLM",
75+
)
76+
1277

13-
__all__ = ["VLLMCompatibleFlashAttention", "VLLMPagedFlashAttention"]
78+
__all__ = [
79+
"VLLMCompatibleFlashAttention",
80+
"VLLMPagedFlashAttention",
81+
"TorchTitanVLLMModel",
82+
]

0 commit comments

Comments
 (0)