Skip to content

Commit 011a200

Browse files
authored
feat(fault-injection): Add core testing helper utilities (#4040)
1 parent 7893427 commit 011a200

File tree

3 files changed

+642
-0
lines changed

3 files changed

+642
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
"""
6+
Fault tolerance testing helper utilities.
7+
8+
This package provides reusable utilities for testing fault tolerance scenarios.
9+
"""
10+
11+
__all__ = [
12+
"InferenceLoadTester",
13+
"get_inference_endpoint",
14+
"NodeOperations",
15+
"PodOperations",
16+
]
17+
18+
from .inference_testing import InferenceLoadTester, get_inference_endpoint
19+
from .k8s_operations import NodeOperations, PodOperations
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
"""
6+
Inference load testing utilities for fault tolerance tests.
7+
8+
Provides continuous load generation and statistics tracking for
9+
validating inference availability during fault injection scenarios.
10+
11+
Supports both local (port-forwarded) and in-cluster execution.
12+
"""
13+
14+
import os
15+
import threading
16+
import time
17+
from typing import Dict, List, Optional
18+
19+
import requests
20+
21+
22+
def get_inference_endpoint(
23+
deployment_name: str, namespace: str, local_port: int = 8000
24+
) -> str:
25+
"""
26+
Get inference endpoint URL based on environment.
27+
28+
Args:
29+
deployment_name: Name of the deployment
30+
namespace: Kubernetes namespace
31+
local_port: Port for local port-forwarding (default: 8000)
32+
33+
Returns:
34+
Inference endpoint URL
35+
"""
36+
in_cluster = os.getenv("KUBERNETES_SERVICE_HOST") is not None
37+
38+
if in_cluster:
39+
# Use cluster-internal service DNS
40+
return (
41+
f"http://{deployment_name}.{namespace}.svc.cluster.local:80/v1/completions"
42+
)
43+
else:
44+
# Use port-forwarded localhost
45+
return f"http://localhost:{local_port}/v1/completions"
46+
47+
48+
class InferenceLoadTester:
49+
"""Continuous inference load generator for fault tolerance testing."""
50+
51+
def __init__(self, endpoint: str, model_name: str, timeout: int = 30):
52+
"""
53+
Initialize the inference load tester.
54+
55+
Args:
56+
endpoint: Inference endpoint URL (e.g., "http://localhost:8000/v1/completions")
57+
model_name: Model name to use in requests
58+
timeout: Request timeout in seconds (default: 30)
59+
"""
60+
self.endpoint = endpoint
61+
self.model_name = model_name
62+
self.timeout = timeout
63+
self.running = False
64+
self.thread: Optional[threading.Thread] = None
65+
self.results: List[Dict] = []
66+
self.lock = threading.Lock()
67+
68+
def send_inference_request(self, prompt: str = "Hello, world!") -> Dict:
69+
"""
70+
Send a single inference request and return result.
71+
72+
Args:
73+
prompt: Text prompt for inference
74+
75+
Returns:
76+
Dict with keys: success, status_code, latency, timestamp, error
77+
"""
78+
try:
79+
start_time = time.time()
80+
response = requests.post(
81+
self.endpoint,
82+
json={
83+
"model": self.model_name,
84+
"prompt": prompt,
85+
"max_tokens": 50,
86+
"temperature": 0.7,
87+
},
88+
timeout=self.timeout,
89+
)
90+
latency = time.time() - start_time
91+
92+
return {
93+
"success": response.status_code == 200,
94+
"status_code": response.status_code,
95+
"latency": latency,
96+
"timestamp": time.time(),
97+
"error": None if response.status_code == 200 else response.text[:200],
98+
}
99+
except requests.exceptions.Timeout:
100+
return {
101+
"success": False,
102+
"status_code": None,
103+
"latency": self.timeout,
104+
"timestamp": time.time(),
105+
"error": "Request timeout",
106+
}
107+
except Exception as e:
108+
return {
109+
"success": False,
110+
"status_code": None,
111+
"latency": time.time() - start_time if "start_time" in locals() else 0,
112+
"timestamp": time.time(),
113+
"error": str(e)[:200],
114+
}
115+
116+
def _load_loop(self, interval: float = 2.0):
117+
"""Background loop sending requests at specified interval."""
118+
while self.running:
119+
result = self.send_inference_request()
120+
with self.lock:
121+
self.results.append(result)
122+
time.sleep(interval)
123+
124+
def start(self, interval: float = 2.0):
125+
"""
126+
Start sending inference requests in background.
127+
128+
Args:
129+
interval: Seconds between requests (default: 2.0)
130+
"""
131+
if self.running:
132+
return
133+
134+
self.running = True
135+
self.results = []
136+
self.thread = threading.Thread(
137+
target=self._load_loop, args=(interval,), daemon=True
138+
)
139+
self.thread.start()
140+
141+
def stop(self) -> List[Dict]:
142+
"""
143+
Stop sending requests and return results.
144+
145+
Returns:
146+
List of all request results
147+
"""
148+
self.running = False
149+
if self.thread:
150+
self.thread.join(timeout=5)
151+
152+
with self.lock:
153+
return self.results.copy()
154+
155+
def get_stats(self) -> Dict:
156+
"""
157+
Get statistics for current results.
158+
159+
Returns:
160+
Dict with keys: total, success, failed, success_rate, avg_latency, errors
161+
"""
162+
with self.lock:
163+
if not self.results:
164+
return {
165+
"total": 0,
166+
"success": 0,
167+
"failed": 0,
168+
"success_rate": 0.0,
169+
"avg_latency": 0.0,
170+
"errors": [],
171+
}
172+
173+
total = len(self.results)
174+
success = sum(1 for r in self.results if r["success"])
175+
failed = total - success
176+
avg_latency = sum(r["latency"] for r in self.results if r["success"]) / max(
177+
success, 1
178+
)
179+
180+
return {
181+
"total": total,
182+
"success": success,
183+
"failed": failed,
184+
"success_rate": (success / total) * 100,
185+
"avg_latency": avg_latency,
186+
"errors": [r["error"] for r in self.results if r["error"]][:5],
187+
}

0 commit comments

Comments
 (0)