diff --git a/examples/ray_orchestrator/llm_inference_async_ray.py b/examples/ray_orchestrator/llm_inference_async_ray.py index ea57975291a..87b29d48e75 100644 --- a/examples/ray_orchestrator/llm_inference_async_ray.py +++ b/examples/ray_orchestrator/llm_inference_async_ray.py @@ -1,4 +1,5 @@ # Generate text asynchronously with Ray orchestrator. +import argparse import asyncio from tensorrt_llm import LLM, SamplingParams @@ -6,6 +7,16 @@ def main(): + parser = argparse.ArgumentParser( + description="Generate text asynchronously with Ray orchestrator.") + parser.add_argument( + "--model", + type=str, + default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + help= + "HuggingFace model name or path to local HF model (default: TinyLlama/TinyLlama-1.1B-Chat-v1.0)" + ) + args = parser.parse_args() # Configure KV cache memory usage fraction. kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5, max_tokens=4096, @@ -13,7 +24,7 @@ def main(): # model could accept HF model name or a path to local HF model. llm = LLM( - model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + model=args.model, kv_cache_config=kv_cache_config, max_seq_len=1024, max_batch_size=1, diff --git a/tests/integration/defs/examples/test_ray.py b/tests/integration/defs/examples/test_ray.py index 5ced89ea87a..9d2666d8430 100644 --- a/tests/integration/defs/examples/test_ray.py +++ b/tests/integration/defs/examples/test_ray.py @@ -14,7 +14,8 @@ def ray_example_root(llm_root): def test_llm_inference_async_ray(ray_example_root, llm_venv): script_path = os.path.join(ray_example_root, "llm_inference_async_ray.py") - venv_check_call(llm_venv, [script_path]) + model_path = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + venv_check_call(llm_venv, [script_path, "--model", model_path]) @pytest.mark.skip_less_device(2)