diff --git a/client/src/components/EventStream.tsx b/client/src/components/EventStream.tsx index 1a80b70..52f6b72 100644 --- a/client/src/components/EventStream.tsx +++ b/client/src/components/EventStream.tsx @@ -2,6 +2,8 @@ import { useCallback, useEffect, useState } from "react"; import { CircleAlert } from "lucide-react"; import * as z from "zod"; import classNames from "classnames"; +import { useTypedStoreActions } from "../state/store"; +import type { JobStats } from "../state/types"; const EventName = { // Events for JobTask-related things @@ -16,6 +18,7 @@ const EventName = { // Events for Job-related things JOB_COMPLETE: 3001, JOB_CREATED: 3002, + JOB_PROGRESS: 3003, // Events for Project-related things PROJECT_CREATED: 4001, PROJECT_FILE_UPLOADED: 4002, @@ -84,19 +87,37 @@ export const EventStream = () => { const [connected, setConnected] = useState(false); const [logs, setLogs] = useState>([]); - // TODO: Integrate into state management to update job task status in real time + // const setJobsForProject = useTypedStoreActions( + // (actions) => actions.setJobsForProject + // ); + const updateJobStats = useTypedStoreActions( + (actions) => actions.updateJobStats + ); + const _onMessage = useCallback((event: MessageEvent) => { const { data } = event; if (typeof data === "string") { const dataJson = JSON.parse(data); - // console.log(dataJson); const parsedData = EventData.safeParse(dataJson); if (!parsedData.error) { - setLogs((logs) => [...logs, parsedData.data]); + const eventData = parsedData.data; + setLogs((logs) => [...logs, eventData]); + + switch (eventData.event_name) { + case EventName.JOB_PROGRESS: + { + const jobId = eventData.value.job_id; + const stats: JobStats = eventData.value.stats; + updateJobStats({ jobId, stats }); + } + break; + default: + break; + } } } - }, []); + }, [updateJobStats]); const startLogStream = useCallback(() => { const eventSource = new EventSource(event_url); diff --git a/client/src/pages/ProjectPage.tsx b/client/src/pages/ProjectPage.tsx index 5e7bb27..29bff80 100644 --- a/client/src/pages/ProjectPage.tsx +++ b/client/src/pages/ProjectPage.tsx @@ -7,8 +7,7 @@ import { DropdownMenuText, DropdownOption } from "../components/DropDownMenus"; import { FileDropArea } from "../components/FileDropArea"; import { ExpandableToast } from "../components/ExpandableToast"; import { TruncatedFileNames } from "../components/TruncatedFileNames"; -import { fetchJobTasksFromBackend } from "../services/jobTaskService"; -import { createJob, fetchJobsForProject } from "../services/jobService"; +import { createJob } from "../services/jobService"; import { fileUploadToBackend, fileFetchFromBackend, @@ -17,9 +16,6 @@ import { ManualEvaluationModal } from "../components/ManualEvaluationModal"; import { Button } from "../components/Button"; import { FetchedFile, - JobTask, - JobTaskStatus, - CreatedJob, LlmConfig, createZeroShotPromptingConfig, JobPromptingType, @@ -88,11 +84,10 @@ const ModelConfiguration: React.FC = ({ {Object.keys(modelParametersSchema.properties).map((key) => { const property = modelParametersSchema.properties[key]; return ( - {`${property.title}: ${ - modelFormValues[key] !== undefined && + {`${property.title}: ${modelFormValues[key] !== undefined && modelFormValues[key] !== "" && modelFormValues[key] - }`} + }`} ); })} @@ -124,7 +119,7 @@ const ModelConfiguration: React.FC = ({ {modelFormValues[key] !== undefined && - modelFormValues[key] !== "" ? ( + modelFormValues[key] !== "" ? ( <>{modelFormValues[key]} ) : ( "" @@ -236,8 +231,8 @@ const ProviderConfiguration: React.FC = ({ {providerFormValues[key] !== undefined && - property.type !== "string" && - providerFormValues[key] !== "" ? ( + property.type !== "string" && + providerFormValues[key] !== "" ? ( <>{providerFormValues[key]} ) : ( "" @@ -413,7 +408,7 @@ export const ProjectPage = () => { const [evaluateViewMatch] = useRoute("/project/:projectUuid/evaluate"); const [fewShotViewMatch] = useRoute("/project/:projectUuid/few_shot"); const search = useSearch(); - const jobTaskRefetchIntervalMs = 5000; + const [isLlmProviderSelected, setIsLlmProviderSelected] = useState(false); const [modelsLoaded, setModelsLoaded] = useState(false); const [isLlmSelected, setIsLlmSelected] = useState(false); @@ -422,12 +417,11 @@ export const ProjectPage = () => { const getPapers = useTypedStoreState((state) => state.getPapersForProject); const papers = getPapers(projectUuid); - const [createdJobs, setCreatedJobs] = useState([]); const [fetchedFiles, setFetchedFiles] = useState([]); - const [jobTasks, setJobTasks] = useState([]); const [availableModels, setAvailableModels] = useState< Array<{ id: string; created: number; object: "model"; owned_by: string }> >([]); + const loadingProjects = useTypedStoreState((state) => state.loading.projects); const loadProjects = useTypedStoreActions((actions) => actions.fetchProjects); const getProjectByUuid = useTypedStoreState( @@ -436,6 +430,13 @@ export const ProjectPage = () => { const providers = useTypedStoreState((state) => state.providers); const fetchPapers = useTypedStoreActions((actions) => actions.fetchPapers); + const fetchJobsForProject = useTypedStoreActions( + (actions) => actions.fetchJobsForProject, + ); + const jobs = useTypedStoreState( + (state) => state.jobsByProject[projectUuid] || [], + ); + const project = getProjectByUuid(projectUuid); useEffect(() => { @@ -446,6 +447,12 @@ export const ProjectPage = () => { // eslint-disable-next-line react-hooks/exhaustive-deps }, [project, projectUuid]); + useEffect(() => { + if (projectUuid) { + fetchJobsForProject(projectUuid); + } + }, [projectUuid, fetchJobsForProject]); + const paperUuid = useMemo(() => { if (!search) return null; return new URLSearchParams(search).get("paperUuid"); @@ -513,23 +520,7 @@ export const ProjectPage = () => { [papers], ); - const evaluationFinished = jobTasks.length > 0 && pendingTasks.length === 0; - - const fetchJobs = useCallback(() => { - async function doFetch() { - try { - const jobs = await fetchJobsForProject(projectUuid); - setCreatedJobs(jobs); - } catch (e) { - console.error("Failed to fetch jobs for project", e); - } - } - doFetch().catch(console.error); - }, [projectUuid]); - - useEffect(() => { - fetchJobs(); - }, [fetchJobs, projectUuid]); + const evaluationFinished = jobs.length === 0 && pendingTasks.length === 0; const fetchModels = useCallback(() => { async function fetch_models() { @@ -545,7 +536,7 @@ export const ProjectPage = () => { } catch (error) { console.error( "Failed to fetch available models for provider " + - selectedLlmProvider.value, + selectedLlmProvider.value, error, ); } @@ -557,7 +548,7 @@ export const ProjectPage = () => { const paperToTaskMap = useMemo(() => { if ( papers.length === 0 || - jobTasks.length === 0 || + jobs.length === 0 || pendingTasks.length === 0 ) { return {}; @@ -579,7 +570,7 @@ export const ProjectPage = () => { } }); return map; - }, [papers, jobTasks, pendingTasks]); + }, [papers, jobs, pendingTasks]); const currentTaskUuid = paperUuid ? paperToTaskMap[paperUuid] : undefined; @@ -602,17 +593,8 @@ export const ProjectPage = () => { const promptingConfig = createZeroShotPromptingConfig(); try { - const res = await createJob(projectUuid, llmConfig, promptingConfig); - const createdJob: CreatedJob = { - uuid: res.uuid, - project_uuid: res.project_uuid, - llm_config: res.llm_config, - prompting_config: res.prompting_config, - created_at: res.created_at, - updated_at: res.updated_at, - }; - setCreatedJobs((prev) => [...prev, createdJob]); - // await loadPapers(); + await createJob(projectUuid, llmConfig, promptingConfig); + fetchJobsForProject(projectUuid); } catch (e) { console.error("Error creating job:", e); toast.error("Error creating job"); @@ -623,6 +605,7 @@ export const ProjectPage = () => { modelFormValues, providerFormValues, projectUuid, + fetchJobsForProject, ]); const uploadFilesToBackend = useCallback( @@ -665,7 +648,6 @@ export const ProjectPage = () => { try { await uploadFilesToBackend(files); await fetchFiles(); - // await loadPapers(); } catch (error) { console.error("Problem uploading the files", error); } @@ -687,31 +669,6 @@ export const ProjectPage = () => { })(); }, [fetchFiles]); - useEffect(() => { - if (createdJobs.length === 0) return; - - const fetchAll = () => { - Promise.all( - createdJobs.map((job) => { - // console.log("job.uuid", job.uuid); - // @ts-expect-error Expected - return fetchJobTasksFromBackend(job.uuid, job.id); - }), - ) - .then((results) => { - setJobTasks(results.flat()); - // console.log("results: ", results.flat()); - }) - .catch((error) => { - console.error("Error fetching job tasks:", error); - }); - }; - - fetchAll(); - const interval = setInterval(fetchAll, jobTaskRefetchIntervalMs); - return () => clearInterval(interval); - }, [createdJobs, jobTaskRefetchIntervalMs]); - const openManualEvaluation = useCallback(() => { if (evaluationFinished) return; if (papers.length === 0) { @@ -730,7 +687,7 @@ export const ProjectPage = () => { if (idx !== -1) { for (let i = idx + 1; i < papers.length; i++) { const candidate = papers[i]; - if (jobTasks.length === 0 || paperToTaskMap[candidate.uuid]) { + if (jobs.length === 0 || paperToTaskMap[candidate.uuid]) { navigate( `/project/${projectUuid}/evaluate?paperUuid=${candidate.uuid}`, ); @@ -738,17 +695,15 @@ export const ProjectPage = () => { } } } - // await loadPapers(); navigate(`/project/${projectUuid}`); toast.success("Manual evaluation finished."); }, [ paperUuid, papers, - jobTasks.length, + jobs.length, paperToTaskMap, navigate, projectUuid, - // loadPapers, ]); useEffect(() => { @@ -835,19 +790,14 @@ export const ProjectPage = () => {
- {jobTasks.length === 0 && ( + {jobs.length === 0 && ( )} - {createdJobs.map((job) => { - const tasks = jobTasks.filter((task) => task.job_uuid === job.uuid); - const doneCount = tasks.filter( - (task) => task.status === JobTaskStatus.DONE, - ).length; - const errorCount = tasks.filter( - (task) => task.status === JobTaskStatus.ERROR, - ).length; - const totalCount = tasks.length; - const completedCount = doneCount + errorCount; + {jobs.map((job) => { + const successCount = job.stats.success; + const errorCount = job.stats.failed; + const totalCount = job.stats.total; + const completedCount = successCount + errorCount; const progress = totalCount === 0 ? 0 @@ -1140,7 +1090,7 @@ export const ProjectPage = () => { }} onClose={() => { loadProjects(); - fetchJobs(); + fetchJobsForProject(projectUuid); navigate(`/project/${projectUuid}`); }} /> diff --git a/client/src/state/store.ts b/client/src/state/store.ts index 459f341..77c287f 100644 --- a/client/src/state/store.ts +++ b/client/src/state/store.ts @@ -11,13 +11,21 @@ import { import * as projectsService from "../services/projectService"; import * as paperService from "../services/paperService"; import * as providerService from "../services/providerService"; -import type { JobTaskHumanResult, PaperWithModelEval, Provider } from "./types"; +import * as jobService from "../services/jobService"; +import type { + JobTaskHumanResult, + PaperWithModelEval, + Provider, + JobWithStats, + JobStats, +} from "./types"; import type { Project } from "./types/project"; const injections = { projectsService, paperService, providerService, + jobService, }; type LoadingModel = { @@ -44,6 +52,20 @@ interface ProjectModel { refreshProjects: Thunk; } +interface JobModel { + jobsByProject: Record; + jobsById: Record; + + setJobsForProject: Action< + StoreModel, + { projectUuid: string; jobs: JobWithStats[] } + >; + + updateJobStats: Action; + + fetchJobsForProject: Thunk; +} + interface PaperModel { // Papers are study-specific papers: Record>; @@ -99,7 +121,11 @@ interface ProviderModel { refreshProviders: Thunk; } -type StoreModel = {} & LoadingModel & ProjectModel & PaperModel & ProviderModel; +type StoreModel = {} & LoadingModel & + ProjectModel & + JobModel & + PaperModel & + ProviderModel; export type Injections = typeof injections; @@ -232,6 +258,45 @@ export const model = { return (name: string) => (state.providers || []).find((provider) => provider.name === name); }), + + jobsByProject: {}, + jobsById: {}, + + setJobsForProject: action((state, payload) => { + const { projectUuid, jobs } = payload; + + state.jobsByProject[projectUuid] = jobs; + + jobs.forEach((job) => { + state.jobsById[job.id] = job; + }); + }), + + updateJobStats: action((state, payload) => { + const { jobId, stats } = payload; + const job = state.jobsById[jobId]; + if (!job) return; + + const updatedJob = { ...job, stats }; + state.jobsById[jobId] = updatedJob; + + for (const projectUuid in state.jobsByProject) { + state.jobsByProject[projectUuid] = state.jobsByProject[projectUuid].map( + (j) => (j.id === jobId ? updatedJob : j), + ); + } + }), + + fetchJobsForProject: thunk(async (actions, projectUuid, { injections }) => { + const { jobService } = injections; + + const jobs = await jobService.fetchJobsForProject(projectUuid); + + actions.setJobsForProject({ + projectUuid, + jobs, + }); + }), } satisfies StoreModel; export const store = createStore(model, { diff --git a/client/src/state/types.ts b/client/src/state/types.ts index 35e23f6..a59c767 100644 --- a/client/src/state/types.ts +++ b/client/src/state/types.ts @@ -72,6 +72,14 @@ export enum JobTaskStatus { ERROR = "ERROR", } +export enum JobStatus { + NOT_STARTED = "NOT_STARTED", + RUNNING = "RUNNING", + PARTIAL_SUCCESS = "PARTIAL_SUCCESS", + SUCCESS = "SUCCESS", + FAILED = "FAILED", +} + export type JobTask = { uuid: string; job_uuid: string; @@ -87,6 +95,24 @@ export type JobTask = { error: string | null; }; +export type JobStats = { + total: number; + success: number; + failed: number; + status: JobStatus; +}; + +export type JobWithStats = { + uuid: string; + id: string; + project_uuid: string; + prompting_config: PromptingConfig; + llm_config: LlmConfig; + created_at: Date | null; + updated_at: Date | null; + stats: JobStats; +} + export type Paper = { uuid: string; paper_id: number; diff --git a/server/src/api/controllers/job.py b/server/src/api/controllers/job.py index 78e2b8a..26eb323 100644 --- a/server/src/api/controllers/job.py +++ b/server/src/api/controllers/job.py @@ -5,7 +5,7 @@ from src.db.db_context import DBContext, get_db_ctx from src.event_queue import EventName, QueueItem, push_event -from src.schemas.job import FewShotPromptingConfig, JobCreate, JobRead +from src.schemas.job import FewShotPromptingConfig, JobCreate, JobRead, JobReadWithStats from src.schemas.project import FewShotPreferences from src.services.job_service import create_job_service from src.services.project_service import ( @@ -17,7 +17,7 @@ @router.get( - "/job", status_code=status.HTTP_200_OK, response_model=list[JobRead], tags=["Job"] + "/job", status_code=status.HTTP_200_OK, response_model=list[JobReadWithStats], tags=["Job"] ) async def get_jobs( project: Optional[UUID] = None, db_ctx: DBContext = Depends(get_db_ctx) diff --git a/server/src/celery/tasks.py b/server/src/celery/tasks.py index 83c6298..cac99a1 100644 --- a/server/src/celery/tasks.py +++ b/server/src/celery/tasks.py @@ -11,6 +11,7 @@ from src.crud.project_crud import ProjectCrud from src.db.db_context import DBContext from src.event_queue import EventName, QueueItem +from src.helpers.resolve_job_status import resolve_job_status from src.redis_client.client import REDIS_CHANNEL, get_redis_client from src.schemas.job import JobCreate from src.schemas.jobtask import JobTaskStatus @@ -69,6 +70,7 @@ def should_retry_status(response): async def _process_job_task( celery_task: Task, job_task_id: int, + job_id: int, job_data: JobCreate, project_criteria: Criteria, semaphore: asyncio.Semaphore, @@ -112,6 +114,9 @@ async def _process_job_task( ) await task_db_ctx.commit() + async with counter_lock: + counter["success"] += 1 + await _publish_redis_event( redis, EventName.JOB_TASK_DONE, @@ -127,6 +132,10 @@ async def _process_job_task( ) await err_jobtask_crud.update_job_task_error(job_task_id, str(e)) await task_err_db_ctx.commit() + + async with counter_lock: + counter["failed"] += 1 + except Exception as err_db_exc: logger.exception( "Failed to write error to database for job_task %s: %s", @@ -147,23 +156,43 @@ async def _process_job_task( except Exception: logger.exception("Failed to publish error %s", job_task_id) - try: - celery_task.update_state(state="FAILURE", meta={"error": str(e)}) - except Exception: - logger.exception( - "Failed to update celery state for job_task %s", job_task_id - ) + # try: + # celery_task.update_state(state="FAILURE", meta={"error": str(e)}) + # except Exception: + # logger.exception( + # "Failed to update celery state for job_task %s", job_task_id + # ) finally: async with counter_lock: counter["completed"] += 1 completed = counter["completed"] total = counter["total"] - try: - celery_task.update_state( - state="PROGRESS", meta={"current": completed, "total": total} - ) - except Exception: - logger.exception("Failed to update celery progress counter") + success = counter["success"] + failed = counter["failed"] + + status = resolve_job_status(total, success, failed) + + # TODO: maybe buffer to avoid spamming + await _publish_redis_event( + redis, + EventName.JOB_PROGRESS, + { + "job_id": job_id, + "stats": { + "total": total, + "success": success, + "failed": failed, + "status": status, + }, + }, + ) + + try: + celery_task.update_state( + state="PROGRESS", meta={"current": completed, "total": total} + ) + except Exception: + logger.exception("Failed to update celery progress counter") async def process_job( @@ -198,12 +227,13 @@ async def process_job( semaphore = asyncio.Semaphore(max_concurrent_tasks) counter_lock = asyncio.Lock() - counter = {"completed": 0, "total": len(job_task_ids)} + counter = {"completed": 0, "success": 0, "failed": 0, "total": len(job_task_ids)} tasks = [ _process_job_task( celery_task=celery_task, job_task_id=jt_id, + job_id=job_id, job_data=job_data, project_criteria=project_criteria, semaphore=semaphore, diff --git a/server/src/crud/job_crud.py b/server/src/crud/job_crud.py index eeebc36..926064b 100644 --- a/server/src/crud/job_crud.py +++ b/server/src/crud/job_crud.py @@ -31,6 +31,7 @@ async def fetch_jobs_by_project(self, project_uuid: UUID): stmt = ( select( Job.uuid, + Job.id, Project.uuid.label("project_uuid"), Job.llm_config, Job.prompting_config, diff --git a/server/src/crud/jobtask_crud.py b/server/src/crud/jobtask_crud.py index 0c28b33..9523d95 100644 --- a/server/src/crud/jobtask_crud.py +++ b/server/src/crud/jobtask_crud.py @@ -1,13 +1,14 @@ from typing import List, Sequence, Tuple from uuid import UUID -from sqlalchemy import Row, select, update +from sqlalchemy import Row, case, func, select, update from sqlalchemy.ext.asyncio import AsyncSession from src.db.models.job import Job from src.db.models.jobtask import JobTask from src.db.models.paper import Paper -from src.schemas.jobtask import JobTaskCreate, JobTaskHumanResult +from src.db.models.project import Project +from src.schemas.jobtask import JobTaskCreate, JobTaskHumanResult, JobTaskStatus from src.schemas.llm import StructuredResponse from src.schemas.paper import PaperCreate @@ -57,6 +58,26 @@ async def fetch_job_tasks_by_paper_uuid( ) return (await self.db.execute(stmt)).all() + async def fetch_tasks_stats_by_project(self, project_uuid): + stmt = ( + select( + Job.uuid.label("job_uuid"), + func.count(JobTask.id).label("total_count"), + func.coalesce( + func.sum(case((JobTask.status == JobTaskStatus.DONE, 1), else_=0)) + ).label("success_count"), + func.coalesce( + func.sum(case((JobTask.status == JobTaskStatus.ERROR, 1), else_=0)) + ).label("failed_count"), + ) + .join(Job, JobTask.job_id == Job.id) + .join(Project, Project.id == Job.project_id) + .where(Project.uuid == project_uuid) + .group_by(Job.uuid) + ) + result = await self.db.execute(stmt) + return result.mappings().all() + async def update_job_task_status(self, job_task_id: int, status: str): stmt = update(JobTask).where(JobTask.id == job_task_id).values(status=status) await self.db.execute(stmt) diff --git a/server/src/event_queue.py b/server/src/event_queue.py index 1f0b9e1..0142663 100644 --- a/server/src/event_queue.py +++ b/server/src/event_queue.py @@ -17,6 +17,7 @@ class EventName(Enum): # Events for Job-related things JOB_COMPLETE = 3001 JOB_CREATED = 3002 + JOB_PROGRESS = 3003 # Events for Project-related things PROJECT_CREATED = 4001 PROJECT_FILE_UPLOADED = 4002 diff --git a/server/src/helpers/__init__.py b/server/src/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/src/helpers/resolve_job_status.py b/server/src/helpers/resolve_job_status.py new file mode 100644 index 0000000..8da79b1 --- /dev/null +++ b/server/src/helpers/resolve_job_status.py @@ -0,0 +1,19 @@ +from src.schemas.job import JobStatus + + +def resolve_job_status(total: int, success: int, failed: int) -> JobStatus: + if total == 0: + return JobStatus.NOT_STARTED + + finished = success + failed + + if finished < total: + return JobStatus.RUNNING + + if failed == 0: + return JobStatus.SUCCESS + + if success == 0: + return JobStatus.FAILED + + return JobStatus.PARTIAL_SUCCESS diff --git a/server/src/schemas/job.py b/server/src/schemas/job.py index 41805e3..932773d 100644 --- a/server/src/schemas/job.py +++ b/server/src/schemas/job.py @@ -1,7 +1,8 @@ +from datetime import datetime from enum import Enum from typing import Annotated, Any, List, Literal, Union from uuid import UUID -from datetime import datetime + from pydantic import BaseModel, ConfigDict, Field @@ -11,6 +12,21 @@ class JobPromptingType(str, Enum): FEW_SHOT = "FEW_SHOT" +class JobStatus(str, Enum): + NOT_STARTED = "NOT_STARTED" + RUNNING = "RUNNING" + PARTIAL_SUCCESS = "PARTIAL_SUCCESS" + SUCCESS = "SUCCESS" + FAILED = "FAILED" + + +class JobStats(BaseModel): + total: int + success: int + failed: int + status: JobStatus + + class LLMModelConfig(BaseModel): provider_name: str model_name: str @@ -55,3 +71,21 @@ class JobRead(BaseModel): updated_at: datetime model_config = ConfigDict(from_attributes=True) + + +class JobReadWithStats(BaseModel): + uuid: UUID + id: int + project_uuid: UUID + prompting_config: PromptingConfig + llm_config: LLMModelConfig + created_at: datetime + updated_at: datetime + + stats: JobStats + # total_tasks: int = 0 + # success_tasks: int = 0 + # failed_tasks: int = 0 + # status: JobStatus = JobStatus.NOT_STARTED + + model_config = ConfigDict(from_attributes=True) diff --git a/server/src/schemas/jobtask.py b/server/src/schemas/jobtask.py index 06219da..97f7562 100644 --- a/server/src/schemas/jobtask.py +++ b/server/src/schemas/jobtask.py @@ -43,6 +43,7 @@ class JobTaskRead(BaseModel): result: Optional[Dict[str, Any]] human_result: JobTaskHumanResult | None = None status_metadata: Optional[Dict[str, Any]] = None + error: Optional[str] = None @field_validator("result", mode="before") @classmethod @@ -65,6 +66,7 @@ class JobTaskReadWithLLMConfig(BaseModel): result: Optional[Dict[str, Any]] = None human_result: Optional[JobTaskHumanResult] = None status_metadata: Optional[Dict[str, Any]] = None + error: Optional[str] = None llm_config: Optional[Dict[str, Any]] = None prompting_config: Optional[Dict[str, Any]] = None diff --git a/server/src/services/job_service.py b/server/src/services/job_service.py index e409038..f34eca2 100644 --- a/server/src/services/job_service.py +++ b/server/src/services/job_service.py @@ -3,7 +3,8 @@ from src.crud.job_crud import JobCrud from src.db.db_context import DBContext -from src.schemas.job import JobCreate, JobRead +from src.helpers.resolve_job_status import resolve_job_status +from src.schemas.job import JobCreate, JobRead, JobReadWithStats, JobStats from src.schemas.jobtask import JobTaskRead from src.services.jobtask_service import JobTaskService, create_jobtask_service @@ -33,9 +34,32 @@ async def fetch_all(self) -> list[JobRead]: for row in rows ] - async def fetch_by_project(self, project_uuid: UUID) -> list[JobRead]: - rows = await self.job_crud.fetch_jobs_by_project(project_uuid) - return [JobRead(**row) for row in rows] + async def fetch_by_project(self, project_uuid: UUID) -> list[JobReadWithStats]: + jobs = await self.job_crud.fetch_jobs_by_project(project_uuid) + stats_rows = await self.jobtask_service.fetch_task_stats_by_project( + project_uuid + ) + + stats_map = {row["job_uuid"]: row for row in stats_rows} + + result = [] + for job in jobs: + stats = stats_map.get(job["uuid"]) + total = stats["total_count"] if stats else 0 + success = stats["success_count"] if stats else 0 + failed = stats["failed_count"] if stats else 0 + + job_status = resolve_job_status(total, success, failed) + result.append( + JobReadWithStats( + **job, + stats=JobStats( + total=total, success=success, failed=failed, status=job_status + ), + ) + ) + + return result async def fetch_by_uuid(self, uuid: UUID) -> JobRead: job = await self.job_crud.fetch_job_by_uuid(uuid) diff --git a/server/src/services/jobtask_service.py b/server/src/services/jobtask_service.py index 6853b88..f4eda81 100644 --- a/server/src/services/jobtask_service.py +++ b/server/src/services/jobtask_service.py @@ -59,6 +59,10 @@ async def fetch_job_tasks_for_paper(self, paper_uuid: UUID): for task, job in job_tasks_with_jobs ] + async def fetch_task_stats_by_project(self, project_uuid: UUID): + stats = await self.jobtask_crud.fetch_tasks_stats_by_project(project_uuid) + return stats + async def add_human_result(self, uuid: UUID, human_result: JobTaskHumanResult): await self.jobtask_crud.add_jobtask_human_result(uuid, human_result)