Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ The format is based on Keep a Changelog, and the project follows Semantic Versio

## [Unreleased]

### Added

- **GEE fetch statistics reporting in `export_batch`.** When `show_progress=True`, a `[gee_fetch]` summary line is now printed to stderr after each prefetch chunk completes, reporting total planned fetches, completed, failed, cache hits, and the most recently processed point/sensor. This gives users visibility into GEE quota consumption, cache reuse, and whether runtime is dominated by fetching vs. inference. No output is emitted when `show_progress=False` or when no GEE provider is involved (e.g. precomputed models). The underlying `FetchStats` class in `tools/progress.py` is thread-safe and accumulates counts cumulatively across chunks.

### Fixed

- **CLI `ModuleNotFoundError` on import.** `rs_embed.cli` was importing from `rs_embed.export` and `rs_embed.inspect`, two modules that do not exist in the current package layout. The imports now point directly to `rs_embed.api` (`export_batch`, `inspect_gee_patch`). The `export-npz` subcommand call site has been updated to match `export_batch`'s current signature: a single spatial argument is wrapped in `spatials=[...]`, the output path becomes `ExportTarget.combined(args.out)`, and the flat boolean flags (`save_inputs`, `save_manifest`, etc.) are grouped into an `ExportConfig` object. The stub injection in `tests/test_cli_parsers.py` that was masking the broken imports has been removed and the integration test updated to patch `cli.export_batch` instead of `cli.export_npz`.
Expand Down
31 changes: 26 additions & 5 deletions src/rs_embed/pipelines/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
point_resume_manifest,
summarize_status,
)
from ..tools.progress import FetchStats
from ..tools.progress import create_progress as _default_create_progress
from ..tools.serialization import (
sanitize_key,
Expand Down Expand Up @@ -155,6 +156,8 @@ def _run_per_item(self) -> list[dict[str, Any]]:

model_progress, on_model_done = self._create_model_progress(total=len(pending_idxs))

fetch_stats = FetchStats() if (bool(cfg.show_progress) and need_prefetch) else None

# Chunk pipeline
csize = cfg.effective_chunk_size
chunk_groups = [pending_idxs[s : s + csize] for s in range(0, len(pending_idxs), csize)]
Expand All @@ -166,14 +169,16 @@ def _run_per_item(self) -> list[dict[str, Any]]:
if need_prefetch and chunk_groups:
prefetch_pipeline_ex = ThreadPoolExecutor(max_workers=1)
prefetched_chunk_fut = prefetch_pipeline_ex.submit(
self._prefetch_chunk, active_prefetch, chunk_groups[0]
self._prefetch_chunk, active_prefetch, chunk_groups[0], fetch_stats
)

for chunk_idx, idxs in enumerate(chunk_groups):
# Wait for prefetched data
if prefetched_chunk_fut is not None:
prefetched_chunk_fut.result()
prefetched_chunk_fut = None
if fetch_stats is not None:
fetch_stats.log()

# Kick off next chunk prefetch
next_prefetch: PrefetchManager | None = None
Expand All @@ -182,7 +187,10 @@ def _run_per_item(self) -> list[dict[str, Any]]:
# is isolated from the current chunk while both are live.
next_prefetch = self._clone_prefetch(active_prefetch)
prefetched_chunk_fut = prefetch_pipeline_ex.submit(
self._prefetch_chunk, next_prefetch, chunk_groups[chunk_idx + 1]
self._prefetch_chunk,
next_prefetch,
chunk_groups[chunk_idx + 1],
fetch_stats,
)

# Chunk-level batch inference when useful for GPU or precomputed models.
Expand Down Expand Up @@ -267,9 +275,17 @@ def _run_combined(self) -> dict[str, Any]:
unit="step",
)

fetch_stats = (
FetchStats() if (bool(cfg.show_progress) and prefetch.provider is not None) else None
)

# Prefetch all inputs
if prefetch.provider is not None and tasks:
prefetch.fetch_chunk(all_idxs, self.spatials, self.temporal, progress=progress)
prefetch.fetch_chunk(
all_idxs, self.spatials, self.temporal, progress=progress, fetch_stats=fetch_stats
)
if fetch_stats is not None:
fetch_stats.log()

# Store prefetch checkpoint
if prefetch.provider is not None:
Expand Down Expand Up @@ -504,9 +520,14 @@ def _should_batch_per_item(self) -> bool:
# get_embeddings_batch(); this is useful on CPU and independent of input_prep.
return any(bool(mc.is_precomputed) for mc in self.models)

def _prefetch_chunk(self, prefetch: PrefetchManager, idxs: list[int]) -> None:
def _prefetch_chunk(
self,
prefetch: PrefetchManager,
idxs: list[int],
fetch_stats: FetchStats | None = None,
) -> None:
"""Prefetch a chunk of inputs (for use in pipelined prefetch)."""
prefetch.fetch_chunk(idxs, self.spatials, self.temporal)
prefetch.fetch_chunk(idxs, self.spatials, self.temporal, fetch_stats=fetch_stats)

def _clone_prefetch(self, src: PrefetchManager) -> PrefetchManager:
"""Clone a PrefetchManager preserving the plan with fresh per-chunk caches."""
Expand Down
20 changes: 20 additions & 0 deletions src/rs_embed/pipelines/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
select_prefetched_channels,
)
from ..tools.normalization import normalize_input_array
from ..tools.progress import FetchStats
from .runner import run_with_retry


Expand Down Expand Up @@ -124,12 +125,21 @@ def fetch_chunk(
temporal: TemporalSpec | None,
*,
progress: Any = None,
fetch_stats: FetchStats | None = None,
) -> None:
"""Prefetch provider inputs for *idxs* into ``self.cache``."""
if not self.enabled or self.provider is None:
return

tasks = self.build_tasks(idxs, spatials)

if fetch_stats is not None:
possible = len(idxs) * len(self.fetch_sensor_by_key)
cache_hits = possible - len(tasks)
if cache_hits > 0:
fetch_stats.record_cache_hits(cache_hits)
fetch_stats.record_planned(len(tasks))

if not tasks:
return

Expand Down Expand Up @@ -176,6 +186,8 @@ def _fetch_one(
err_s = repr(e)
for member_skey in self.fetch_members.get(skey, []):
self.errors[(i, member_skey)] = err_s
if fetch_stats is not None:
fetch_stats.record_failure()
if progress is not None:
progress.update(1)
continue
Expand Down Expand Up @@ -203,12 +215,20 @@ def _fetch_one(
if not cfg.continue_on_error:
raise err
self.errors[(i, member_skey)] = repr(err)
if fetch_stats is not None:
fetch_stats.record_failure()
continue
self.input_reports[(i, member_skey)] = rep
self.cache[(i, member_skey)] = x_member
if fmeta:
self.fetch_meta[(i, member_skey)] = fmeta

if fetch_stats is not None:
fsensor = self.fetch_sensor_by_key.get(skey)
fetch_stats.record_success(
point=i,
sensor=fsensor.collection if fsensor is not None else None,
)
if progress is not None:
progress.update(1)

Expand Down
88 changes: 88 additions & 0 deletions src/rs_embed/tools/progress.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
import threading
from typing import Any


Expand Down Expand Up @@ -44,6 +45,93 @@ def close(self) -> None:
sys.stderr.flush()


class FetchStats:
"""Thread-safe accumulator for GEE image fetch statistics.

Updated by :class:`~rs_embed.pipelines.prefetch.PrefetchManager` during
``fetch_chunk`` and surfaced as log messages when progress reporting is on.
"""

def __init__(self) -> None:
self._lock = threading.Lock()
self._total = 0
self._completed = 0
self._failed = 0
self._cache_hits = 0
self._last_point: int | None = None
self._last_sensor: str | None = None

@property
def total(self) -> int:
"""Total GEE fetch operations planned across all chunks."""
with self._lock:
return self._total

@property
def completed(self) -> int:
"""Number of fetch operations that succeeded."""
with self._lock:
return self._completed

@property
def failed(self) -> int:
"""Number of fetch operations that failed."""
with self._lock:
return self._failed

@property
def cache_hits(self) -> int:
"""Number of fetch operations skipped due to cache reuse."""
with self._lock:
return self._cache_hits

def record_planned(self, n: int = 1) -> None:
"""Register *n* newly planned fetch tasks."""
with self._lock:
self._total += max(0, int(n))

def record_cache_hits(self, n: int = 1) -> None:
"""Register *n* fetches skipped due to a cache hit."""
with self._lock:
self._cache_hits += max(0, int(n))

def record_success(self, *, point: int | None = None, sensor: str | None = None) -> None:
"""Register one successful fetch, optionally recording the point/sensor."""
with self._lock:
self._completed += 1
if point is not None:
self._last_point = point
if sensor is not None:
self._last_sensor = sensor

def record_failure(self) -> None:
"""Register one failed fetch."""
with self._lock:
self._failed += 1

def format_summary(self) -> str:
"""Return a compact summary line suitable for stderr logging."""
with self._lock:
t, c, f, h = self._total, self._completed, self._failed, self._cache_hits
last_pt, last_s = self._last_point, self._last_sensor
pct = int(100 * c / t) if t > 0 else 0
msg = f"[gee_fetch] total={t} | done={c} ({pct}%) | failed={f} | cached={h}"
if last_pt is not None and last_s is not None:
msg += f" | last=point:{last_pt},sensor:{last_s}"
return msg

def log(self) -> None:
"""Write the current summary to stderr, respecting any active tqdm bar."""
msg = self.format_summary()
try:
from tqdm import tqdm

tqdm.write(msg, file=sys.stderr)
except Exception:
sys.stderr.write(msg + "\n")
sys.stderr.flush()


def create_progress(*, enabled: bool, total: int, desc: str, unit: str = "item") -> Any:
"""Create a progress bar, falling back gracefully when tqdm is unavailable.

Expand Down
Loading
Loading