Skip to content

Commit 11d273d

Browse files
committed
Initial multiturn scheduler and types
1 parent 0e78c65 commit 11d273d

File tree

4 files changed

+226
-163
lines changed

4 files changed

+226
-163
lines changed

src/guidellm/request/session.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import itertools
2+
from abc import ABC, abstractmethod
3+
from typing import Generic, TypeVar
4+
5+
from guidellm.request.request import GenerationRequest
6+
7+
__all__ = ["RequestSession"]
8+
9+
# TODO: Replace with specific types that implement needed features
10+
RequestT = TypeVar("RequestT")
11+
ResponseT = TypeVar("ResponseT")
12+
13+
14+
class RequestSession(ABC, Generic[RequestT, ResponseT]):
15+
@abstractmethod
16+
def get_next_request(self) -> RequestT: ...
17+
18+
@abstractmethod
19+
def get_next_delay(self) -> float: ...
20+
21+
@abstractmethod
22+
def push_response(self, response: ResponseT) -> None: ...
23+
24+
@property
25+
@abstractmethod
26+
def complete(self) -> bool: ...
27+
28+
29+
# FIXME: Dummy implementation
30+
class GenerativeRequestSession(RequestSession[GenerationRequest, str]):
31+
def __init__(self, prompts: list[GenerationRequest]) -> None:
32+
if not prompts:
33+
raise ValueError("Prompts cannot be empty")
34+
35+
self.prompts = prompts
36+
self.responses: list[str] = []
37+
38+
def get_request(self) -> GenerationRequest:
39+
completed_responses = len(self.responses)
40+
base_request = self.prompts[completed_responses].model_copy()
41+
base_request.content = "".join(
42+
itertools.chain.from_iterable(
43+
zip((x.content for x in self.prompts), self.responses)
44+
)
45+
)
46+
base_request.stats["prompt_tokens"] = sum(
47+
x.stats["prompt_tokens"] for x in self.prompts[: completed_responses + 1]
48+
)
49+
base_request.constraints["output_tokens"] = sum(
50+
x.constraints["output_tokens"] for x in self.prompts[:completed_responses]
51+
)
52+
53+
return base_request
54+
55+
def push_response(self, response: str) -> None:
56+
if len(self.responses) < len(self.prompts):
57+
self.responses.append(response)
58+
else:
59+
raise ValueError("Response list full")
60+
61+
@property
62+
def complete(self) -> bool:
63+
return len(self.responses) >= len(self.prompts)

src/guidellm/scheduler/scheduler.py

Lines changed: 57 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import asyncio
22
import math
3-
import multiprocessing
4-
import multiprocessing.queues
53
import time
64
from collections.abc import AsyncGenerator, Iterable, Iterator
75
from concurrent.futures import ProcessPoolExecutor
6+
from multiprocessing import Manager, Queue
7+
from queue import Empty as QueueEmpty
88
from typing import (
99
Any,
1010
Generic,
@@ -15,17 +15,22 @@
1515
from loguru import logger
1616

1717
from guidellm.config import settings
18+
from guidellm.request.session import RequestSession
1819
from guidellm.scheduler.result import (
1920
SchedulerRequestResult,
2021
SchedulerResult,
2122
SchedulerRunInfo,
2223
)
2324
from guidellm.scheduler.strategy import SchedulingStrategy
24-
from guidellm.scheduler.types import RequestT, ResponseT
25+
from guidellm.scheduler.types import (
26+
MPQueues,
27+
RequestT,
28+
ResponseT,
29+
WorkerProcessRequestTime,
30+
WorkerProcessResult,
31+
)
2532
from guidellm.scheduler.worker import (
2633
RequestsWorker,
27-
WorkerProcessRequest,
28-
WorkerProcessResult,
2934
)
3035

3136
__all__ = ["Scheduler"]
@@ -114,13 +119,13 @@ async def run(
114119
raise ValueError(f"Invalid max_duration: {max_duration}")
115120

116121
with (
117-
multiprocessing.Manager() as manager,
122+
Manager() as manager,
118123
ProcessPoolExecutor(
119124
max_workers=scheduling_strategy.processes_limit
120125
) as executor,
121126
):
122127
requests_iter: Optional[Iterator[Any]] = None
123-
futures, requests_queue, responses_queue = await self._start_processes(
128+
futures, queues = await self._start_processes(
124129
manager, executor, scheduling_strategy
125130
)
126131
run_info, requests_iter, times_iter = self._run_setup(
@@ -149,13 +154,14 @@ async def run(
149154
requests_iter = self._add_requests(
150155
requests_iter,
151156
times_iter,
152-
requests_queue,
157+
queues.requests,
158+
queues.times,
153159
run_info,
154160
)
155161
await asyncio.sleep(0) # enable requests to start
156162

157163
iter_result = self._check_result_ready(
158-
responses_queue,
164+
queues.responses,
159165
run_info,
160166
)
161167
if iter_result is not None:
@@ -171,7 +177,7 @@ async def run(
171177
run_info=run_info,
172178
)
173179

174-
await self._stop_processes(futures, requests_queue)
180+
await self._stop_processes(futures, queues.requests)
175181

176182
async def _start_processes(
177183
self,
@@ -180,14 +186,16 @@ async def _start_processes(
180186
scheduling_strategy: SchedulingStrategy,
181187
) -> tuple[
182188
list[asyncio.Future],
183-
multiprocessing.Queue,
184-
multiprocessing.Queue,
189+
MPQueues[RequestT, ResponseT],
185190
]:
186191
await self.worker.prepare_multiprocessing()
187-
requests_queue = manager.Queue(
188-
maxsize=scheduling_strategy.queued_requests_limit
192+
queues: MPQueues[RequestT, ResponseT] = MPQueues(
193+
requests=manager.Queue(
194+
maxsize=scheduling_strategy.processing_requests_limit
195+
),
196+
times=manager.Queue(maxsize=scheduling_strategy.processing_requests_limit),
197+
responses=manager.Queue(),
189198
)
190-
responses_queue = manager.Queue()
191199

192200
num_processes = min(
193201
scheduling_strategy.processes_limit,
@@ -212,36 +220,20 @@ async def _start_processes(
212220
futures = []
213221
loop = asyncio.get_event_loop()
214222
for id_, requests_limit in zip(process_ids, process_requests_limits):
215-
if scheduling_strategy.processing_mode == "sync":
216-
futures.append(
217-
loop.run_in_executor(
218-
executor,
219-
self.worker.process_loop_synchronous,
220-
requests_queue,
221-
responses_queue,
222-
id_,
223-
)
224-
)
225-
elif scheduling_strategy.processing_mode == "async":
226-
futures.append(
227-
loop.run_in_executor(
228-
executor,
229-
self.worker.process_loop_asynchronous,
230-
requests_queue,
231-
responses_queue,
232-
requests_limit,
233-
id_,
234-
)
235-
)
236-
else:
237-
raise ValueError(
238-
f"Invalid processing mode: {scheduling_strategy.processing_mode} "
239-
f"for strategy: {scheduling_strategy}"
223+
futures.append(
224+
loop.run_in_executor(
225+
executor,
226+
self.worker.process_loop_asynchronous,
227+
queues,
228+
False, # TODO: Make configurable
229+
requests_limit,
230+
id_,
240231
)
232+
)
241233

242234
await asyncio.sleep(0.1) # give time for processes to start
243235

244-
return futures, requests_queue, responses_queue
236+
return futures, queues
245237

246238
def _run_setup(
247239
self,
@@ -284,7 +276,8 @@ def _add_requests(
284276
self,
285277
requests_iter: Optional[Iterator[Any]],
286278
times_iter: Iterator[float],
287-
requests_queue: multiprocessing.Queue,
279+
requests_queue: Queue[RequestSession[RequestT, ResponseT]],
280+
times_queue: Queue[WorkerProcessRequestTime],
288281
run_info: SchedulerRunInfo,
289282
) -> Optional[Iterator[Any]]:
290283
if requests_iter is not None:
@@ -298,23 +291,24 @@ def _add_requests(
298291
if run_info.created_requests >= run_info.end_number:
299292
raise StopIteration
300293

301-
if (
302-
request_time := next(times_iter)
303-
) >= run_info.end_time or time.time() >= run_info.end_time:
304-
raise StopIteration
305-
306-
request = next(requests_iter)
307-
work_req: WorkerProcessRequest[RequestT] = WorkerProcessRequest(
308-
request=request,
309-
start_time=request_time,
310-
timeout_time=run_info.end_time,
311-
queued_time=time.time(),
312-
)
313-
requests_queue.put(work_req)
314-
315-
run_info.created_requests += 1
316-
run_info.queued_requests += 1
317-
added_count += 1
294+
session = next(requests_iter)
295+
requests_queue.put(session)
296+
for _ in range(len(session)):
297+
if (
298+
request_time := next(times_iter)
299+
) >= run_info.end_time or time.time() >= run_info.end_time:
300+
raise StopIteration
301+
302+
work_req = WorkerProcessRequestTime(
303+
start_time=request_time,
304+
timeout_time=run_info.end_time,
305+
queued_time=time.time(),
306+
)
307+
times_queue.put(work_req)
308+
309+
run_info.created_requests += 1
310+
run_info.queued_requests += 1
311+
added_count += 1
318312
except StopIteration:
319313
# we've reached the limit number, limit time, or exhausted the requests
320314
# set to None to stop adding more and tell the loop no more requests
@@ -324,14 +318,14 @@ def _add_requests(
324318

325319
def _check_result_ready(
326320
self,
327-
responses_queue: multiprocessing.Queue,
321+
responses_queue: Queue[WorkerProcessResult[RequestT, ResponseT]],
328322
run_info: SchedulerRunInfo,
329323
) -> Optional[SchedulerRequestResult[RequestT, ResponseT]]:
330324
try:
331325
process_response: WorkerProcessResult[RequestT, ResponseT] = (
332326
responses_queue.get_nowait()
333327
)
334-
except multiprocessing.queues.Empty: # type: ignore[attr-defined]
328+
except QueueEmpty:
335329
return None
336330

337331
if process_response.type_ == "request_scheduled":
@@ -374,8 +368,9 @@ def _check_result_ready(
374368
async def _stop_processes(
375369
self,
376370
futures: list[asyncio.Future],
377-
requests_queue: multiprocessing.Queue,
371+
requests_queue: Queue[RequestSession[RequestT, ResponseT]],
378372
):
373+
# FIXME: Need new method for stopping workers
379374
for _ in futures:
380375
requests_queue.put(None)
381376

src/guidellm/scheduler/types.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,43 @@
1-
from typing import TypeVar
1+
from dataclasses import dataclass
2+
from multiprocessing import Queue
3+
from typing import Generic, Literal, Optional, TypeVar
24

3-
__all__ = ["RequestT", "ResponseT"]
5+
from guidellm.request.session import RequestSession
6+
from guidellm.scheduler.result import SchedulerRequestInfo
7+
8+
__all__ = [
9+
"MPQueues",
10+
"RequestT",
11+
"ResponseT",
12+
"WorkerProcessRequestTime",
13+
"WorkerProcessResult",
14+
]
415

516

617
RequestT = TypeVar("RequestT")
718
ResponseT = TypeVar("ResponseT")
19+
20+
21+
# TODO: Move dataclasses somewhere else
22+
23+
24+
@dataclass
25+
class WorkerProcessRequestTime:
26+
start_time: float
27+
timeout_time: float
28+
queued_time: float
29+
30+
31+
@dataclass
32+
class WorkerProcessResult(Generic[RequestT, ResponseT]):
33+
type_: Literal["request_scheduled", "request_start", "request_complete"]
34+
request: RequestT
35+
response: Optional[ResponseT]
36+
info: SchedulerRequestInfo
37+
38+
39+
@dataclass
40+
class MPQueues(Generic[RequestT, ResponseT]):
41+
requests: Queue[RequestSession[RequestT, ResponseT]]
42+
times: Queue[WorkerProcessRequestTime]
43+
responses: Queue[WorkerProcessResult[RequestT, ResponseT]]

0 commit comments

Comments
 (0)