Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"openai>=2.14.0",
"pandas>=2.3.3",
"pydantic>=2.12.5",
"pydantic-ai-slim[openai,openrouter,retries]>=1.52.0",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
"pytest-cov>=7.0.0",
Expand Down
67 changes: 33 additions & 34 deletions server/src/celery/tasks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import logging
import random
from typing import Dict

from httpx import AsyncClient, HTTPStatusError
from pydantic_ai.retries import AsyncTenacityTransport, RetryConfig, wait_retry_after
from tenacity import retry_if_exception_type, stop_after_attempt, wait_exponential

from celery import Task
from src.crud.jobtask_crud import JobTaskCrud
from src.crud.project_crud import ProjectCrud
Expand Down Expand Up @@ -43,28 +46,24 @@ async def _publish_redis_event(redis, event_name, value):
)


async def _run_with_retry(
func, max_retries: int = 3, base_delay: float = 1.0, jitter: float = 0.3
):
retries = 0
while True:
try:
return await func()
except Exception as e:
retries += 1
if retries > max_retries:
raise

delay = base_delay * (2 ** (retries - 1))

jitter_amount = delay * jitter
delay = delay + random.uniform(-jitter_amount, jitter_amount)

logger.warning(
f"Retrying job task (attempt {retries}/{max_retries}) - Error: {e}"
)

await asyncio.sleep(delay)
def _create_retrying_client(max_attempts: int = 3, max_wait_seconds=60) -> AsyncClient:
def should_retry_status(response):
if response.status_code in (429, 502, 503, 504):
response.raise_for_status()

transport = AsyncTenacityTransport(
config=RetryConfig(
retry=retry_if_exception_type((HTTPStatusError, ConnectionError)),
wait=wait_retry_after(
fallback_strategy=wait_exponential(multiplier=1, max=60),
max_wait=max_wait_seconds,
),
stop=stop_after_attempt(max_attempts),
reraise=True,
),
validate_response=should_retry_status,
)
return AsyncClient(transport=transport)


async def _process_job_task(
Expand All @@ -76,7 +75,7 @@ async def _process_job_task(
redis,
counter: Dict[str, int],
counter_lock: asyncio.Lock,
max_retries: int,
client: AsyncClient,
):
async with semaphore:
try:
Expand All @@ -98,15 +97,13 @@ async def _process_job_task(
{"job_task_id": job_task.id, "status": JobTaskStatus.RUNNING},
)

llm_result = await _run_with_retry(
lambda: get_structured_response(
llm_service,
paper_service,
job_task,
job_data,
project_criteria,
),
max_retries=max_retries,
llm_result = await get_structured_response(
llm_service,
paper_service,
job_task,
job_data,
project_criteria,
client,
)

await jobtask_crud.update_job_task_result(job_task.id, llm_result)
Expand Down Expand Up @@ -179,6 +176,8 @@ async def process_job(
logger.info("process_job: Starting to process job %s", job_id)
redis = get_redis_client()

client = _create_retrying_client(max_attempts=max_retries)

async with DBContext() as db_ctx:
project_crud = db_ctx.crud(ProjectCrud)
jobtask_crud = db_ctx.crud(JobTaskCrud)
Expand Down Expand Up @@ -211,7 +210,7 @@ async def process_job(
redis=redis,
counter=counter,
counter_lock=counter_lock,
max_retries=max_retries,
client=client,
)
for jt_id in job_task_ids
]
Expand Down
70 changes: 42 additions & 28 deletions server/src/core/llm/providers/local_openai_sdk.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Any, List, Type

from httpx import AsyncClient
from openai.types.model import Model
from pydantic import BaseModel, Field
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
from pydantic_ai.output import ToolOutput
from pydantic_ai.providers.openai import OpenAIProvider as PAI_OpenAIProvider

from src.core.llm.providers.provider import T, BaseLLMParams, LLMProvider
from openai.types.model import Model
from src.core.llm.providers.provider import BaseLLMParams, LLMProvider, T


class LocalOpenAISDKProviderParams(BaseModel):
Expand All @@ -29,42 +34,51 @@ class LocalOpenAISDKProvider(
config_parameters = []

async def generate_answer_async(
self, model_parameters: dict[str, Any], schema: Type[T], prompt: str
) -> tuple[T, str]:
self,
client: AsyncClient,
model_parameters: dict[str, Any],
schema: Type[T],
prompt: str,
) -> T:
model_cfg = self.parse_model_parameters(model_parameters)

from openai import AsyncOpenAI

if self.provider_parameters is None:
raise RuntimeError("Provider parameters needs to be defined")

if self.runtime_parameters.model is None:
raise RuntimeError("Model needs to be defined")

async with AsyncOpenAI(
provider = PAI_OpenAIProvider(
api_key="Foo",
base_url=self.provider_parameters.base_url,
) as client:
try:
response = await client.responses.parse(
model=self.runtime_parameters.model,
input=[
{
"role": "system",
"content": self.runtime_parameters.system_prompt,
},
{"role": "user", "content": prompt},
],
top_p=model_cfg.top_p,
temperature=model_cfg.temperature,
# Structured Outputs is available in OpenAI's latest large language models, starting with GPT-4o
text_format=schema,
)
if response.output_parsed is None:
raise RuntimeError("Output from LLM was empty")
return response.output_parsed, ""
except Exception as e:
raise RuntimeError("LLM call failed") from e
http_client=client,
)

settings = OpenAIResponsesModelSettings(
temperature=model_cfg.temperature,
top_p=model_cfg.top_p,
)

model = OpenAIResponsesModel(
str(self.runtime_parameters.model),
provider=provider,
settings=settings,
)

agent = Agent(
model,
system_prompt=self.runtime_parameters.system_prompt,
retries=3,
output_retries=5,
output_type=ToolOutput(schema, name=schema.__name__.lower()),
)

result = await agent.run(prompt)

if result.output is None:
raise RuntimeError("Output from LLM was empty")

return result.output

async def get_available_models(self) -> List[Model]:
if self.provider_parameters is None:
Expand Down
68 changes: 35 additions & 33 deletions server/src/core/llm/providers/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
from typing import Any, List

from httpx import AsyncClient
from pydantic import BaseModel, Field

from src.core.llm.providers.provider import (
Expand Down Expand Up @@ -55,8 +56,12 @@ def __init__(
config_parameters = []

async def generate_answer_async(
self, model_parameters: dict[str, Any], schema: type[T], prompt
) -> tuple[StructuredResponse, str]:
self,
client: AsyncClient,
model_parameters: dict[str, Any],
schema: type[T],
prompt,
) -> StructuredResponse:
if self.provider_parameters is None:
raise RuntimeError("Provider parameters needs to be defined")

Expand All @@ -67,38 +72,35 @@ async def generate_answer_async(
delay_ms = max(0.0, self.provider_parameters.delay + jitter_ms)
await asyncio.sleep(delay_ms / 1000.0)

return (
StructuredResponse(
overall_decision=Decision(
binary_decision=True,
probability_decision=1.0,
likert_decision=LikertDecision.stronglyAgree,
reason="The paper completely meets the inclusion criteria.",
),
inclusion_criteria=[
Criterion(
name="Example criteria",
decision=Decision(
binary_decision=True,
probability_decision=1.0,
likert_decision=LikertDecision.stronglyAgree,
reason="The criteria is met.",
),
)
],
exclusion_criteria=[
Criterion(
name="Example criteria",
decision=Decision(
binary_decision=False,
probability_decision=0.0,
likert_decision=LikertDecision.stronglyDisagree,
reason="The criteria is not met.",
),
)
],
return StructuredResponse(
overall_decision=Decision(
binary_decision=True,
probability_decision=1.0,
likert_decision=LikertDecision.stronglyAgree,
reason="The paper completely meets the inclusion criteria.",
),
"",
inclusion_criteria=[
Criterion(
name="Example criteria",
decision=Decision(
binary_decision=True,
probability_decision=1.0,
likert_decision=LikertDecision.stronglyAgree,
reason="The criteria is met.",
),
)
],
exclusion_criteria=[
Criterion(
name="Example criteria",
decision=Decision(
binary_decision=False,
probability_decision=0.0,
likert_decision=LikertDecision.stronglyDisagree,
reason="The criteria is not met.",
),
)
],
)

async def get_available_models(self) -> List[Model]:
Expand Down
Loading