Skip to content

Commit 1e4c94f

Browse files
Feat/custom pipeline (#267)
* feat(platform,application): configureable pipeline * test(application): reactivate test_cli_run_submit_and_describe_and_cancel_and_download_and_delete against production * chore(deps): some
1 parent e4ee905 commit 1e4c94f

File tree

9 files changed

+499
-67
lines changed

9 files changed

+499
-67
lines changed

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ requires-python = ">=3.11, <3.14"
7474

7575
dependencies = [
7676
# From Template
77-
"fastapi[standard,all]>=0.121.1,<1",
77+
"fastapi[standard,all]>=0.121.3,<1",
7878
"humanize>=4.14.0,<5",
7979
"nicegui[native]>=3.1.0,<3.2.0", # Regression in 3.2.0
8080
"packaging>=25.0,<26",
@@ -83,24 +83,24 @@ dependencies = [
8383
"pydantic-settings>=2.12.0,<3",
8484
"pywin32>=310,<311 ; sys_platform == 'win32'",
8585
"pyyaml>=6.0.3,<7",
86-
"sentry-sdk>=2.44.0,<3",
86+
"sentry-sdk>=2.45.0,<3",
8787
"typer>=0.20.0,<1",
8888
"uptime>=3.0.1,<4",
8989
# Custom
9090
"aiopath>=0.6.11,<1",
91-
"boto3>=1.40.61,<2",
91+
"boto3>=1.41.0,<2",
9292
"certifi>=2025.11.12,<2026",
9393
"defusedxml>=0.7.1",
9494
"dicom-validator>=0.7.3,<1",
9595
"dicomweb-client[gcp]>=0.59.3,<1",
9696
"duckdb>=0.10.0,<=1.4.1",
9797
"fastparquet>=2024.11.0,<2025",
98-
"google-cloud-storage>=3.5.0,<4",
98+
"google-cloud-storage>=3.6.0,<4",
9999
"google-crc32c>=1.7.1,<2",
100100
"highdicom>=0.26.1,<1",
101101
"html-sanitizer>=2.6.0,<3",
102102
"httpx>=0.28.1,<1",
103-
"idc-index-data==22.1.2",
103+
"idc-index-data==22.1.5",
104104
"ijson>=3.4.0.post0,<4",
105105
"jsf>=0.11.2,<1",
106106
"jsonschema[format-nongpl]>=4.25.1,<5",
@@ -131,7 +131,7 @@ jupyter = ["jupyter>=1.1.1,<2"]
131131
marimo = [
132132
"cloudpathlib>=0.23.0,<1",
133133
"ipython>=9.7.0,<10",
134-
"marimo>=0.17.7,<1",
134+
"marimo>=0.17.8,<1",
135135
"matplotlib>=3.10.7,<4",
136136
"shapely>=2.1.0,<3",
137137
]
@@ -165,7 +165,7 @@ dev = [
165165
"pytest-timeout>=2.4.0,<3",
166166
"pytest-watcher>=0.4.3,<1",
167167
"pytest-xdist[psutil]>=3.8.0,<4",
168-
"ruff>=0.14.4,<1",
168+
"ruff>=0.14.5,<1",
169169
"scalene>=1.5.55,<2",
170170
"sphinx>=8.2.3,<9",
171171
"sphinx-autobuild>=2025.8.25,<2026",

src/aignostics/application/_cli.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,22 @@ def run_execute( # noqa: PLR0913, PLR0917
366366
validate_only: Annotated[
367367
bool, typer.Option(help="If True, cancel the run post validation, before analysis.")
368368
] = False,
369+
gpu_type: Annotated[
370+
str,
371+
typer.Option(help="GPU type to use for processing (L4 or A100)."),
372+
] = "A100",
373+
gpu_provisioning_mode: Annotated[
374+
str,
375+
typer.Option(help="GPU provisioning mode (SPOT or ON_DEMAND)."),
376+
] = "ON_DEMAND",
377+
max_gpus_per_slide: Annotated[
378+
int,
379+
typer.Option(help="Maximum number of GPUs to allocate per slide (1-8).", min=1, max=8),
380+
] = 1,
381+
cpu_provisioning_mode: Annotated[
382+
str,
383+
typer.Option(help="CPU provisioning mode (SPOT or ON_DEMAND)."),
384+
] = "ON_DEMAND",
369385
) -> None:
370386
"""Prepare metadata, upload data to platform, and submit an application run, then incrementally download results.
371387
@@ -401,10 +417,15 @@ def run_execute( # noqa: PLR0913, PLR0917
401417
metadata_csv_file=metadata_csv_file,
402418
application_version=application_version,
403419
note=note,
420+
tags=None,
404421
due_date=due_date,
405422
deadline=deadline,
406423
onboard_to_aignostics_portal=onboard_to_aignostics_portal,
407424
validate_only=validate_only,
425+
gpu_type=gpu_type,
426+
gpu_provisioning_mode=gpu_provisioning_mode,
427+
max_gpus_per_slide=max_gpus_per_slide,
428+
cpu_provisioning_mode=cpu_provisioning_mode,
408429
)
409430
result_download(
410431
run_id=run_id,
@@ -652,6 +673,22 @@ def run_submit( # noqa: PLR0913, PLR0917
652673
validate_only: Annotated[
653674
bool, typer.Option(help="If True, cancel the run post validation, before analysis.")
654675
] = False,
676+
gpu_type: Annotated[
677+
str,
678+
typer.Option(help="GPU type to use for processing (L4 or A100)."),
679+
] = "A100",
680+
gpu_provisioning_mode: Annotated[
681+
str,
682+
typer.Option(help="GPU provisioning mode (SPOT or ON_DEMAND)."),
683+
] = "ON_DEMAND",
684+
max_gpus_per_slide: Annotated[
685+
int,
686+
typer.Option(help="Maximum number of GPUs to allocate per slide (1-8).", min=1, max=8),
687+
] = 1,
688+
cpu_provisioning_mode: Annotated[
689+
str,
690+
typer.Option(help="CPU provisioning mode (SPOT or ON_DEMAND)."),
691+
] = "ON_DEMAND",
655692
) -> str:
656693
"""Submit run by referencing the metadata CSV file.
657694
@@ -701,11 +738,26 @@ def run_submit( # noqa: PLR0913, PLR0917
701738
app_version.version_number,
702739
metadata_dict,
703740
)
741+
742+
# Build custom metadata with pipeline configuration
743+
custom_metadata = {
744+
"pipeline": {
745+
"gpu": {
746+
"gpu_type": gpu_type,
747+
"provisioning_mode": gpu_provisioning_mode,
748+
"max_gpus_per_slide": max_gpus_per_slide,
749+
},
750+
"cpu": {
751+
"provisioning_mode": cpu_provisioning_mode,
752+
},
753+
},
754+
}
755+
704756
application_run = Service().application_run_submit_from_metadata(
705757
application_id=application_id,
706758
metadata=metadata_dict,
707759
application_version=application_version,
708-
custom_metadata=None, # TODO(Helmut): Add support for custom metadata
760+
custom_metadata=custom_metadata,
709761
note=note,
710762
tags={tag.strip() for tag in tags.split(",") if tag.strip()} if tags else None,
711763
due_date=due_date,

src/aignostics/application/_gui/_page_application_describe.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
WIDTH_1200px = "width: 1200px; max-width: none"
2929
MESSAGE_METADATA_GRID_IS_NOT_INITIALIZED = "Metadata grid is not initialized."
3030

31+
CLASS_SUBSECTION_HEADER = "text-h6 mb-0 pb-0"
32+
CLASS_WIDTH_ONE_THIRD = "w-1/3"
33+
3134

3235
@binding.bindable_dataclass
3336
class SubmitForm:
@@ -50,6 +53,10 @@ class SubmitForm:
5053
deadline: str = (datetime.now().astimezone() + timedelta(hours=24)).strftime("%Y-%m-%d %H:%M")
5154
validate_only: bool = False
5255
onboard_to_aignostics_portal: bool = False
56+
gpu_type: str = "A100"
57+
gpu_provisioning_mode: str = "ON_DEMAND"
58+
max_gpus_per_slide: int = 1
59+
cpu_provisioning_mode: str = "ON_DEMAND"
5360

5461

5562
submit_form = SubmitForm()
@@ -619,7 +626,7 @@ class ThumbnailRenderer {
619626
today = now.strftime("%Y/%m/%d")
620627
min_hour = (now + timedelta(hours=1)).hour
621628
min_minute = (now + timedelta(hours=1)).minute
622-
ui.label("Soft Due Date").classes("text-h6 mb-0 pb-0")
629+
ui.label("Soft Due Date").classes("class_subsection_header")
623630
ui.label(
624631
"The platform will try to complete the run before this time, "
625632
"given your subscription tier and available GPU resources."
@@ -672,7 +679,7 @@ class ThumbnailRenderer {
672679
}}
673680
"""
674681
)
675-
ui.label("Hard Deadline").classes("text-h6 mb-0 pb-0")
682+
ui.label("Hard Deadline").classes("class_subsection_header")
676683
ui.label("The platform might cancel the run if not completed by this time.").classes(
677684
"text-sm mt-0 pt-0"
678685
)
@@ -702,11 +709,25 @@ def _submit() -> None:
702709
"""Submit the application run."""
703710
ui.notify("Submitting application run ...", type="info")
704711
try:
712+
# Build custom metadata with pipeline configuration
713+
custom_metadata = {
714+
"pipeline": {
715+
"gpu": {
716+
"gpu_type": submit_form.gpu_type,
717+
"provisioning_mode": submit_form.gpu_provisioning_mode,
718+
"max_gpus_per_slide": submit_form.max_gpus_per_slide,
719+
},
720+
"cpu": {
721+
"provisioning_mode": submit_form.cpu_provisioning_mode,
722+
},
723+
},
724+
}
725+
705726
run = service.application_run_submit_from_metadata(
706727
application_id=str(submit_form.application_id),
707728
metadata=submit_form.metadata or [],
708729
application_version=str(submit_form.application_version),
709-
custom_metadata=None, # TODO(Helmut): Allow user to edit custom metadata
730+
custom_metadata=custom_metadata,
710731
note=submit_form.note,
711732
tags=set(submit_form.tags) if submit_form.tags else None,
712733
due_date=datetime.strptime(submit_form.due_date, "%Y-%m-%d %H:%M")
@@ -816,6 +837,80 @@ def _update_upload_progress() -> None:
816837
break
817838
_upload_ui.refresh(submit_form.metadata)
818839

840+
with ui.step("Pipeline"):
841+
user_info: UserInfo | None = app.storage.tab.get("user_info", None)
842+
can_configure_pipeline = (
843+
user_info
844+
and user_info.organization
845+
and user_info.organization.name
846+
and user_info.organization.name.lower() in {"aignostics", "pre-alpha-org", "lmu", "charite"}
847+
)
848+
849+
if can_configure_pipeline:
850+
with ui.column(align_items="start").classes("w-full"):
851+
ui.label("GPU Configuration").classes("class_subsection_header")
852+
ui.label(
853+
"Configure GPU resources for processing your whole slide images. "
854+
"These settings control the type and provisioning mode of GPUs used during AI analysis."
855+
).classes("text-sm mt-0 pt-0 mb-4")
856+
857+
with ui.row().classes("w-full gap-4"):
858+
ui.select(
859+
label="GPU Type",
860+
options={"L4": "L4", "A100": "A100"},
861+
value=submit_form.gpu_type,
862+
).bind_value(submit_form, "gpu_type").mark("SELECT_GPU_TYPE").classes(CLASS_WIDTH_ONE_THIRD)
863+
864+
ui.number(
865+
label="Max GPUs per Slide",
866+
value=submit_form.max_gpus_per_slide,
867+
min=1,
868+
max=8,
869+
step=1,
870+
).bind_value(submit_form, "max_gpus_per_slide").mark("NUMBER_MAX_GPUS_PER_SLIDE").classes(
871+
CLASS_WIDTH_ONE_THIRD
872+
)
873+
874+
ui.select(
875+
label="GPU Provisioning Mode",
876+
options={
877+
"SPOT": "Spot nodes (lower cost, better availability, might be preempted and retried)",
878+
"ON_DEMAND": (
879+
"On demand nodes (higher cost, limited availability, processing might be delayed)"
880+
),
881+
},
882+
value=submit_form.gpu_provisioning_mode,
883+
).bind_value(submit_form, "gpu_provisioning_mode").mark("SELECT_GPU_PROVISIONING_MODE").classes(
884+
CLASS_WIDTH_ONE_THIRD
885+
)
886+
887+
ui.separator().classes("my-4")
888+
889+
ui.label("CPU Configuration").classes("class_subsection_header")
890+
ui.label("Configure CPU resources for algorithms that do not require GPU acceleration.").classes(
891+
"text-sm mt-0 pt-0 mb-4"
892+
)
893+
894+
with ui.row().classes("w-full gap-4"):
895+
ui.select(
896+
label="CPU Provisioning Mode",
897+
options={
898+
"SPOT": "Spot nodes (lower cost, better availability, might be preempted and retried)",
899+
"ON_DEMAND": "On demand nodes (higher cost, limited availability, might be delayed)",
900+
},
901+
value=submit_form.cpu_provisioning_mode,
902+
).bind_value(submit_form, "cpu_provisioning_mode").mark("SELECT_CPU_PROVISIONING_MODE").classes(
903+
"w-1/2"
904+
)
905+
else:
906+
ui.label(
907+
"Pipeline configuration is not available for your organization. Default settings will be used."
908+
).classes("text-body1")
909+
910+
with ui.stepper_navigation():
911+
ui.button("Next", on_click=stepper.next).mark("BUTTON_PIPELINE_NEXT")
912+
ui.button("Back", on_click=stepper.previous).props("flat")
913+
819914
with ui.step("Submit"):
820915
_upload_ui([])
821916
ui.timer(0.1, callback=_update_upload_progress)

src/aignostics/platform/_sdk_metadata.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,76 @@
77
import os
88
import sys
99
from datetime import UTC, datetime
10+
from enum import StrEnum
1011
from typing import Any, Literal
1112

1213
from loguru import logger
13-
from pydantic import BaseModel, Field, ValidationError
14+
from pydantic import BaseModel, Field, PositiveInt, ValidationError
1415

1516
from aignostics.utils import user_agent
1617

1718
SDK_METADATA_SCHEMA_VERSION = "0.0.4"
1819
ITEM_SDK_METADATA_SCHEMA_VERSION = "0.0.3"
1920

21+
# Pipeline orchestration defaults
22+
DEFAULT_GPU_TYPE = "A100"
23+
DEFAULT_MAX_GPUS_PER_SLIDE = 1
24+
DEFAULT_GPU_PROVISIONING_MODE = "ON_DEMAND"
25+
DEFAULT_CPU_PROVISIONING_MODE = "ON_DEMAND"
26+
27+
28+
class GPUType(StrEnum):
29+
"""Type of GPU to use for processing."""
30+
31+
L4 = "L4"
32+
A100 = "A100"
33+
34+
35+
class ProvisioningMode(StrEnum):
36+
"""Provisioning mode for resources."""
37+
38+
SPOT = "SPOT"
39+
ON_DEMAND = "ON_DEMAND"
40+
41+
42+
class CPUConfig(BaseModel):
43+
"""Configuration for CPU resources."""
44+
45+
provisioning_mode: ProvisioningMode = Field(
46+
default_factory=lambda: ProvisioningMode(DEFAULT_CPU_PROVISIONING_MODE),
47+
description="The provisioning mode for CPU resources (SPOT or ON_DEMAND)",
48+
)
49+
50+
51+
class GPUConfig(BaseModel):
52+
"""Configuration for GPU resources."""
53+
54+
gpu_type: GPUType = Field(
55+
default_factory=lambda: GPUType(DEFAULT_GPU_TYPE),
56+
description="The type of GPU to use (L4 or A100)",
57+
)
58+
provisioning_mode: ProvisioningMode = Field(
59+
default_factory=lambda: ProvisioningMode(DEFAULT_GPU_PROVISIONING_MODE),
60+
description="The provisioning mode for GPU resources (SPOT or ON_DEMAND)",
61+
)
62+
max_gpus_per_slide: PositiveInt = Field(
63+
default=DEFAULT_MAX_GPUS_PER_SLIDE,
64+
description="The maximum number of GPUs to allocate per slide",
65+
)
66+
67+
68+
class PipelineConfig(BaseModel):
69+
"""Pipeline configuration for dynamic orchestration."""
70+
71+
gpu: GPUConfig = Field(
72+
default_factory=GPUConfig,
73+
description="GPU resource configuration",
74+
)
75+
cpu: CPUConfig = Field(
76+
default_factory=CPUConfig,
77+
description="CPU resource configuration",
78+
)
79+
2080

2181
class SubmissionMetadata(BaseModel):
2282
"""Metadata about how the SDK was invoked."""
@@ -121,6 +181,7 @@ class RunSdkMetadata(BaseModel):
121181
note: str | None = Field(None, description="Optional user note for the run")
122182
workflow: WorkflowMetadata | None = Field(None, description="Workflow control flags")
123183
scheduling: SchedulingMetadata | None = Field(None, description="Scheduling information")
184+
pipeline: PipelineConfig | None = Field(None, description="Pipeline orchestration configuration")
124185

125186
model_config = {"extra": "forbid"} # Reject unknown fields
126187

tests/aignostics/application/cli_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Tests to verify the CLI functionality of the application module."""
22

3-
import os
43
import platform
54
import re
65
from datetime import UTC, datetime, timedelta
@@ -254,10 +253,6 @@ def test_cli_run_submit_fails_on_missing_url(runner: CliRunner, tmp_path: Path,
254253
assert "Invalid platform bucket URL: ''" in normalize_output(result.stdout)
255254

256255

257-
@pytest.mark.skipif(
258-
os.getenv("AIGNOSTICS_PLATFORM_ENVIRONMENT", "staging") == "production",
259-
reason="Broken when targeting production",
260-
)
261256
@pytest.mark.e2e
262257
@pytest.mark.long_running
263258
@pytest.mark.flaky(retries=3, delay=5)

0 commit comments

Comments
 (0)