Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,35 @@
from jinja2 import Template as Jinja2Template
from gfm_data_processing.common import logger
from gfm_data_processing import raster_data_operations as rdo
from urllib3.exceptions import NameResolutionError
from requests.exceptions import ConnectionError, Timeout
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, before_sleep_log
import logging

# Retry only transient network errors
RETRYABLE_EXCEPTIONS = (
NameResolutionError, # DNS resolution failed
ConnectionError, # Connection refused/reset
Timeout, # Request timeout
)


# Create a reusable retry decorator
def retry_on_network_error(func):
"""Decorator that retries function on transient network errors."""
return retry(
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=4, max=60),
before_sleep=before_sleep_log(logger, logging.WARNING),
reraise=True
)(func)


######################################################################################################
### Add to geoserver
######################################################################################################
@retry_on_network_error
def add_imagemosaic_to_geoserver(geo, workspace, task_folder, layer_name, retrieved_file_paths):
file_type = "imagemosaic"
content_type = "application/zip"
Expand Down Expand Up @@ -79,6 +103,7 @@ def add_imagemosaic_to_geoserver(geo, workspace, task_folder, layer_name, retrie
logger.debug(f"{tss}: published imagemosaic time dimension")


@retry_on_network_error
def add_netcdf_to_geoserver(geo, workspace, file_path, layer_name, coverage_name):
"""
Save netcdf to geoserver, assuming store has one feature
Expand Down Expand Up @@ -131,7 +156,7 @@ def add_netcdf_to_geoserver(geo, workspace, file_path, layer_name, coverage_name
css = geo.get_coveragestore(workspace=workspace, coveragestore_name=layer_name)
return css


@retry_on_network_error
def add_vector_to_geoserver(geo, workspace, file_path, layer_name, store_format):
"""
Save gpkg or shp to geoserver, assuming store has one feature
Expand Down
52 changes: 44 additions & 8 deletions pipelines/components/terrakit_data_fetch/terrakit_data_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
"""

# Dependencies
# pip install terrakit==0.1.0 requests opentelemetry-distro opentelemetry-exporter-otlp
# pip install terrakit==0.1.0 requests opentelemetry-distro opentelemetry-exporter-otlp tenacity

import os
import json
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_fixed,
retry_if_exception_type,
before_sleep_log,
)
from terrakit import DataConnector
from terrakit.download.geodata_utils import save_data_array_to_file
from terrakit.download.transformations.scale_data_xarray import scale_data_xarray
Expand Down Expand Up @@ -49,6 +56,34 @@ def s1grd_to_decibels(da, modality_tag):
return da


@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(5),
retry=retry_if_exception_type((RuntimeError, ConnectionError, OSError)),
before_sleep=before_sleep_log(logger, "WARNING"),
reraise=True,
)
def fetch_data_with_retry(dc, collection_name, data_date, bbox, maxcc, band_names, save_filepath, task_folder):
"""
Fetch data from connector with automatic retry on network errors.
Retries up to 3 times with 5 second delays for network-related errors:
- RuntimeError (includes RasterioIOError)
- ConnectionError
- OSError (includes CURL errors)
"""
logger.info(f"Attempting to fetch data for collection: {collection_name}")
return dc.connector.get_data(
data_collection_name=collection_name,
date_start=data_date,
date_end=data_date,
bbox=bbox,
maxcc=maxcc,
bands=band_names,
save_file=save_filepath,
working_dir=task_folder,
)


@metric_manager.count_failures(inference_id=inference_id, task_id=task_id)
@metric_manager.record_duration(inference_id=inference_id, task_id=task_id)
def terrakit_data_fetch():
Expand Down Expand Up @@ -124,15 +159,16 @@ def terrakit_data_fetch():

band_names = list(band_dict.get("band_name") for band_dict in model_input_data_spec["bands"])

da = dc.connector.get_data(
data_collection_name=collection_name,
date_start=data_date,
date_end=data_date,
# Use tenacity for automatic retry on network errors
da = fetch_data_with_retry(
dc=dc,
collection_name=collection_name,
data_date=data_date,
bbox=bbox,
maxcc=maxcc,
bands=band_names,
save_file=save_filepath,
working_dir=task_folder,
band_names=band_names,
save_filepath=save_filepath,
task_folder=task_folder,
)
logger.debug("\n\nRetrieved data cube\n\n")
logger.debug(da)
Expand Down
Loading