Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion grpc_servicer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "smg-grpc-servicer"
version = "0.5.2"
version = "0.6.0"
description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang)"
requires-python = ">=3.10"
dependencies = [
Expand Down Expand Up @@ -32,6 +32,7 @@ classifiers = [
[project.optional-dependencies]
vllm = ["vllm>=0.19.0"]
sglang = ["sglang>=0.5.10"]
test = ["pytest>=7.0", "pytest-asyncio>=0.21", "pytest-timeout>=2.0"]

[project.urls]
Homepage = "https://github.com/lightseekorg/smg"
Expand Down
119 changes: 119 additions & 0 deletions grpc_servicer/smg_grpc_servicer/health_watch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
"""Shared Watch() continuous streaming for gRPC health servicers."""

import asyncio
import inspect
import logging
from collections.abc import AsyncIterator

import grpc
from grpc_health.v1 import health_pb2

logger = logging.getLogger(__name__)


class HealthWatchMixin:
"""Continuous Watch() streaming for gRPC health servicers.

Implements the gRPC Health Checking Protocol Watch RPC as a long-lived
server-streaming response that sends updates on status change.

Subclasses must:
1. Call self._init_watch() in __init__
2. Implement _compute_watch_status(service_name) -> ServingStatus
3. Implement _is_shutting_down() -> bool
4. Call self._notify_shutdown() in set_not_serving()
"""

WATCH_POLL_INTERVAL_S = 5.0

def _init_watch(self) -> None:
"""Initialize Watch state. Must be called in subclass __init__.

Note: on Python 3.10-3.11, asyncio.Event() captures the running
event loop at construction time. Both SGLang and vLLM construct
their servicers within the async server context (verified against
smg sglang/server.py and vllm grpc_server.py), so this is safe.
Python 3.12+ removed the loop binding entirely.
"""
self._watch_shutdown_event = asyncio.Event()
self._watch_notified_shutdown = False

def _notify_shutdown(self) -> None:
"""Wake all Watch streams to detect shutdown immediately.
Must be called in subclass set_not_serving(). Sets
_watch_notified_shutdown so _is_shutting_down() implementations
can check it alongside their own shutdown flags."""
self._watch_notified_shutdown = True
self._watch_shutdown_event.set()

def _compute_watch_status(self, service_name: str) -> int:
"""Compute current health status for the given service.

Must not call context.set_code() -- that would pollute the
streaming response. Return a ServingStatus enum value instead.

May be overridden as async def for servicers that need I/O
(e.g., VllmHealthServicer calls await async_llm.check_health()).
"""
raise NotImplementedError(f"{type(self).__name__} must implement _compute_watch_status()")

def _is_shutting_down(self) -> bool:
"""Return True if the server is shutting down."""
raise NotImplementedError(f"{type(self).__name__} must implement _is_shutting_down()")

async def _resolve_watch_status(self, service_name: str) -> int:
"""Call _compute_watch_status, handling both sync and async impls."""
result = self._compute_watch_status(service_name)
if inspect.isawaitable(result):
return await result
return result

async def Watch(
self,
request: health_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> AsyncIterator[health_pb2.HealthCheckResponse]:
"""gRPC Health Watch -- continuous streaming implementation.

Behavior per gRPC Health Checking Protocol:
- Immediately sends the current serving status
- Sends a new message whenever status changes
- Stream ends on server shutdown or client cancellation

Deviation from spec: for unknown services, the stream sends
SERVICE_UNKNOWN once then exits. The spec says to keep the stream
open for dynamic service registration, but smg services are
statically defined and never registered at runtime.
"""
service_name = request.service
logger.debug("Health watch request for service: '%s'", service_name)

last_status = None
try:
while True:
status = await self._resolve_watch_status(service_name)

if status != last_status:
yield health_pb2.HealthCheckResponse(status=status)
last_status = status

if self._is_shutting_down():
return
Comment on lines +101 to +102
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Emit shutdown status before terminating Watch stream

This early return can drop a real status transition during shutdown: if set_not_serving() runs after status was computed for the current iteration but before this check executes, the stream exits immediately without sending the final NOT_SERVING update. In that race window, Watch clients only observe EOF and miss the health-state change event they rely on for routing decisions.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Evaluated -- pushing back on this one.

The race requires set_not_serving() to fire between _resolve_watch_status() returning and _is_shutting_down() executing. In asyncio's single-threaded cooperative model:

  • SGLang: _compute_watch_status() is sync -- no yield point, race is impossible.
  • vLLM: async via await check_health(), but the window between the return of _resolve_watch_status() and the _is_shutting_down() check is a single Python statement with no await -- no coroutine switch can happen there.

Even in the theoretical case where shutdown lands between _resolve_watch_status yielding control (during check_health()) and the shutdown check: the client would see SERVING then EOF. Any conformant Watch client treats EOF as "server unavailable" -- the gRPC transport-level disconnect is the primary shutdown signal, not an in-band NOT_SERVING message.

Not worth the added complexity.


if status == health_pb2.HealthCheckResponse.SERVICE_UNKNOWN:
return

try:
await asyncio.wait_for(
self._watch_shutdown_event.wait(),
timeout=self.WATCH_POLL_INTERVAL_S,
)
except asyncio.TimeoutError:
pass

except asyncio.CancelledError:
logger.debug(
"Health watch cancelled by client for service: '%s'",
service_name,
)
68 changes: 37 additions & 31 deletions grpc_servicer/smg_grpc_servicer/sglang/health_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

import logging
import time
from collections.abc import AsyncIterator

import grpc
from grpc_health.v1 import health_pb2, health_pb2_grpc

from smg_grpc_servicer.health_watch import HealthWatchMixin

logger = logging.getLogger(__name__)

SERVING = health_pb2.HealthCheckResponse.SERVING
NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING
SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN


class SGLangHealthServicer(health_pb2_grpc.HealthServicer):
class SGLangHealthServicer(HealthWatchMixin, health_pb2_grpc.HealthServicer):
"""
Standard gRPC health check service implementation for Kubernetes probes.
Implements grpc.health.v1.Health protocol.
Expand All @@ -29,6 +34,8 @@ class SGLangHealthServicer(health_pb2_grpc.HealthServicer):
- SERVING: Model loaded and ready to serve requests
"""

SCHEDULER_RESPONSIVENESS_TIMEOUT_S = 30

# Service names we support
OVERALL_SERVER = "" # Empty string for overall server health
SGLANG_SERVICE = "sglang.grpc.scheduler.SglangScheduler"
Expand All @@ -49,6 +56,7 @@ def __init__(self, request_manager, scheduler_info: dict):
self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING
self._serving_status[self.SGLANG_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING

self._init_watch()
logger.info("Standard gRPC health service initialized")

def set_serving(self):
Expand All @@ -61,6 +69,7 @@ def set_not_serving(self):
"""Mark services as NOT_SERVING - call this during shutdown."""
self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING
self._serving_status[self.SGLANG_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING
self._notify_shutdown()
logger.info("Health service status set to NOT_SERVING")

async def Check(
Expand Down Expand Up @@ -113,11 +122,7 @@ async def Check(
time_since_last_receive = time.time() - self.request_manager.last_receive_tstamp

# If no recent activity and we have active requests, might be stuck
# NOTE: 30s timeout is hardcoded. This is more conservative than
# HEALTH_CHECK_TIMEOUT (20s) used for custom HealthCheck RPC.
# Consider making this configurable via environment variable in the future
# if different workloads need different responsiveness thresholds.
if time_since_last_receive > 30 and len(self.request_manager.rid_to_state) > 0:
if time_since_last_receive > self.SCHEDULER_RESPONSIVENESS_TIMEOUT_S and len(self.request_manager.rid_to_state) > 0:
logger.warning(
f"Service health check: Scheduler not responsive "
f"({time_since_last_receive:.1f}s since last receive, "
Expand All @@ -139,30 +144,31 @@ async def Check(
status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN
)

async def Watch(
self,
request: health_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> AsyncIterator[health_pb2.HealthCheckResponse]:
"""
Streaming health check - sends updates when status changes.
def _is_shutting_down(self) -> bool:
# _watch_notified_shutdown is set by _notify_shutdown() in set_not_serving();
# gracefully_exit covers external shutdown from the request manager.
return self.request_manager.gracefully_exit or self._watch_notified_shutdown

For now, just send current status once (Kubernetes doesn't use Watch).
A full implementation would monitor status changes and stream updates.

Args:
request: Contains service name
context: gRPC context

Yields:
HealthCheckResponse messages when status changes
"""
service_name = request.service
logger.debug(f"Health watch request for service: '{service_name}'")
def _compute_watch_status(self, service_name: str) -> int:
"""Sync status computation -- no I/O needed."""
if self.request_manager.gracefully_exit:
return NOT_SERVING

# Send current status
response = await self.Check(request, context)
yield response
if service_name == self.OVERALL_SERVER:
return self._serving_status.get(self.OVERALL_SERVER, NOT_SERVING)

if service_name == self.SGLANG_SERVICE:
base_status = self._serving_status.get(self.SGLANG_SERVICE, NOT_SERVING)
if base_status != SERVING:
return base_status
time_since = time.time() - self.request_manager.last_receive_tstamp
if time_since > self.SCHEDULER_RESPONSIVENESS_TIMEOUT_S and len(self.request_manager.rid_to_state) > 0:
logger.warning(
"Scheduler not responsive (%.1fs, %d pending)",
time_since,
len(self.request_manager.rid_to_state),
)
return NOT_SERVING
return SERVING

# Note: Full Watch implementation would monitor status changes
# and stream updates. For K8s probes, Check is sufficient.
return SERVICE_UNKNOWN
52 changes: 20 additions & 32 deletions grpc_servicer/smg_grpc_servicer/vllm/health_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@
to AsyncLLM.check_health() from the vLLM EngineClient protocol.
"""

from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

import grpc
from grpc_health.v1 import health_pb2, health_pb2_grpc
from vllm.logger import init_logger

from smg_grpc_servicer.health_watch import HealthWatchMixin

if TYPE_CHECKING:
from vllm.v1.engine.async_llm import AsyncLLM

logger = init_logger(__name__)

SERVING = health_pb2.HealthCheckResponse.SERVING
NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING
SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN


class VllmHealthServicer(health_pb2_grpc.HealthServicer):
class VllmHealthServicer(HealthWatchMixin, health_pb2_grpc.HealthServicer):
"""
Standard gRPC health check service for Kubernetes probes.
Implements grpc.health.v1.Health protocol.
Expand All @@ -44,11 +49,13 @@ def __init__(self, async_llm: "AsyncLLM"):
"""
self.async_llm = async_llm
self._shutting_down = False
self._init_watch()
logger.info("Standard gRPC health service initialized")

def set_not_serving(self):
"""Mark all services as NOT_SERVING during graceful shutdown."""
self._shutting_down = True
self._notify_shutdown()
logger.info("Health service status set to NOT_SERVING")

async def Check(
Expand Down Expand Up @@ -89,43 +96,24 @@ async def Check(
context.set_details(f"Unknown service: {service_name}")
return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN)

async def Watch(
self,
request: health_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> AsyncIterator[health_pb2.HealthCheckResponse]:
"""
Streaming health check - sends current status once.

For now, sends current status once (Kubernetes doesn't use Watch).
A full implementation would monitor status changes and stream updates.
def _is_shutting_down(self) -> bool:
return self._shutting_down

Args:
request: Contains service name
context: gRPC context

Yields:
HealthCheckResponse messages
"""
service_name = request.service
logger.debug(f"Health watch request for service: '{service_name}'")

# Inline status computation to avoid Check()'s context.set_code()
# side effect, which would incorrectly set the RPC status on the
# streaming response for unknown services.
status = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN
async def _compute_watch_status(self, service_name: str) -> int:
"""Async status computation -- delegates to check_health()."""
if self._shutting_down:
status = health_pb2.HealthCheckResponse.NOT_SERVING
elif service_name in (self.OVERALL_SERVER, self.VLLM_SERVICE):
return NOT_SERVING

if service_name in (self.OVERALL_SERVER, self.VLLM_SERVICE):
try:
await self.async_llm.check_health()
status = health_pb2.HealthCheckResponse.SERVING
return SERVING
except Exception:
logger.debug(
"Health watch check failed for service '%s'",
"Health watch check failed for '%s'",
service_name,
exc_info=True,
)
status = health_pb2.HealthCheckResponse.NOT_SERVING
return NOT_SERVING

yield health_pb2.HealthCheckResponse(status=status)
return SERVICE_UNKNOWN
Empty file.
Loading