-
Notifications
You must be signed in to change notification settings - Fork 632
Run vLLM inference using torchtitan model definition (single GPU) #2119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
89736f4
c97316a
5f826dc
1e7ee17
0dc7a75
86e34df
f39f678
e203cf5
d1cb51b
42dac79
4fc1d16
8e64515
971f919
1659708
295e654
218336a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import argparse | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.logger import init_logger | ||
|
|
||
| # Import models module - this automatically registers TorchTitan models with vLLM | ||
| from torchtitan.experiments.deterministic_vllm_rl import models # noqa: F401 | ||
|
|
||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Run TorchTitan model inference with vLLM Engine", | ||
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
| ) | ||
| parser.add_argument( | ||
| "--model_ckpt_path", | ||
| type=str, | ||
| default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B", | ||
| help="Path to TorchTitan checkpoint directory", | ||
| ) | ||
| parser.add_argument( | ||
| "--prompt", | ||
| type=str, | ||
| default="Hello, my name is", | ||
| help="Prompt text for generation", | ||
| ) | ||
| parser.add_argument( | ||
| "--max-tokens", | ||
| type=int, | ||
| default=100, | ||
| help="Maximum number of tokens to generate", | ||
| ) | ||
| parser.add_argument( | ||
| "--temperature", | ||
| type=float, | ||
| default=0.8, | ||
| help="Sampling temperature", | ||
| ) | ||
| parser.add_argument( | ||
| "--tensor-parallel-size", | ||
| type=int, | ||
| default=1, | ||
| help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
|
|
||
| logger.info("Initializing vLLM with TorchTitan model") | ||
| logger.info(f"Model: {args.model_ckpt_path}") | ||
| logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") | ||
|
|
||
| # Initialize vLLM with custom TorchTitan model | ||
| # The LLM initialization will internally: | ||
| # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) | ||
| # 2. Create TorchTitanVLLMModel instance | ||
| # 3. Create JobConfig and ParallelDims from vLLM config | ||
| # 4. Apply parallelization using parallelize_qwen3 | ||
| # 5. Load model weights and prepare for inference | ||
| logger.info("Creating vLLM LLM engine...") | ||
|
|
||
| llm = LLM( | ||
| model=args.model_ckpt_path, # Model checkpoint path | ||
| hf_overrides={ | ||
| "checkpoint_dir": args.model_ckpt_path, | ||
| }, | ||
| dtype="bfloat16", | ||
| trust_remote_code=True, | ||
| enforce_eager=True, # Use eager mode | ||
| tensor_parallel_size=args.tensor_parallel_size, | ||
| ) | ||
|
|
||
| logger.info("vLLM engine initialized successfully") | ||
| logger.info(f"Prompt: {args.prompt}") | ||
|
|
||
| # Prepare prompt and sampling parameters | ||
| prompts = [args.prompt] | ||
| sampling_params = SamplingParams( | ||
| temperature=args.temperature, | ||
| top_p=0.95, | ||
| max_tokens=args.max_tokens, | ||
| ) | ||
|
|
||
| # Generate text | ||
| logger.info("Generating text...") | ||
| outputs = llm.generate( | ||
| prompts=prompts, | ||
| sampling_params=sampling_params, | ||
| ) | ||
|
|
||
| # Print results | ||
| logger.info("Generation complete") | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
|
|
||
| print(f"\nPrompt: {prompt}") | ||
| print(f"Generated text: {generated_text!r}\n") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,8 +6,77 @@ | |
|
|
||
| """ | ||
| Models for deterministic vLLM RL training. | ||
|
|
||
| This module automatically registers TorchTitan models with vLLM when imported. | ||
| """ | ||
|
|
||
| from .attention import VLLMCompatibleFlashAttention | ||
| from vllm.logger import init_logger | ||
|
|
||
| from torchtitan.protocols.train_spec import get_train_spec, TrainSpec | ||
| from .attention import VLLMCompatibleFlashAttention, VLLMPagedFlashAttention | ||
| from .vllm_wrapper import TorchTitanVLLMModel | ||
|
|
||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def register_torchtitan_model_from_train_spec( | ||
| train_spec: TrainSpec, | ||
| model_name: str, | ||
| ) -> None: | ||
| """ | ||
| Register a TorchTitan model with vLLM using a TrainSpec. | ||
|
|
||
| Args: | ||
| train_spec: TorchTitan TrainSpec containing model components | ||
| model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") | ||
|
|
||
| """ | ||
| from vllm.model_executor.models.registry import ModelRegistry | ||
|
|
||
| # Extract model_args from TrainSpec | ||
| # TrainSpec has model_args as a Mapping, get the first value | ||
| if isinstance(train_spec.model_args, dict): | ||
| model_args_cls = type(next(iter(train_spec.model_args.values()))) | ||
| else: | ||
| model_args_cls = train_spec.model_args | ||
|
|
||
| # Create dynamic model class directly from TrainSpec components | ||
| class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModel): | ||
| """Dynamically created vLLM model from TrainSpec.""" | ||
|
|
||
| def __init__(self, *, vllm_config, prefix=""): | ||
| super().__init__( | ||
| model_cls=train_spec.model_cls, | ||
| model_args_cls=model_args_cls, | ||
| state_dict_adapter=train_spec.state_dict_adapter, | ||
| parallelize_fn=train_spec.parallelize_fn, | ||
|
Comment on lines
+50
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we need these fields and the wrappers This is hacky and making things complicated as we are dumping a lot of logic (originally in train.py and checkpoint.py) to the model code itself. I feel this is unnecessary if our end goal is to use the engine part of vLLM, not the model init part.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Agreed, the main blocker is that we need to have control of how Worker instantiate a model. According to vllm design , this class is not only a model nn.module, but a |
||
| vllm_config=vllm_config, | ||
| prefix=prefix, | ||
| ) | ||
|
|
||
| # Set the class name | ||
| TorchTitanVLLMModelFromSpec.__name__ = model_name | ||
| TorchTitanVLLMModelFromSpec.__qualname__ = model_name | ||
|
|
||
| # Register with vLLM | ||
| ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) | ||
|
|
||
| logger.info( | ||
| f"Successfully registered {model_name} with vLLM using TrainSpec " | ||
| f"(model_cls={train_spec.model_cls.__name__})" | ||
| ) | ||
|
|
||
|
|
||
| # Auto-register TorchTitan models with vLLM when this module is imported | ||
| register_torchtitan_model_from_train_spec( | ||
| train_spec=get_train_spec("qwen3"), | ||
| model_name="Qwen3TorchTitanForCausalLM", | ||
| ) | ||
|
|
||
|
|
||
| __all__ = ["VLLMCompatibleFlashAttention"] | ||
| __all__ = [ | ||
| "VLLMCompatibleFlashAttention", | ||
| "VLLMPagedFlashAttention", | ||
| "TorchTitanVLLMModel", | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not urgent, but we should use "our" config system in the long term
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curretnly the entry point is vllm engine, so we are taking the config from whatever vllm engine passed to us. Let me check vllm engine see if there's anything we could do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait how is it related to vllm config system? You are just using them as is in
args = parse_args().There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
argsis only forinfer.pyscript, it will pass args into vllm engine LLM() , and vllm engine will create a VLLMConfig instance internally, and pass to our model wrapperThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's passing the
argsto LLM(). What would be different if we use our config manager to constructargs?