Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support retry loop for prepare/ingest/materialize. #150

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions rslearn/data_sources/planet_basemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,16 @@ def __init__(
environmnet variable).
"""
self.config = config
self.series_id = series_id
self.bands = bands

self.session = requests.Session()
if api_key is None:
api_key = os.environ["PL_API_KEY"]
self.session.auth = (api_key, "")

# List mosaics.
self.mosaics = {}
for mosaic_dict in self._api_get_paginate(
path=f"series/{series_id}/mosaics", list_key="mosaics"
):
shp = shapely.box(*mosaic_dict["bbox"])
time_range = (
datetime.fromisoformat(mosaic_dict["first_acquired"]),
datetime.fromisoformat(mosaic_dict["last_acquired"]),
)
geom = STGeometry(WGS84_PROJECTION, shp, time_range)
self.mosaics[mosaic_dict["id"]] = geom
# Lazily load mosaics.
self.mosaics: dict | None = None

@staticmethod
def from_config(config: LayerConfig, ds_path: UPath) -> "PlanetBasemap":
Expand All @@ -123,6 +114,31 @@ def from_config(config: LayerConfig, ds_path: UPath) -> "PlanetBasemap":
kwargs[optional_key] = d[optional_key]
return PlanetBasemap(**kwargs)

def _load_mosaics(self) -> dict[str, STGeometry]:
"""Lazily load mosaics in the configured series_id from Planet API.

We don't load it when creating the data source because it takes time and caller
may not be calling get_items. Additionally, loading it during the get_items
call enables leveraging the retry loop functionality in
prepare_dataset_windows.
"""
if self.mosaics is not None:
return self.mosaics

self.mosaics = {}
for mosaic_dict in self._api_get_paginate(
path=f"series/{self.series_id}/mosaics", list_key="mosaics"
):
shp = shapely.box(*mosaic_dict["bbox"])
time_range = (
datetime.fromisoformat(mosaic_dict["first_acquired"]),
datetime.fromisoformat(mosaic_dict["last_acquired"]),
)
geom = STGeometry(WGS84_PROJECTION, shp, time_range)
self.mosaics[mosaic_dict["id"]] = geom

return self.mosaics

def _api_get(
self,
path: str | None = None,
Expand Down Expand Up @@ -159,6 +175,7 @@ def _api_get(
raise ApiError(
f"{url}: got status code {response.status_code}: {response.text}"
)

return response.json()

def _api_get_paginate(
Expand Down Expand Up @@ -204,6 +221,8 @@ def get_items(
Returns:
List of groups of items that should be retrieved for each geometry.
"""
mosaics = self._load_mosaics()

groups = []
for geometry in geometries:
geom_bbox = geometry.to_projection(WGS84_PROJECTION).shp.bounds
Expand All @@ -212,7 +231,7 @@ def get_items(
# Find the relevant mosaics that the geometry intersects.
# For each relevant mosaic, identify the intersecting quads.
items = []
for mosaic_id, mosaic_geom in self.mosaics.items():
for mosaic_id, mosaic_geom in mosaics.items():
if not geometry.intersects(mosaic_geom):
continue
logger.info(f"found mosaic {mosaic_geom} for geom {geometry}")
Expand Down
130 changes: 112 additions & 18 deletions rslearn/dataset/manage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Functions to manage datasets."""

import random
import time
from collections.abc import Callable
from datetime import timedelta
from typing import Any

import rslearn.data_sources
from rslearn.config import (
LayerConfig,
Expand All @@ -17,8 +23,39 @@
logger = get_logger(__name__)


def retry(fn: Callable, retry_max_attempts: int, retry_backoff: timedelta) -> Any:
"""Retry the function multiple times in case of error.

The function is retried until either the attempts are exhausted, or the function
runs successfully without raising an Exception.

Args:
fn: the function to call.
retry_max_attempts: retry this many times (plus the original attempt) before
giving up (and raising Exception).
retry_backoff: the base backoff time used to compute how long to wait between
retries. The actual time is (retry_backoff * attempts) * r, where r is a
random number between 1 and 2, and attempts is the number of attempts tried
so far.
"""
for attempt_idx in range(retry_max_attempts):
try:
return fn()
except Exception as e:
logger.debug(f"Retrying after catching error in retry loop: {e}")
sleep_base_seconds = retry_backoff.total_seconds() * (attempt_idx + 1)
time.sleep(sleep_base_seconds * (1 + random.random()))

# Last attempt. This time we don't catch the exception.
return fn()


def prepare_dataset_windows(
dataset: Dataset, windows: list[Window], force: bool = False
dataset: Dataset,
windows: list[Window],
force: bool = False,
retry_max_attempts: int = 0,
retry_backoff: timedelta = timedelta(minutes=1),
) -> None:
"""Prepare windows in a dataset.

Expand All @@ -30,11 +67,15 @@ def prepare_dataset_windows(
windows: the windows to prepare
force: whether to prepare windows even if they were previously prepared
(default false)
retry_max_attempts: set greater than zero to retry for this many attempts in
case of error.
retry_backoff: how long to wait before retrying (see retry).
"""
# Iterate over retrieved layers, and prepare each one.
for layer_name, layer_cfg in dataset.layers.items():
if not layer_cfg.data_source:
continue
data_source_cfg = layer_cfg.data_source

# Get windows that need to be prepared for this layer.
needed_windows = []
Expand All @@ -59,13 +100,13 @@ def prepare_dataset_windows(
geometry = window.get_geometry()

# Apply temporal modifiers.
time_offset = layer_cfg.data_source.time_offset
time_offset = data_source_cfg.time_offset
if geometry.time_range and time_offset:
geometry.time_range = (
geometry.time_range[0] + time_offset,
geometry.time_range[1] + time_offset,
)
duration = layer_cfg.data_source.duration
duration = data_source_cfg.duration
if geometry.time_range and duration:
geometry.time_range = (
geometry.time_range[0],
Expand All @@ -74,7 +115,12 @@ def prepare_dataset_windows(

geometries.append(geometry)

results = data_source.get_items(geometries, layer_cfg.data_source.query_config)
results = retry(
fn=lambda: data_source.get_items(geometries, data_source_cfg.query_config),
retry_max_attempts=retry_max_attempts,
retry_backoff=retry_backoff,
)

for window, result in zip(needed_windows, results):
layer_datas = window.load_layer_datas()
layer_datas[layer_name] = WindowLayerData(
Expand All @@ -86,7 +132,12 @@ def prepare_dataset_windows(
window.save_layer_datas(layer_datas)


def ingest_dataset_windows(dataset: Dataset, windows: list[Window]) -> None:
def ingest_dataset_windows(
dataset: Dataset,
windows: list[Window],
retry_max_attempts: int = 0,
retry_backoff: timedelta = timedelta(minutes=1),
) -> None:
"""Ingest items for retrieved layers in a dataset.

The items associated with the specified windows are downloaded and divided into
Expand All @@ -95,6 +146,9 @@ def ingest_dataset_windows(dataset: Dataset, windows: list[Window]) -> None:
Args:
dataset: the dataset
windows: the windows to ingest
retry_max_attempts: set greater than zero to retry for this many attempts in
case of error.
retry_backoff: how long to wait before retrying (see retry).
"""
tile_store = dataset.get_tile_store()
for layer_name, layer_cfg in dataset.layers.items():
Expand Down Expand Up @@ -123,10 +177,19 @@ def ingest_dataset_windows(dataset: Dataset, windows: list[Window]) -> None:

print(f"Ingesting {len(geometries_by_item)} items in layer {layer_name}")
geometries_and_items = list(geometries_by_item.items())
data_source.ingest(
tile_store=get_tile_store_with_layer(tile_store, layer_name, layer_cfg),
items=[item for item, _ in geometries_and_items],
geometries=[geometries for _, geometries in geometries_and_items],

# Use retry loop for the actual data source ingest call.
def ingest() -> None:
data_source.ingest(
tile_store=get_tile_store_with_layer(tile_store, layer_name, layer_cfg),
items=[item for item, _ in geometries_and_items],
geometries=[geometries for _, geometries in geometries_and_items],
)

retry(
fn=ingest,
retry_max_attempts=retry_max_attempts,
retry_backoff=retry_backoff,
)


Expand Down Expand Up @@ -186,6 +249,8 @@ def materialize_window(
tile_store: TileStore,
layer_name: str,
layer_cfg: LayerConfig,
retry_max_attempts: int = 0,
retry_backoff: timedelta = timedelta(minutes=1),
) -> None:
"""Materialize a window.

Expand All @@ -196,6 +261,9 @@ def materialize_window(
tile_store: tile store of the dataset to materialize from
layer_name: the layer name
layer_cfg: the layer config
retry_max_attempts: set greater than zero to retry for this many attempts in
case of error.
retry_backoff: how long to wait before retrying (see retry).
"""
# Check if layer is materialized already.
if window.is_layer_completed(layer_name):
Expand Down Expand Up @@ -237,23 +305,39 @@ def materialize_window(
materializer = Materializers[dataset.materializer_name]
else:
materializer = Materializers[layer_cfg.layer_type.value]
materializer.materialize(
get_tile_store_with_layer(tile_store, layer_name, layer_cfg),
window,
layer_name,
layer_cfg,
item_groups,

retry(
fn=lambda: materializer.materialize(
get_tile_store_with_layer(tile_store, layer_name, layer_cfg),
window,
layer_name,
layer_cfg,
item_groups,
),
retry_max_attempts=retry_max_attempts,
retry_backoff=retry_backoff,
)

else:
# This window is meant to be materialized directly from the data source.
print(
f"Materializing {len(item_groups)} item groups in layer {layer_name} via data source"
)
data_source.materialize(window, item_groups, layer_name, layer_cfg)
retry(
fn=lambda: data_source.materialize(
window, item_groups, layer_name, layer_cfg
),
retry_max_attempts=retry_max_attempts,
retry_backoff=retry_backoff,
)


def materialize_dataset_windows(dataset: Dataset, windows: list[Window]) -> None:
def materialize_dataset_windows(
dataset: Dataset,
windows: list[Window],
retry_max_attempts: int = 0,
retry_backoff: timedelta = timedelta(minutes=1),
) -> None:
"""Materialize items for retrieved layers in a dataset.

The portions of items corresponding to dataset windows are extracted from the tile
Expand All @@ -262,6 +346,9 @@ def materialize_dataset_windows(dataset: Dataset, windows: list[Window]) -> None
Args:
dataset: the dataset
windows: the windows to materialize
retry_max_attempts: set greater than zero to retry for this many attempts in
case of error.
retry_backoff: how long to wait before retrying (see retry).
"""
tile_store = dataset.get_tile_store()
for layer_name, layer_cfg in dataset.layers.items():
Expand All @@ -274,5 +361,12 @@ def materialize_dataset_windows(dataset: Dataset, windows: list[Window]) -> None

for window in windows:
materialize_window(
window, dataset, data_source, tile_store, layer_name, layer_cfg
window=window,
dataset=dataset,
data_source=data_source,
tile_store=tile_store,
layer_name=layer_name,
layer_cfg=layer_cfg,
retry_max_attempts=retry_max_attempts,
retry_backoff=retry_backoff,
)
Loading
Loading