diff --git a/backend/app/domain/services/base/task.py b/backend/app/domain/services/base/task.py index ea02501a..f2a05bac 100644 --- a/backend/app/domain/services/base/task.py +++ b/backend/app/domain/services/base/task.py @@ -20,6 +20,7 @@ from app.domain.services.builder_and_evaluation.eval_utils.metrics_dicts import ( meta_metrics_dict, ) +from app.infrastructure.models.models import AccessTypeEnum from app.infrastructure.repositories.dataset import DatasetRepository from app.infrastructure.repositories.example import ExampleRepository from app.infrastructure.repositories.historical_data import HistoricalDataRepository @@ -412,7 +413,7 @@ def get_task_with_round_and_metric_data(self, task_id_or_code: Union[int, str]): scoring_dataset_list = [] for dataset in datasets: dataset_list.append({"id": dataset.id, "name": dataset.name}) - if dataset.access_type == "scoring": + if dataset.access_type == AccessTypeEnum.scoring: scoring_dataset_list.append( { "id": dataset.id, diff --git a/backend/app/infrastructure/repositories/dataset.py b/backend/app/infrastructure/repositories/dataset.py index 314be5d7..a335ade2 100644 --- a/backend/app/infrastructure/repositories/dataset.py +++ b/backend/app/infrastructure/repositories/dataset.py @@ -7,7 +7,7 @@ # LICENSE file in the root directory of this source tree. from app.domain.schemas.base.dataset import UpdateDatasetInfo -from app.infrastructure.models.models import Dataset +from app.infrastructure.models.models import AccessTypeEnum, Dataset from app.infrastructure.repositories.abstract import AbstractRepository @@ -17,7 +17,8 @@ def __init__(self) -> None: def get_scoring_datasets(self, task_id: int, dataset_name: str = None) -> dict: scoring_datasets = self.session.query(self.model).filter( - (self.model.access_type == "scoring") & (self.model.tid == task_id) + (self.model.access_type == AccessTypeEnum.scoring) + & (self.model.tid == task_id) ) if dataset_name: scoring_datasets = scoring_datasets.filter(self.model.name == dataset_name) @@ -33,7 +34,8 @@ def get_scoring_datasets(self, task_id: int, dataset_name: str = None) -> dict: def get_not_scoring_datasets(self, task_id: int) -> dict: no_scoring_datasets = self.session.query(self.model).filter( - (self.model.access_type != "scoring") & (self.model.tid == task_id) + (self.model.access_type != AccessTypeEnum.scoring) + & (self.model.tid == task_id) ) jsonl_no_scoring_datasets = [] @@ -66,7 +68,7 @@ def get_order_scoring_datasets_by_task_id(self, task_id: int) -> dict: self.session.query(self.model) .order_by(self.model.id) .filter(self.model.tid == task_id) - .filter(self.model.access_type == "scoring") + .filter(self.model.access_type == AccessTypeEnum.scoring) .all() ) @@ -95,7 +97,7 @@ def get_scoring_datasets_by_task_id(self, task_id: int) -> dict: return ( self.session.query(self.model.id) .filter(self.model.tid == task_id) - .filter(self.model.access_type == "scoring") + .filter(self.model.access_type == AccessTypeEnum.scoring) .all() ) @@ -114,7 +116,7 @@ def get_downstream_datasets(self, task_id: int) -> dict: downstream_datasets = ( self.session.query(self.model) .filter(self.model.tid == task_id) - .filter(self.model.access_type == "scoring") + .filter(self.model.access_type == AccessTypeEnum.scoring) ) jsonl_downstream_datasets = [] for downstream_dataset in downstream_datasets: