Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ repos:
rev: 1.7.0
hooks:
- id: interrogate
exclude: ^(setup.py|test/|earth2studio/models/nn/)
exclude: ^(setup.py|test/|earth2studio/models/nn/|serve/server/example_workflows/)
args: [--config=pyproject.toml]

- repo: https://github.com/igorshubovych/markdownlint-cli
Expand Down
1 change: 1 addition & 0 deletions earth2studio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .mrms import MRMS
from .ncar import NCAR_ERA5
from .planetary_computer import (
GeoCatalogClient,
PlanetaryComputerECMWFOpenDataIFS,
PlanetaryComputerGOES,
PlanetaryComputerMODISFire,
Expand Down
273 changes: 273 additions & 0 deletions earth2studio/data/planetary_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@

import asyncio
import hashlib
import json
import logging
import os
import pathlib
import shutil
import time as _time_module
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from time import perf_counter
from typing import Any, TypeVar
from urllib.parse import urlparse
from uuid import uuid4

import nest_asyncio
import netCDF4
import numpy as np
import pygrib
import requests
import xarray as xr
from loguru import logger
from tqdm import tqdm
Expand Down Expand Up @@ -1223,3 +1229,270 @@ def _validate_time(self, times: list[datetime]) -> None:
raise ValueError(
f"Requested date time {time} is after {self._satellite} was retired ({end_date})"
)


class GeoCatalogClient:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will let @NickGeneva comment whether this is the right place to put the client that starts the data ingestion into Microsoft Planetary Computer from Azure Blob Storage. It is not a data source, more like an IO utility so we may want to put it somewhere else.

"""Client for ingesting STAC features into a Planetary Computer GeoCatalog.

Workflow-specific behavior (templates, tile settings, render options, step size,
start-time parameter key) is loaded from JSON files in the caller-provided config
directory. This class is independent of concrete workflows.

Expected files per workflow (workflow_name is passed by the caller):
- parameters-{workflow_name}.json: step_size_hours, start_time_parameter_key
- template-collection-{suffix}.json, template-feature-{suffix}.json
- tile-settings-{suffix}.json, render-options-{suffix}.json

Parameters
----------
workflow_name : str
Workflow name used in filenames (e.g. "fcn3", "fcn3-stormscope-goes").
config_dir : str | pathlib.Path
Directory containing the GeoCatalog JSON config and template files.
"""

APPLICATION_URL = "https://geocatalog.spatio.azure.com/"
REQUESTS_TIMEOUT = 30
CREATION_TIMEOUT = 300

def __init__(
self,
workflow_name: str,
config_dir: str | pathlib.Path,
) -> None:
try:
from azure.identity import DefaultAzureCredential as _DefaultAzureCredential
except ImportError as e:
raise ImportError(
"GeoCatalogClient requires 'azure-identity'. "
"Install with the serve extra or pip install azure-identity."
) from e
self._DefaultAzureCredential = _DefaultAzureCredential
self._workflow_name = workflow_name
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create filenames with workflow name as prefix. keep workflow name consistent throughout. parameter mapping is awkward - fix that too.

self._config_dir = pathlib.Path(config_dir)
self._parameters: dict[str, Any] = {}
self._load_parameters()
self.headers: dict | None = None

def _load_parameters(self) -> None:
path = self._config_dir / f"parameters-{self._workflow_name}.json"
with open(path) as f:
self._parameters = json.load(f)

def update_headers(self) -> None:
"""Refresh the Authorization header using a new Azure credential token."""
credential = self._DefaultAzureCredential()
token = credential.get_token(self.APPLICATION_URL)
self.headers = {"Authorization": f"Bearer {token.token}"}

def _get(self, url: str) -> Any:
return requests.get(
url,
headers=self.headers,
params={"api-version": "2025-04-30-preview"},
timeout=self.REQUESTS_TIMEOUT,
)

def _post(self, url: str, body: dict | None = None) -> Any:
return requests.post(
url,
json=body,
headers=self.headers,
params={"api-version": "2025-04-30-preview"},
timeout=self.REQUESTS_TIMEOUT,
)

def _put(self, url: str, body: dict | None = None) -> Any:
return requests.put(
url,
json=body,
headers=self.headers,
params={"api-version": "2025-04-30-preview"},
timeout=self.REQUESTS_TIMEOUT,
)

def _create_element(self, url: str, stac_config: dict) -> bool:
response = self._post(url, body=stac_config)
log = logging.getLogger("planetary_computer.geocatalog")
if response.status_code not in {200, 201, 202}:
log.error(
"POST to '%s' failed: %s - %s", url, response.status_code, response.text
)
return False
location = response.headers["location"]
log.info("Creating '%s'...", stac_config["id"])
start = perf_counter()
while True:
if (perf_counter() - start) > self.CREATION_TIMEOUT:
log.error("Creation of '%s' timed out", stac_config["id"])
return False
response = self._get(location)
if response.status_code not in {200, 201, 202}:
log.warning(
"Polling '%s' returned %s, retrying...",
location,
response.status_code,
)
_time_module.sleep(5)
continue
try:
status = response.json()["status"]
except (ValueError, KeyError) as exc:
log.warning("Unexpected polling response: %s", exc)
_time_module.sleep(5)
continue
log.info(status)
if status not in {"Pending", "Running"}:
break
_time_module.sleep(5)
if status == "Succeeded":
log.info("Successfully created '%s'", stac_config["id"])
return True
else:
log.error("Failed to create '%s': %s", stac_config["id"], response.text)
return False

def _get_collection_json(self, collection_id: str | None) -> dict:
path = self._config_dir / f"template-collection-{self._workflow_name}.json"
with open(path) as f:
stac_config = json.load(f)
if collection_id is None:
stac_config["id"] = stac_config["id"].format(uuid=uuid4())
else:
stac_config["id"] = collection_id
return stac_config

def _get_feature_json(
self,
start_time: datetime,
end_time: datetime,
blob_url: str,
) -> dict:
path = self._config_dir / f"template-feature-{self._workflow_name}.json"
with open(path) as f:
stac_config = json.load(f)
iso_start = start_time.isoformat()
iso_end = end_time.isoformat()
stac_config["id"] = stac_config["id"].format(
start_time=iso_start[:13], uuid=uuid4()
)
stac_config["properties"]["datetime"] = iso_start
stac_config["properties"]["start_datetime"] = iso_start
stac_config["properties"]["end_datetime"] = iso_end
stac_config["assets"]["data"]["href"] = blob_url
stac_config["assets"]["data"]["description"] = stac_config["assets"]["data"][
"description"
].format(start_time=iso_start, end_time=iso_end)
return stac_config

def _update_tile_settings(self, geocatalog_url: str, collection_id: str) -> None:
path = self._config_dir / f"tile-settings-{self._workflow_name}.json"
with open(path) as f:
tile_settings = json.load(f)
response = self._put(
f"{geocatalog_url}/stac/collections/{collection_id}/configurations/tile-settings",
body=tile_settings,
)
if response.status_code not in {200, 201}:
log = logging.getLogger("planetary_computer.geocatalog")
log.error(
"Could not update tile settings: Error %s - %s",
response.status_code,
response.text,
)

def _update_render_options(self, geocatalog_url: str, collection_id: str) -> None:
path = self._config_dir / f"render-options-{self._workflow_name}.json"
with open(path) as f:
render_params = json.load(f)
log = logging.getLogger("planetary_computer.geocatalog")
for params in render_params:
render_option = {
"id": f"auto-{params['id']}",
"name": params["id"],
"type": "raster-tile",
"options": (
f"assets=data&subdataset_name={params['id']}"
"&sel=time=2100-01-01&sel=ensemble=0&sel_method=nearest"
f"&rescale={params['scale'][0]},{params['scale'][1]}"
f"&colormap_name={params['cmap']}"
),
"minZoom": 0,
}
response = self._post(
f"{geocatalog_url}/stac/collections/{collection_id}/configurations/render-options",
body=render_option,
)
if response.status_code not in {200, 201}:
log.error(
"Could not update render options: Error %s - %s",
response.status_code,
response.text,
)

def _create_collection(
self,
geocatalog_url: str,
collection_id: str | None,
) -> str:
stac_config = self._get_collection_json(collection_id)
success = self._create_element(
url=f"{geocatalog_url}/stac/collections",
stac_config=stac_config,
)
if not success:
raise RuntimeError(f"Failed to create collection '{stac_config['id']}'")
self._update_tile_settings(geocatalog_url, stac_config["id"])
self._update_render_options(geocatalog_url, stac_config["id"])
return stac_config["id"]

def _ensure_collection_exists(self, geocatalog_url: str, collection_id: str) -> str:
response = self._get(f"{geocatalog_url}/stac/collections/{collection_id}")
if response.status_code == 200:
return collection_id
if response.status_code != 404:
raise RuntimeError(
f"Failed to retrieve collection: Error {response.status_code} - {response.text}"
)
return self._create_collection(geocatalog_url, collection_id)

def _resolve_start_time(self, parameters: dict) -> datetime:
key = self._parameters["start_time_parameter_key"]
raw = parameters.get(key)
if raw is None:
raise ValueError(
f"Missing {key!r} in parameters for workflow {self._workflow_name!r}"
)
if isinstance(raw, str):
normalized = raw.replace("Z", "+00:00")
return datetime.fromisoformat(normalized)
if isinstance(raw, datetime):
return raw
if hasattr(raw, "isoformat"):
return datetime.fromisoformat(raw.isoformat())
raise TypeError(f"start time must be str or datetime, got {type(raw)}")

def create_feature(
self,
geocatalog_url: str,
collection_id: str | None,
parameters: dict,
blob_url: str,
) -> tuple[str, str]:
"""Ingest a new STAC feature into the collection."""
self.update_headers()
if collection_id is None:
collection_id = self._create_collection(geocatalog_url, None)
else:
self._ensure_collection_exists(geocatalog_url, collection_id)
start_time = self._resolve_start_time(parameters)
step_hours = self._parameters["step_size_hours"]
end_time = start_time + timedelta(hours=step_hours)
stac_config = self._get_feature_json(start_time, end_time, blob_url)
success = self._create_element(
url=f"{geocatalog_url}/stac/collections/{collection_id}/items",
stac_config=stac_config,
)
if not success:
raise RuntimeError(f"Failed to create feature '{stac_config['id']}'")
return collection_id, stac_config["id"]
4 changes: 4 additions & 0 deletions earth2studio/serve/client/e2client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def as_dataset(self) -> xr.Dataset:
zarr_path = "/".join(result_path.split("/")[1:])
mapper = fsspec_utils.get_mapper(request_result, zarr_path)
ds = xr.open_zarr(mapper, consolidated=True, **self.workflow.xr_args)

elif request_result.storage_type == StorageType.AZURE:
mapper = fsspec_utils.get_mapper(request_result, "results.zarr")
ds = xr.open_zarr(mapper, consolidated=True, **self.workflow.xr_args)
elif request_result.storage_type == StorageType.SERVER:
result_url = urljoin(
self.workflow.base_url + "/",
Expand Down
Loading