diff --git a/pyproject.toml b/pyproject.toml index e1f4527..5c4fd83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ dependencies = [ "optuna-integration", "psutil", "tabulate>=0.9.0", -"torch<=2.2.2", +"torch", "seaborn" ] @@ -94,7 +94,8 @@ dev = [ test = [ "coverage", "pytest", - "pytest-cov" + "pytest-cov", + "deepdiff" ] utility = [ diff --git a/terratorch_iterate/iterate_types.py b/terratorch_iterate/iterate_types.py index 1ff0201..dd9bca9 100644 --- a/terratorch_iterate/iterate_types.py +++ b/terratorch_iterate/iterate_types.py @@ -6,7 +6,7 @@ import copy import enum from dataclasses import dataclass, field, replace -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TYPE_CHECKING from terratorch.tasks import ( ClassificationTask, MultiLabelClassificationTask, @@ -16,6 +16,20 @@ ) from torchgeo.datamodules import BaseDataModule +import logging + +try: + from geobench_v2.datamodules import GeoBenchDataModule + GEOBENCH_AVAILABLE = True +except ImportError: + GeoBenchDataModule = None # type: ignore + GEOBENCH_AVAILABLE = False + logging.getLogger("terratorch").debug("geobench_v2 not installed") + + +if TYPE_CHECKING: + from geobench_v2.datamodules import GeoBenchDataModule + valid_task_types = type[ SemanticSegmentationTask | ClassificationTask @@ -129,7 +143,7 @@ class Task: name: str type: TaskTypeEnum = field(repr=False) - datamodule: BaseDataModule = field(repr=False) + datamodule: Union[BaseDataModule, "GeoBenchDataModule"] = field(repr=False) direction: str terratorch_task: Optional[dict[str, Any]] = None metric: str = "val/loss"