diff --git a/grpc_servicer/pyproject.toml b/grpc_servicer/pyproject.toml index 14172f280..91fccab2b 100644 --- a/grpc_servicer/pyproject.toml +++ b/grpc_servicer/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "smg-grpc-servicer" -version = "0.5.2" +version = "0.6.0" description = "SMG gRPC servicer implementations for LLM inference engines (vLLM, SGLang)" requires-python = ">=3.10" dependencies = [ @@ -32,6 +32,7 @@ classifiers = [ [project.optional-dependencies] vllm = ["vllm>=0.19.0"] sglang = ["sglang>=0.5.10"] +test = ["pytest>=7.0", "pytest-asyncio>=0.21", "pytest-timeout>=2.0"] [project.urls] Homepage = "https://github.com/lightseekorg/smg" diff --git a/grpc_servicer/smg_grpc_servicer/health_watch.py b/grpc_servicer/smg_grpc_servicer/health_watch.py new file mode 100644 index 000000000..9763abec2 --- /dev/null +++ b/grpc_servicer/smg_grpc_servicer/health_watch.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Shared Watch() continuous streaming for gRPC health servicers.""" + +import asyncio +import inspect +import logging +from collections.abc import AsyncIterator + +import grpc +from grpc_health.v1 import health_pb2 + +logger = logging.getLogger(__name__) + + +class HealthWatchMixin: + """Continuous Watch() streaming for gRPC health servicers. + + Implements the gRPC Health Checking Protocol Watch RPC as a long-lived + server-streaming response that sends updates on status change. + + Subclasses must: + 1. Call self._init_watch() in __init__ + 2. Implement _compute_watch_status(service_name) -> ServingStatus + 3. Implement _is_shutting_down() -> bool + 4. Call self._notify_shutdown() in set_not_serving() + """ + + WATCH_POLL_INTERVAL_S = 5.0 + + def _init_watch(self) -> None: + """Initialize Watch state. Must be called in subclass __init__. + + Note: on Python 3.10-3.11, asyncio.Event() captures the running + event loop at construction time. Both SGLang and vLLM construct + their servicers within the async server context (verified against + smg sglang/server.py and vllm grpc_server.py), so this is safe. + Python 3.12+ removed the loop binding entirely. + """ + self._watch_shutdown_event = asyncio.Event() + self._watch_notified_shutdown = False + + def _notify_shutdown(self) -> None: + """Wake all Watch streams to detect shutdown immediately. + Must be called in subclass set_not_serving(). Sets + _watch_notified_shutdown so _is_shutting_down() implementations + can check it alongside their own shutdown flags.""" + self._watch_notified_shutdown = True + self._watch_shutdown_event.set() + + def _compute_watch_status(self, service_name: str) -> int: + """Compute current health status for the given service. + + Must not call context.set_code() -- that would pollute the + streaming response. Return a ServingStatus enum value instead. + + May be overridden as async def for servicers that need I/O + (e.g., VllmHealthServicer calls await async_llm.check_health()). + """ + raise NotImplementedError(f"{type(self).__name__} must implement _compute_watch_status()") + + def _is_shutting_down(self) -> bool: + """Return True if the server is shutting down.""" + raise NotImplementedError(f"{type(self).__name__} must implement _is_shutting_down()") + + async def _resolve_watch_status(self, service_name: str) -> int: + """Call _compute_watch_status, handling both sync and async impls.""" + result = self._compute_watch_status(service_name) + if inspect.isawaitable(result): + return await result + return result + + async def Watch( + self, + request: health_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[health_pb2.HealthCheckResponse]: + """gRPC Health Watch -- continuous streaming implementation. + + Behavior per gRPC Health Checking Protocol: + - Immediately sends the current serving status + - Sends a new message whenever status changes + - Stream ends on server shutdown or client cancellation + + Deviation from spec: for unknown services, the stream sends + SERVICE_UNKNOWN once then exits. The spec says to keep the stream + open for dynamic service registration, but smg services are + statically defined and never registered at runtime. + """ + service_name = request.service + logger.debug("Health watch request for service: '%s'", service_name) + + last_status = None + try: + while True: + status = await self._resolve_watch_status(service_name) + + if status != last_status: + yield health_pb2.HealthCheckResponse(status=status) + last_status = status + + if self._is_shutting_down(): + return + + if status == health_pb2.HealthCheckResponse.SERVICE_UNKNOWN: + return + + try: + await asyncio.wait_for( + self._watch_shutdown_event.wait(), + timeout=self.WATCH_POLL_INTERVAL_S, + ) + except asyncio.TimeoutError: + pass + + except asyncio.CancelledError: + logger.debug( + "Health watch cancelled by client for service: '%s'", + service_name, + ) diff --git a/grpc_servicer/smg_grpc_servicer/sglang/health_servicer.py b/grpc_servicer/smg_grpc_servicer/sglang/health_servicer.py index c7127d520..65d8d8687 100644 --- a/grpc_servicer/smg_grpc_servicer/sglang/health_servicer.py +++ b/grpc_servicer/smg_grpc_servicer/sglang/health_servicer.py @@ -7,15 +7,20 @@ import logging import time -from collections.abc import AsyncIterator import grpc from grpc_health.v1 import health_pb2, health_pb2_grpc +from smg_grpc_servicer.health_watch import HealthWatchMixin + logger = logging.getLogger(__name__) +SERVING = health_pb2.HealthCheckResponse.SERVING +NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING +SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + -class SGLangHealthServicer(health_pb2_grpc.HealthServicer): +class SGLangHealthServicer(HealthWatchMixin, health_pb2_grpc.HealthServicer): """ Standard gRPC health check service implementation for Kubernetes probes. Implements grpc.health.v1.Health protocol. @@ -29,6 +34,8 @@ class SGLangHealthServicer(health_pb2_grpc.HealthServicer): - SERVING: Model loaded and ready to serve requests """ + SCHEDULER_RESPONSIVENESS_TIMEOUT_S = 30 + # Service names we support OVERALL_SERVER = "" # Empty string for overall server health SGLANG_SERVICE = "sglang.grpc.scheduler.SglangScheduler" @@ -49,6 +56,7 @@ def __init__(self, request_manager, scheduler_info: dict): self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING self._serving_status[self.SGLANG_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING + self._init_watch() logger.info("Standard gRPC health service initialized") def set_serving(self): @@ -61,6 +69,7 @@ def set_not_serving(self): """Mark services as NOT_SERVING - call this during shutdown.""" self._serving_status[self.OVERALL_SERVER] = health_pb2.HealthCheckResponse.NOT_SERVING self._serving_status[self.SGLANG_SERVICE] = health_pb2.HealthCheckResponse.NOT_SERVING + self._notify_shutdown() logger.info("Health service status set to NOT_SERVING") async def Check( @@ -113,11 +122,7 @@ async def Check( time_since_last_receive = time.time() - self.request_manager.last_receive_tstamp # If no recent activity and we have active requests, might be stuck - # NOTE: 30s timeout is hardcoded. This is more conservative than - # HEALTH_CHECK_TIMEOUT (20s) used for custom HealthCheck RPC. - # Consider making this configurable via environment variable in the future - # if different workloads need different responsiveness thresholds. - if time_since_last_receive > 30 and len(self.request_manager.rid_to_state) > 0: + if time_since_last_receive > self.SCHEDULER_RESPONSIVENESS_TIMEOUT_S and len(self.request_manager.rid_to_state) > 0: logger.warning( f"Service health check: Scheduler not responsive " f"({time_since_last_receive:.1f}s since last receive, " @@ -139,30 +144,31 @@ async def Check( status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN ) - async def Watch( - self, - request: health_pb2.HealthCheckRequest, - context: grpc.aio.ServicerContext, - ) -> AsyncIterator[health_pb2.HealthCheckResponse]: - """ - Streaming health check - sends updates when status changes. + def _is_shutting_down(self) -> bool: + # _watch_notified_shutdown is set by _notify_shutdown() in set_not_serving(); + # gracefully_exit covers external shutdown from the request manager. + return self.request_manager.gracefully_exit or self._watch_notified_shutdown - For now, just send current status once (Kubernetes doesn't use Watch). - A full implementation would monitor status changes and stream updates. - - Args: - request: Contains service name - context: gRPC context - - Yields: - HealthCheckResponse messages when status changes - """ - service_name = request.service - logger.debug(f"Health watch request for service: '{service_name}'") + def _compute_watch_status(self, service_name: str) -> int: + """Sync status computation -- no I/O needed.""" + if self.request_manager.gracefully_exit: + return NOT_SERVING - # Send current status - response = await self.Check(request, context) - yield response + if service_name == self.OVERALL_SERVER: + return self._serving_status.get(self.OVERALL_SERVER, NOT_SERVING) + + if service_name == self.SGLANG_SERVICE: + base_status = self._serving_status.get(self.SGLANG_SERVICE, NOT_SERVING) + if base_status != SERVING: + return base_status + time_since = time.time() - self.request_manager.last_receive_tstamp + if time_since > self.SCHEDULER_RESPONSIVENESS_TIMEOUT_S and len(self.request_manager.rid_to_state) > 0: + logger.warning( + "Scheduler not responsive (%.1fs, %d pending)", + time_since, + len(self.request_manager.rid_to_state), + ) + return NOT_SERVING + return SERVING - # Note: Full Watch implementation would monitor status changes - # and stream updates. For K8s probes, Check is sufficient. + return SERVICE_UNKNOWN diff --git a/grpc_servicer/smg_grpc_servicer/vllm/health_servicer.py b/grpc_servicer/smg_grpc_servicer/vllm/health_servicer.py index f138b5cdb..024abfa60 100644 --- a/grpc_servicer/smg_grpc_servicer/vllm/health_servicer.py +++ b/grpc_servicer/smg_grpc_servicer/vllm/health_servicer.py @@ -5,20 +5,25 @@ to AsyncLLM.check_health() from the vLLM EngineClient protocol. """ -from collections.abc import AsyncIterator from typing import TYPE_CHECKING import grpc from grpc_health.v1 import health_pb2, health_pb2_grpc from vllm.logger import init_logger +from smg_grpc_servicer.health_watch import HealthWatchMixin + if TYPE_CHECKING: from vllm.v1.engine.async_llm import AsyncLLM logger = init_logger(__name__) +SERVING = health_pb2.HealthCheckResponse.SERVING +NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING +SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + -class VllmHealthServicer(health_pb2_grpc.HealthServicer): +class VllmHealthServicer(HealthWatchMixin, health_pb2_grpc.HealthServicer): """ Standard gRPC health check service for Kubernetes probes. Implements grpc.health.v1.Health protocol. @@ -44,11 +49,13 @@ def __init__(self, async_llm: "AsyncLLM"): """ self.async_llm = async_llm self._shutting_down = False + self._init_watch() logger.info("Standard gRPC health service initialized") def set_not_serving(self): """Mark all services as NOT_SERVING during graceful shutdown.""" self._shutting_down = True + self._notify_shutdown() logger.info("Health service status set to NOT_SERVING") async def Check( @@ -89,43 +96,24 @@ async def Check( context.set_details(f"Unknown service: {service_name}") return health_pb2.HealthCheckResponse(status=health_pb2.HealthCheckResponse.SERVICE_UNKNOWN) - async def Watch( - self, - request: health_pb2.HealthCheckRequest, - context: grpc.aio.ServicerContext, - ) -> AsyncIterator[health_pb2.HealthCheckResponse]: - """ - Streaming health check - sends current status once. - - For now, sends current status once (Kubernetes doesn't use Watch). - A full implementation would monitor status changes and stream updates. + def _is_shutting_down(self) -> bool: + return self._shutting_down - Args: - request: Contains service name - context: gRPC context - - Yields: - HealthCheckResponse messages - """ - service_name = request.service - logger.debug(f"Health watch request for service: '{service_name}'") - - # Inline status computation to avoid Check()'s context.set_code() - # side effect, which would incorrectly set the RPC status on the - # streaming response for unknown services. - status = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + async def _compute_watch_status(self, service_name: str) -> int: + """Async status computation -- delegates to check_health().""" if self._shutting_down: - status = health_pb2.HealthCheckResponse.NOT_SERVING - elif service_name in (self.OVERALL_SERVER, self.VLLM_SERVICE): + return NOT_SERVING + + if service_name in (self.OVERALL_SERVER, self.VLLM_SERVICE): try: await self.async_llm.check_health() - status = health_pb2.HealthCheckResponse.SERVING + return SERVING except Exception: logger.debug( - "Health watch check failed for service '%s'", + "Health watch check failed for '%s'", service_name, exc_info=True, ) - status = health_pb2.HealthCheckResponse.NOT_SERVING + return NOT_SERVING - yield health_pb2.HealthCheckResponse(status=status) + return SERVICE_UNKNOWN diff --git a/grpc_servicer/tests/__init__.py b/grpc_servicer/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/grpc_servicer/tests/conftest.py b/grpc_servicer/tests/conftest.py new file mode 100644 index 000000000..743f7e638 --- /dev/null +++ b/grpc_servicer/tests/conftest.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys +from unittest.mock import MagicMock + +import pytest +from grpc_health.v1 import health_pb2 + +# Stub out vllm and its submodules so vllm health_servicer can be imported +# without a full vLLM installation. MagicMock-based stubs auto-satisfy +# any attribute access and from-import statements at collection time. +# This must run before any smg_grpc_servicer imports are resolved. +_VLLM_STUBS = [ + "vllm", + "vllm.engine", + "vllm.engine.protocol", + "vllm.inputs", + "vllm.inputs.engine", + "vllm.logger", + "vllm.logprobs", + "vllm.multimodal", + "vllm.multimodal.inputs", + "vllm.outputs", + "vllm.sampling_params", + "vllm.v1", + "vllm.v1.engine", + "vllm.v1.engine.async_llm", +] +for _name in _VLLM_STUBS: + if _name not in sys.modules: + sys.modules[_name] = MagicMock() + +# Stub out sglang and its submodules so health_servicer can be imported +# without a full SGLang installation. MagicMock-based stubs auto-satisfy +# any attribute access and from-import statements at collection time. +# This must run before any smg_grpc_servicer imports are resolved. +_SGLANG_STUBS = [ + "sglang", + "sglang.srt", + "sglang.srt.configs", + "sglang.srt.configs.model_config", + "sglang.srt.disaggregation", + "sglang.srt.disaggregation.kv_events", + "sglang.srt.disaggregation.utils", + "sglang.srt.managers", + "sglang.srt.managers.data_parallel_controller", + "sglang.srt.managers.disagg_service", + "sglang.srt.managers.io_struct", + "sglang.srt.managers.schedule_batch", + "sglang.srt.managers.scheduler", + "sglang.srt.observability", + "sglang.srt.observability.req_time_stats", + "sglang.srt.sampling", + "sglang.srt.sampling.sampling_params", + "sglang.srt.server_args", + "sglang.srt.utils", + "sglang.srt.utils.network", + "sglang.srt.utils.torch_memory_saver_adapter", + "sglang.utils", +] +for _name in _SGLANG_STUBS: + if _name not in sys.modules: + sys.modules[_name] = MagicMock() + +SERVING = health_pb2.HealthCheckResponse.SERVING +NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING +SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + + +@pytest.fixture +def grpc_context(): + return MagicMock(spec=["set_code", "set_details", "cancelled", "done"]) + + +@pytest.fixture +def request_msg(): + msg = MagicMock() + msg.service = "" + return msg diff --git a/grpc_servicer/tests/test_sglang_health_watch.py b/grpc_servicer/tests/test_sglang_health_watch.py new file mode 100644 index 000000000..fca910581 --- /dev/null +++ b/grpc_servicer/tests/test_sglang_health_watch.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for SGLangHealthServicer Watch() continuous streaming.""" + +import asyncio +from unittest.mock import MagicMock + +import pytest +from grpc_health.v1 import health_pb2 +from smg_grpc_servicer.sglang.health_servicer import SGLangHealthServicer + +SERVING = health_pb2.HealthCheckResponse.SERVING +NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING +SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + + +@pytest.fixture +def request_manager(): + mgr = MagicMock() + mgr.gracefully_exit = False + # float("inf") ensures scheduler-responsiveness timeout (30s) never triggers + mgr.last_receive_tstamp = float("inf") + mgr.rid_to_state = {} + return mgr + + +@pytest.fixture +def servicer(request_manager): + return SGLangHealthServicer( + request_manager=request_manager, + scheduler_info={"model_path": "test"}, + ) + + +@pytest.mark.asyncio +async def test_watch_sends_initial_status(servicer, request_msg, grpc_context): + """Watch must immediately send the current status.""" + servicer.set_serving() + request_msg.service = "" + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + if len(received) == 1: + servicer.set_not_serving() + + assert received[0] == SERVING + + +@pytest.mark.asyncio +async def test_watch_yields_on_status_change(servicer, request_msg, grpc_context, request_manager): + """Watch must send a new response when status changes.""" + servicer.set_serving() + request_msg.service = "" + + received = [] + + async def trigger_shutdown(): + await asyncio.sleep(0.1) + servicer.set_not_serving() + + task = asyncio.create_task(trigger_shutdown()) + + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + assert received == [SERVING, NOT_SERVING] + + +@pytest.mark.asyncio +async def test_watch_exits_on_shutdown(servicer, request_msg, grpc_context): + """set_not_serving() must cause Watch to end the stream.""" + servicer.set_serving() + request_msg.service = "" + + async def trigger_shutdown(): + await asyncio.sleep(0.05) + servicer.set_not_serving() + + task = asyncio.create_task(trigger_shutdown()) + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + assert len(received) >= 1 + assert received[-1] == NOT_SERVING + + +@pytest.mark.asyncio +async def test_watch_handles_client_cancel(servicer, request_msg, grpc_context): + """Task cancellation (real client disconnect) must not raise unexpected errors.""" + servicer.set_serving() + request_msg.service = "" + + async def consume_forever(): + async for _ in servicer.Watch(request_msg, grpc_context): + pass + + task = asyncio.create_task(consume_forever()) + await asyncio.sleep(0.05) + task.cancel() + # Watch() catches CancelledError internally. The task may complete + # normally or propagate cancellation depending on asyncio internals. + # Either outcome is correct -- verify no unexpected exception. + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_watch_unknown_service(servicer, request_msg, grpc_context): + """Unknown service: single SERVICE_UNKNOWN, no context.set_code().""" + servicer.set_serving() + request_msg.service = "nonexistent.Service" + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + assert received == [SERVICE_UNKNOWN] + grpc_context.set_code.assert_not_called() + + +@pytest.mark.asyncio +async def test_watch_no_duplicate_on_stable_status(servicer, request_msg, grpc_context): + """Stable status must not yield duplicate responses.""" + servicer.set_serving() + request_msg.service = "" + + original_interval = servicer.WATCH_POLL_INTERVAL_S + servicer.WATCH_POLL_INTERVAL_S = 0.05 + + received = [] + + async def stop_after_polls(): + await asyncio.sleep(0.2) + servicer.set_not_serving() + + task = asyncio.create_task(stop_after_polls()) + + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + servicer.WATCH_POLL_INTERVAL_S = original_interval + + assert received == [SERVING, NOT_SERVING] + + +@pytest.mark.asyncio +async def test_watch_detects_graceful_exit_via_poll( + servicer, request_msg, grpc_context, request_manager +): + """Watch must detect request_manager.gracefully_exit on next poll cycle, + even without _notify_shutdown() (simulates external shutdown signal).""" + servicer.set_serving() + request_msg.service = "" + + servicer.WATCH_POLL_INTERVAL_S = 0.05 + + received = [] + + async def trigger_graceful_exit(): + await asyncio.sleep(0.1) + request_manager.gracefully_exit = True + + task = asyncio.create_task(trigger_graceful_exit()) + + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + assert received[0] == SERVING + assert received[-1] == NOT_SERVING diff --git a/grpc_servicer/tests/test_vllm_health_watch.py b/grpc_servicer/tests/test_vllm_health_watch.py new file mode 100644 index 000000000..8057cb3bf --- /dev/null +++ b/grpc_servicer/tests/test_vllm_health_watch.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for VllmHealthServicer Watch() continuous streaming.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest +from grpc_health.v1 import health_pb2 +from smg_grpc_servicer.vllm.health_servicer import VllmHealthServicer + +SERVING = health_pb2.HealthCheckResponse.SERVING +NOT_SERVING = health_pb2.HealthCheckResponse.NOT_SERVING +SERVICE_UNKNOWN = health_pb2.HealthCheckResponse.SERVICE_UNKNOWN + + +@pytest.fixture +def async_llm(): + mock = MagicMock() + mock.check_health = AsyncMock() + return mock + + +@pytest.fixture +def servicer(async_llm): + return VllmHealthServicer(async_llm) + + +@pytest.mark.asyncio +async def test_watch_sends_initial_serving(servicer, request_msg, grpc_context): + """Watch must immediately send SERVING when engine is healthy.""" + request_msg.service = "" + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + if len(received) == 1: + servicer.set_not_serving() + + assert received[0] == SERVING + + +@pytest.mark.asyncio +async def test_watch_yields_on_engine_failure(servicer, request_msg, grpc_context, async_llm): + """Watch must send NOT_SERVING when check_health() starts failing.""" + request_msg.service = "" + servicer.WATCH_POLL_INTERVAL_S = 0.05 + + received = [] + poll_count = 0 + + original_check = async_llm.check_health + + async def check_health_with_failure(): + nonlocal poll_count + poll_count += 1 + if poll_count >= 3: + raise Exception("engine dead") + await original_check() + + async_llm.check_health = AsyncMock(side_effect=check_health_with_failure) + + async def stop_eventually(): + await asyncio.sleep(0.5) + servicer.set_not_serving() + + task = asyncio.create_task(stop_eventually()) + + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + assert SERVING in received + assert NOT_SERVING in received + serving_idx = received.index(SERVING) + not_serving_idx = received.index(NOT_SERVING) + assert serving_idx < not_serving_idx + + +@pytest.mark.asyncio +async def test_watch_exits_on_shutdown(servicer, request_msg, grpc_context): + """set_not_serving() must wake Watch and end the stream.""" + request_msg.service = "" + + async def trigger_shutdown(): + await asyncio.sleep(0.05) + servicer.set_not_serving() + + task = asyncio.create_task(trigger_shutdown()) + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + assert len(received) >= 1 + assert received[-1] == NOT_SERVING + + +@pytest.mark.asyncio +async def test_watch_handles_client_cancel(servicer, request_msg, grpc_context): + """Task cancellation (real client disconnect) must not raise unexpected errors.""" + request_msg.service = "" + + async def consume_forever(): + async for _ in servicer.Watch(request_msg, grpc_context): + pass + + task = asyncio.create_task(consume_forever()) + await asyncio.sleep(0.05) + task.cancel() + # Watch() catches CancelledError internally. The task may complete + # normally or propagate cancellation depending on asyncio internals. + # Either outcome is correct -- verify no unexpected exception. + try: + await task + except asyncio.CancelledError: + pass + + +@pytest.mark.asyncio +async def test_watch_unknown_service(servicer, request_msg, grpc_context): + """Unknown service: single SERVICE_UNKNOWN, no context.set_code().""" + request_msg.service = "fake.Service" + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + assert received == [SERVICE_UNKNOWN] + grpc_context.set_code.assert_not_called() + + +@pytest.mark.asyncio +async def test_watch_no_duplicate_on_stable_status(servicer, request_msg, grpc_context): + """Stable SERVING must not yield duplicates across poll cycles.""" + request_msg.service = "" + servicer.WATCH_POLL_INTERVAL_S = 0.05 + + received = [] + + async def stop_after_polls(): + await asyncio.sleep(0.2) + servicer.set_not_serving() + + task = asyncio.create_task(stop_after_polls()) + + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + await task + assert received == [SERVING, NOT_SERVING] + + +@pytest.mark.asyncio +async def test_watch_shutdown_overrides_healthy(servicer, request_msg, grpc_context, async_llm): + """After set_not_serving(), Watch returns NOT_SERVING even if + check_health() would succeed.""" + servicer.set_not_serving() + request_msg.service = "" + + received = [] + async for response in servicer.Watch(request_msg, grpc_context): + received.append(response.status) + + assert received == [NOT_SERVING] + async_llm.check_health.assert_not_awaited()