Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ backend-unit:

# Run all tests in the backend
backend-test:
APP_ENV=test docker compose -f docker-compose-dev.yml -p test up -d backend postgres redis celery
APP_ENV=test docker compose -f docker-compose-dev.yml -p test down -v
APP_ENV=test docker compose -f docker-compose-dev.yml -p test up -d --build backend postgres redis celery
APP_ENV=test RUN_MIGRATIONS=true docker compose -f docker-compose-dev.yml -p test run --rm backend uv run pytest -m asyncio -v -s --cov=src $(REPORT)
APP_ENV=test docker compose -f docker-compose-dev.yml -p test down

Expand Down
2 changes: 1 addition & 1 deletion client/src/components/PaperCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export const PaperCard: React.FC<
"text-gray-400": paper.avg_probability_decision == null,
})}
>
{paper.avg_probability_decision
{paper.avg_probability_decision != null
? paper.avg_probability_decision.toFixed(3)
: hasErrors
? "ERROR"
Expand Down
35 changes: 35 additions & 0 deletions server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ZeroShotPromptingConfig,
)
from src.schemas.project import Criteria, ProjectCreate
from src.schemas.llm import StructuredResponse, Criterion, Decision, LikertDecision
from src.tools.diagnostics.db_check import run_migration

# @pytest.fixture(scope="function")
Expand Down Expand Up @@ -55,6 +56,40 @@ def test_job_data(test_project_uuid):
)


@pytest.fixture
def test_structured_response():
return StructuredResponse(
overall_decision=Decision(
binary_decision=True,
probability_decision=0.85,
likert_decision=LikertDecision.agree,
reason="Mock reason",
),
inclusion_criteria=[
Criterion(
name="A",
decision=Decision(
binary_decision=True,
probability_decision=0.9,
likert_decision=LikertDecision.agree,
reason="Included",
),
)
],
exclusion_criteria=[
Criterion(
name="D",
decision=Decision(
binary_decision=False,
probability_decision=0.1,
likert_decision=LikertDecision.disagree,
reason="Excluded",
),
)
],
)


@pytest_asyncio.fixture
async def test_files_working():
file1 = UploadFile(
Expand Down
248 changes: 152 additions & 96 deletions server/src/celery/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from contextlib import nullcontext
import random
from typing import Dict

from celery import Task
from src.crud.jobtask_crud import JobTaskCrud
Expand All @@ -10,6 +11,7 @@
from src.redis_client.client import REDIS_CHANNEL, get_redis_client
from src.schemas.job import JobCreate
from src.schemas.jobtask import JobTaskStatus
from src.schemas.project import Criteria
from src.services.llm_service import create_llm_service
from src.services.paper_service import create_paper_service
from src.tools.llm_decision_creator import get_structured_response
Expand All @@ -18,13 +20,6 @@
logger = logging.getLogger(__name__)


@celery_app.task(name="tasks.process_job", bind=True)
def process_job_task(self: Task, job_id: int, job_data: dict):
job_data_unpacked = JobCreate.model_validate(job_data, strict=True)
logger.info("Running job task using asyncio, ID: %s", job_id)
asyncio.run(async_process_job(self, job_id, job_data_unpacked))


@celery_app.task(name="tasks.test_task")
def test_task(name: str):
import time
Expand All @@ -35,131 +30,192 @@ def test_task(name: str):
return f"Hello, {name}!"


async def _async_retry_job_task(func, max_retries=3, base_delay=1):
@celery_app.task(name="tasks.process_job", bind=True)
def process_job_task(self: Task, job_id: int, job_data: dict):
job_data_unpacked = JobCreate.model_validate(job_data, strict=True)
logger.info("Running job task using asyncio, ID: %s", job_id)
asyncio.run(process_job(self, job_id, job_data_unpacked))


async def _publish_redis_event(redis, event_name, value):
await redis.publish(
REDIS_CHANNEL, QueueItem(event_name=event_name, value=value).model_dump_json()
)


async def _run_with_retry(
func, max_retries: int = 3, base_delay: float = 1.0, jitter: float = 0.3
):
retries = 0
while True:
try:
return await func()
except Exception as e:
retries += 1
if retries > max_retries:
raise

delay = base_delay * (2 ** (retries - 1))

jitter_amount = delay * jitter
delay = delay + random.uniform(-jitter_amount, jitter_amount)

logger.warning(
f"Retrying job task (attempt {retries}/{max_retries}) - Error: {e}"
)
if retries > max_retries:
raise e
delay = base_delay * (2 ** (retries - 1))

await asyncio.sleep(delay)


async def async_process_job(
async def _process_job_task(
celery_task: Task,
job_id: int,
job_task_id: int,
job_data: JobCreate,
db_ctx: DBContext | None = None,
project_criteria: Criteria,
semaphore: asyncio.Semaphore,
redis,
counter: Dict[str, int],
counter_lock: asyncio.Lock,
max_retries: int,
):
logger.info("async_process_job: Starting to process job %s", job_id)
redis = get_redis_client()

# Check that who owns the session
close_session = False
if db_ctx is None:
db_ctx = DBContext()
close_session = True

async with (
db_ctx if close_session else nullcontext(db_ctx)
): # Use nullcontext if session has been created
project_crud = db_ctx.crud(ProjectCrud)
jobtask_crud = db_ctx.crud(JobTaskCrud)
llm_service = create_llm_service(db_ctx)
paper_service = create_paper_service(db_ctx)

logger.info("Fetching project by UUID %s", job_data.project_uuid)
project = await project_crud.fetch_project_by_uuid(job_data.project_uuid)
if project is None:
raise RuntimeError("Project not found")

logger.info("Updating job task status to %s", JobTaskStatus.PENDING)
await jobtask_crud.update_job_tasks_status(job_id, JobTaskStatus.PENDING)
await db_ctx.commit() if close_session else await db_ctx.session.flush()
async with semaphore:
try:
async with DBContext() as task_db_ctx:
jobtask_crud = task_db_ctx.crud(JobTaskCrud)
job_task = await jobtask_crud.fetch_job_task_by_id(job_task_id)

job_tasks = await jobtask_crud.fetch_job_tasks_by_job_id(job_id)
llm_service = create_llm_service(task_db_ctx)
paper_service = create_paper_service(task_db_ctx)

for i, job_task in enumerate(job_tasks):
try:
await jobtask_crud.update_job_task_status(
job_task.id, JobTaskStatus.RUNNING
)
await db_ctx.commit() if close_session else await db_ctx.session.flush()
await redis.publish(
REDIS_CHANNEL,
QueueItem(
event_name=EventName.JOB_TASK_RUNNING,
value={
"job_task_id": job_task.id,
"status": JobTaskStatus.RUNNING,
"current": i + 1,
"total": len(job_tasks),
},
).model_dump_json(),
)
celery_task.update_state(
state="PROGRESS",
meta={"current": i + 1, "total": len(job_tasks)},
await task_db_ctx.commit()

await _publish_redis_event(
redis,
EventName.JOB_TASK_RUNNING,
{"job_task_id": job_task.id, "status": JobTaskStatus.RUNNING},
)
llm_result = await _async_retry_job_task(

llm_result = await _run_with_retry(
lambda: get_structured_response(
llm_service,
paper_service,
job_task,
job_data,
project.criteria,
)
project_criteria,
),
max_retries=max_retries,
)

await jobtask_crud.update_job_task_result(job_task.id, llm_result)
logger.info("Updating job task status to %s", JobTaskStatus.DONE)
await jobtask_crud.update_job_task_status(
job_task.id, JobTaskStatus.DONE
)
await db_ctx.commit() if close_session else await db_ctx.session.flush()
await redis.publish(
REDIS_CHANNEL,
QueueItem(
event_name=EventName.JOB_TASK_DONE,
value={
"job_task_id": job_task.id,
"status": JobTaskStatus.DONE,
},
).model_dump_json(),
await task_db_ctx.commit()

await _publish_redis_event(
redis,
EventName.JOB_TASK_DONE,
{"job_task_id": job_task.id, "status": JobTaskStatus.DONE},
)

except Exception as e:
logger.info("Updating job task status to %s", JobTaskStatus.ERROR)
await jobtask_crud.update_job_task_status(
job_task.id, JobTaskStatus.ERROR
except Exception as e:
try:
async with DBContext() as task_err_db_ctx:
err_jobtask_crud = task_err_db_ctx.crud(JobTaskCrud)
await err_jobtask_crud.update_job_task_status(
job_task_id, JobTaskStatus.ERROR
)
await err_jobtask_crud.update_job_task_error(job_task_id, str(e))
await task_err_db_ctx.commit()
except Exception as err_db_exc:
logger.exception(
"Failed to write error to database for job_task %s: %s",
job_task_id,
err_db_exc,
)
await jobtask_crud.update_job_task_error(job_task.id, str(e))

await redis.publish(
REDIS_CHANNEL,
QueueItem(
event_name=EventName.JOB_TASK_ERROR,
value={
"job_task_id": job_task.id,
"status": JobTaskStatus.ERROR,
"message": str(e),
},
).model_dump_json(),

try:
await _publish_redis_event(
redis,
EventName.JOB_TASK_ERROR,
{
"job_task_id": job_task_id,
"status": JobTaskStatus.ERROR,
"message": str(e),
},
)
except Exception:
logger.exception("Failed to publish error %s", job_task_id)

celery_task.update_state(
state="FAILURE",
meta={"error": str(e)},
try:
celery_task.update_state(state="FAILURE", meta={"error": str(e)})
except Exception:
logger.exception(
"Failed to update celery state for job_task %s", job_task_id
)
logger.error(e)
await db_ctx.commit() if close_session else await db_ctx.session.flush()
finally:
async with counter_lock:
counter["completed"] += 1
completed = counter["completed"]
total = counter["total"]
try:
celery_task.update_state(
state="PROGRESS", meta={"current": completed, "total": total}
)
except Exception:
logger.exception("Failed to update celery progress counter")


continue
async def process_job(
celery_task: Task,
job_id: int,
job_data: JobCreate,
max_concurrent_tasks: int = 20,
max_retries: int = 3,
):
logger.info("process_job: Starting to process job %s", job_id)
redis = get_redis_client()

async with DBContext() as db_ctx:
project_crud = db_ctx.crud(ProjectCrud)
jobtask_crud = db_ctx.crud(JobTaskCrud)

logger.info("Fetching project by UUID %s", job_data.project_uuid)
project = await project_crud.fetch_project_by_uuid(job_data.project_uuid)
if project is None:
raise RuntimeError("Project not found")

project_criteria = project.criteria

return {"result": "all job tasks processed"}
logger.info("Updating job task status to %s", JobTaskStatus.PENDING)
await jobtask_crud.update_job_tasks_status(job_id, JobTaskStatus.PENDING)
await db_ctx.commit()

job_tasks = await jobtask_crud.fetch_job_tasks_by_job_id(job_id)
job_task_ids = [jt.id for jt in job_tasks]

semaphore = asyncio.Semaphore(max_concurrent_tasks)
counter_lock = asyncio.Lock()
counter = {"completed": 0, "total": len(job_task_ids)}

tasks = [
_process_job_task(
celery_task=celery_task,
job_task_id=jt_id,
job_data=job_data,
project_criteria=project_criteria,
semaphore=semaphore,
redis=redis,
counter=counter,
counter_lock=counter_lock,
max_retries=max_retries,
)
for jt_id in job_task_ids
]

await asyncio.gather(*tasks)

return {"result": "all job tasks processed"}
5 changes: 5 additions & 0 deletions server/src/crud/jobtask_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ async def fetch_job_tasks_by_job_uuid(self, job_uuid: UUID) -> Sequence[JobTask]
result = await self.db.execute(stmt)
return result.scalars().all()

async def fetch_job_task_by_id(self, job_task_id: int) -> JobTask:
stmt = select(JobTask).where(JobTask.id == job_task_id)
result = await self.db.execute(stmt)
return result.scalar_one_or_none()

async def fetch_job_tasks_by_paper_uuid(
self, paper_uuid: UUID
) -> Sequence[Row[Tuple[JobTask, Job]]]:
Expand Down
Loading