55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8- """
9- Example CLI to run TorchTitan Qwen3 model inference with vLLM:
10-
11- # Run inference
12- python torchtitan/experiments/deterministic_vllm_rl/infer.py
13- """
14-
158import argparse
169
1710from vllm import LLM , SamplingParams
1811
19- # Import and register the TorchTitan vLLM plugin
20- from torchtitan .experiments .deterministic_vllm_rl .register import register
21-
22- # Register TorchTitan models with vLLM.
23- # NOTE(jianiw): We could use plug-in system instead: https://docs.vllm.ai/en/latest/design/plugin_system/
24- register ()
12+ # Import models module - this automatically registers TorchTitan models with vLLM
13+ from torchtitan .experiments .deterministic_vllm_rl import models # noqa: F401
2514
2615
2716def parse_args ():
@@ -66,39 +55,50 @@ def main():
6655 args = parse_args ()
6756
6857 print ("=" * 80 )
69- print ("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL" )
58+ print ("INITIALIZING vLLM WITH TORCHTITAN QWEN3 MODEL " )
7059 print ("=" * 80 )
60+ print (f"Model: { args .model } " )
61+ print (f"Tensor Parallel Size: { args .tensor_parallel_size } " )
62+ print ()
7163
7264 # Build hf_overrides with checkpoint path
7365 hf_overrides = {
7466 "checkpoint_dir" : args .model ,
7567 }
7668
7769 # Initialize vLLM with custom TorchTitan Qwen3 model
70+ # The LLM initialization will internally:
71+ # 1. Load TrainSpec for Qwen3 (from register())
72+ # 2. Create TorchTitanVLLMModel instance
73+ # 3. Process parallelism settings via process_parallelism_settings()
74+ # 4. Build device mesh and apply parallelization via build_device_mesh_and_parallelize()
75+ # 5. Load model weights and prepare for inference
76+ print ("Initializing vLLM engine..." )
7877 llm = LLM (
79- model = args .model , # Use temporary directory with config.json
78+ model = args .model , # Model checkpoint path
8079 hf_overrides = hf_overrides ,
8180 dtype = "bfloat16" ,
8281 trust_remote_code = True ,
83- enforce_eager = True , # Use eager mode for debugging
84- # Disable kv cache, required for now
85- enable_prefix_caching = False ,
82+ enforce_eager = True , # Use eager mode
83+ enable_prefix_caching = False , # Disable kv cache for now
8684 tensor_parallel_size = args .tensor_parallel_size , # Multi-GPU support
8785 )
8886
8987 print ("=" * 80 )
90- print ("vLLM ENGINE INITIALIZED - STARTING GENERATION " )
88+ print ("vLLM ENGINE INITIALIZED - CONFIGURATION DETAILS " )
9189 print ("=" * 80 )
90+ print (f"Prompt: { args .prompt } " )
91+ print ()
9292
93- # Prepare prompt
93+ # Prepare prompt and sampling parameters
9494 prompts = [args .prompt ]
9595 sampling_params = SamplingParams (
9696 temperature = args .temperature ,
9797 top_p = 0.95 ,
9898 max_tokens = args .max_tokens ,
9999 )
100100
101- # Generate
101+ # Generate text
102102 outputs = llm .generate (
103103 prompts = prompts ,
104104 sampling_params = sampling_params ,
@@ -109,8 +109,9 @@ def main():
109109 prompt = output .prompt
110110 generated_text = output .outputs [0 ].text
111111
112- print (f"\n Prompt : { prompt } " )
112+ print (f"Prompt : { prompt } " )
113113 print (f"Generated text: { generated_text !r} " )
114+ print ()
114115
115116
116117if __name__ == "__main__" :
0 commit comments