-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathclient.py
More file actions
256 lines (205 loc) · 7.31 KB
/
client.py
File metadata and controls
256 lines (205 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Client framework for GPU inference game.
Provides base class for participants to implement.
"""
import socket
import threading
import time
import queue
from collections import defaultdict
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass
from abc import ABC, abstractmethod
from protocol import (
ProtocolHandler,
SocketReader,
SocketWriter,
RegisterMessage,
InferenceRequest,
InferenceResponse,
ScoreUpdate,
Heartbeat,
ErrorMessage,
)
@dataclass
class PendingRequest:
"""A pending inference request."""
unique_id: int
symbol: str
features: List[float]
received_time: float
def age_ms(self) -> float:
"""Get age of request in milliseconds."""
return (time.time() - self.received_time) * 1000
class BaseInferenceClient(ABC):
"""
Base class for inference clients.
Participants should inherit from this and implement process_batch().
"""
def __init__(
self,
num_symbols: int,
server_host: str = "localhost",
server_port: int = 8080,
max_queue_size: int = 100,
):
"""
Initialize the client.
Args:
server_host: Server hostname
server_port: Server TCP port
max_queue_size: Maximum pending requests per symbol
"""
self.num_symbols = num_symbols
self.server_host = server_host
self.server_port = server_port
self.max_queue_size = max_queue_size
# Connection
self.socket = None
self.reader = None
self.writer = None
# Request queues by symbol
self.request_queues: Dict[str, queue.Queue] = defaultdict(
lambda: queue.Queue(maxsize=max_queue_size)
)
self.queue_lock = threading.RLock()
# Threading
self.running = False
self.receive_thread = None
self.process_thread = None
# Response tracking
self.response_queue = queue.Queue()
@abstractmethod
def process_batch(
self, requests_by_symbol: Dict[str, List[PendingRequest]]
) -> InferenceResponse:
pass
def connect(self) -> bool:
"""Connect to the server."""
try:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect((self.server_host, self.server_port))
self.reader = SocketReader(self.socket)
self.writer = SocketWriter(self.socket)
# Send registration
reg_msg = RegisterMessage()
if not self.writer.send_message(reg_msg):
print(f"Failed to send registration")
return False
print(f"Connected to {self.server_host}:{self.server_port}")
return True
except Exception as e:
print(f"Connection failed: {e}")
return False
def disconnect(self):
"""Disconnect from server."""
if self.socket:
try:
self.socket.close()
except:
pass
self.socket = None
def _receive_loop(self):
"""Background thread to receive messages from server."""
last_heartbeat_send = 0
assert self.reader
assert self.writer
while self.running:
try:
messages = self.reader.read_all_available()
for msg in messages:
if isinstance(msg, InferenceRequest):
self._handle_request(msg)
elif isinstance(msg, ScoreUpdate):
self._handle_score(msg)
elif isinstance(msg, ErrorMessage):
print(f"Server error: {msg.error}")
time_seconds = int(time.time())
if time_seconds % 5 == 0 and (time_seconds - last_heartbeat_send) >= 5:
self.writer.send_message(Heartbeat(timestamp=time.time()))
last_heartbeat_send = time_seconds
time.sleep(0.001)
except Exception as e:
if self.running:
print(f"Receive error: {e}")
break
def _handle_request(self, request: InferenceRequest):
"""Handle incoming inference request."""
for unique_id, symbol, features in zip(
request.unique_ids, request.symbols, request.features
):
pending = PendingRequest(
unique_id=unique_id,
symbol=symbol,
features=features,
received_time=time.time(),
)
with self.queue_lock:
try:
self.request_queues[symbol].put_nowait(pending)
except queue.Full:
print(f"Queue full for symbol {symbol}, dropping request")
def _handle_score(self, score: ScoreUpdate):
"""Handle score update from server."""
pass
def _process_loop(self):
"""Background thread to process inference requests."""
while self.running:
try:
# Gather current requests by symbol
requests_by_symbol = self._gather_requests()
if requests_by_symbol:
# Call user's implementation
response = self.process_batch(requests_by_symbol)
if self.writer and not self.writer.send_message(response):
print(f"Failed to send response for {response.unique_ids = }")
else:
# No requests, small sleep
time.sleep(0.001)
except Exception as e:
print(f"Process error: {e}")
import traceback
traceback.print_exc()
time.sleep(0.1)
def _gather_requests(self) -> Dict[str, List[PendingRequest]]:
"""Gather all pending requests by symbol."""
requests_by_symbol = {}
with self.queue_lock:
for symbol, q in self.request_queues.items():
requests = []
while not q.empty():
try:
req = q.get_nowait()
requests.append(req)
except queue.Empty:
break
if len(requests) > 0:
requests_by_symbol[symbol] = requests
return requests_by_symbol
def run(self):
"""Main client run loop."""
if not self.connect():
return
self.running = True
# Start background threads
self.receive_thread = threading.Thread(target=self._receive_loop, daemon=True)
self.receive_thread.start()
self.process_thread = threading.Thread(target=self._process_loop, daemon=True)
self.process_thread.start()
print(f"Client running. Press Ctrl+C to stop.")
try:
while self.running:
time.sleep(1)
except KeyboardInterrupt:
print("\nShutting down...")
finally:
self.stop()
def stop(self):
"""Stop the client."""
self.running = False
# Wait for threads
if self.receive_thread:
self.receive_thread.join(timeout=1)
if self.process_thread:
self.process_thread.join(timeout=1)
self.disconnect()