Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/llm/components/disagg_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import logging

from utils.ns import get_namespace

from dynamo.runtime import EtcdKvCache
from dynamo.sdk import dynamo_context

Expand All @@ -38,7 +40,7 @@ async def async_init(self):
runtime = dynamo_context["runtime"]
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/disagg_router/",
f"/{get_namespace()}/disagg_router/",
{
"max_local_prefill_length": str(self.max_local_prefill_length),
"max_prefill_queue_size": str(self.max_prefill_queue_size),
Expand Down
20 changes: 16 additions & 4 deletions examples/llm/components/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from components.worker import VllmWorker
from fastapi import FastAPI
from pydantic import BaseModel
from utils.ns import get_namespace

from dynamo import sdk
from dynamo.sdk import async_on_shutdown, depends, service
Expand All @@ -44,15 +45,14 @@ class FrontendConfig(BaseModel):
"""Configuration for the Frontend service including model and HTTP server settings."""

served_model_name: str
endpoint: str
port: int = 8080


# todo this should be called ApiServer
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"namespace": get_namespace(),
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
Expand All @@ -78,6 +78,8 @@ def setup_model(self):
subprocess.run(
[
"llmctl",
"-n",
get_namespace(),
"http",
"remove",
"chat-models",
Expand All @@ -88,11 +90,13 @@ def setup_model(self):
subprocess.run(
[
"llmctl",
"-n",
get_namespace(),
"http",
"add",
"chat-models",
self.frontend_config.served_model_name,
self.frontend_config.endpoint,
f"{get_namespace()}.Processor.chat/completions",
],
check=False,
)
Expand All @@ -103,7 +107,13 @@ def start_http_server(self):
http_binary = get_http_binary_path()

self.process = subprocess.Popen(
[http_binary, "-p", str(self.frontend_config.port)],
[
http_binary,
"-p",
str(self.frontend_config.port),
"--namespace",
get_namespace(),
],
stdout=None,
stderr=None,
)
Expand All @@ -116,6 +126,8 @@ def cleanup(self):
subprocess.run(
[
"llmctl",
"-n",
get_namespace(),
"http",
"remove",
"chat-models",
Expand Down
7 changes: 4 additions & 3 deletions examples/llm/components/kv_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from components.worker import VllmWorker
from utils.logging import check_required_workers
from utils.ns import get_namespace
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger

Expand Down Expand Up @@ -70,7 +71,7 @@ def parse_args(service_name, prefix) -> Namespace:
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"namespace": get_namespace(),
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
Expand All @@ -96,15 +97,15 @@ def __init__(self):
async def async_init(self):
self.runtime = dynamo_context["runtime"]
self.workers_client = (
await self.runtime.namespace("dynamo")
await self.runtime.namespace(get_namespace())
.component("VllmWorker")
.endpoint("generate")
.client()
)

await check_required_workers(self.workers_client, self.args.min_workers)

kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
kv_listener = self.runtime.namespace(get_namespace()).component("VllmWorker")
await kv_listener.create_service()
self.indexer = KvIndexer(kv_listener, self.args.block_size)
self.metrics_aggregator = KvMetricsAggregator(kv_listener)
Expand Down
15 changes: 10 additions & 5 deletions examples/llm/components/prefill_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pydantic import BaseModel
from utils.nixl import NixlMetadataStore
from utils.ns import get_namespace
from utils.prefill_queue import PrefillQueue
from utils.vllm import parse_vllm_args
from vllm.entrypoints.openai.api_server import (
Expand All @@ -42,7 +43,7 @@ class RequestType(BaseModel):
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"namespace": get_namespace(),
"custom_lease": LeaseConfig(ttl=1), # 1 second
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
Expand Down Expand Up @@ -87,7 +88,7 @@ async def async_init(self):
raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"]
metadata = self.engine_client.nixl_metadata
self._metadata_store = NixlMetadataStore("dynamo", runtime)
self._metadata_store = NixlMetadataStore(get_namespace(), runtime)
await self._metadata_store.put(metadata.engine_id, metadata)
self.task = asyncio.create_task(self.prefill_queue_handler())

Expand Down Expand Up @@ -119,9 +120,13 @@ async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
get_namespace()
+ "_"
+ (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
)
logger.info(
f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}"
Expand Down
7 changes: 4 additions & 3 deletions examples/llm/components/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers
from utils.ns import get_namespace
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -45,7 +46,7 @@ class RequestType(Enum):
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"namespace": get_namespace(),
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
Expand Down Expand Up @@ -106,13 +107,13 @@ async def async_init(self):

await check_required_workers(self.worker_client, self.min_workers)

kv_listener = runtime.namespace("dynamo").component("VllmWorker")
kv_listener = runtime.namespace(get_namespace()).component("VllmWorker")
await kv_listener.create_service()
self.metrics_aggregator = KvMetricsAggregator(kv_listener)

self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/processor/",
f"/{get_namespace()}/processor/",
{"router": self.engine_args.router},
)

Expand Down
9 changes: 5 additions & 4 deletions examples/llm/components/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from components.disagg_router import PyDisaggregatedRouter
from components.prefill_worker import PrefillWorker
from utils.nixl import NixlMetadataStore
from utils.ns import get_namespace
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args
Expand All @@ -41,7 +42,7 @@
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
"namespace": get_namespace(),
"custom_lease": LeaseConfig(ttl=1), # 1 second
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
Expand All @@ -64,7 +65,7 @@ def __init__(self):
self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222"
)
self._prefill_queue_stream_name = self.model_name
self._prefill_queue_stream_name = get_namespace() + "_" + self.model_name
logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
)
Expand All @@ -90,7 +91,7 @@ def __init__(self):
self.engine_args.enable_prefix_caching = True

os.environ["VLLM_WORKER_ID"] = str(dynamo_context.get("lease").id())
os.environ["VLLM_KV_NAMESPACE"] = "dynamo"
os.environ["VLLM_KV_NAMESPACE"] = get_namespace()
os.environ["VLLM_KV_COMPONENT"] = class_name

self.metrics_publisher = KvMetricsPublisher()
Expand Down Expand Up @@ -128,7 +129,7 @@ async def async_init(self):

if self.engine_args.remote_prefill:
metadata = self.engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo", runtime)
metadata_store = NixlMetadataStore(get_namespace(), runtime)
await metadata_store.put(metadata.engine_id, metadata)

if self.engine_args.conditional_disagg:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Common:

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/agg_router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Common:

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/disagg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Common:

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/disagg_router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Common:

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/multinode-405b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

Frontend:
served_model_name: nvidia/Llama-3.1-405B-Instruct-FP8
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/multinode_agg_r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Common:

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
1 change: 0 additions & 1 deletion examples/llm/configs/mutinode_disagg_r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Common:

Frontend:
served_model_name: deepseek-ai/DeepSeek-R1
endpoint: dynamo.Processor.chat/completions
port: 8000

Processor:
Expand Down
21 changes: 21 additions & 0 deletions examples/llm/utils/ns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os


def get_namespace():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: @ishandhanani , @mohammedabdulwahhab is something we could check from command line arg / graph for default as well? that could be a way to set the default and override in the graph config ...

return os.getenv("DYN_NAMESPACE", "dynamo")
Loading