diff --git a/CHANGELOG.md b/CHANGELOG.md index fa24547..7be4120 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ The format is based on Keep a Changelog, and the project follows Semantic Versio ## [Unreleased] +### Added + +- **GEE fetch statistics reporting in `export_batch`.** When `show_progress=True`, a `[gee_fetch]` summary line is now printed to stderr after each prefetch chunk completes, reporting total planned fetches, completed, failed, cache hits, and the most recently processed point/sensor. This gives users visibility into GEE quota consumption, cache reuse, and whether runtime is dominated by fetching vs. inference. No output is emitted when `show_progress=False` or when no GEE provider is involved (e.g. precomputed models). The underlying `FetchStats` class in `tools/progress.py` is thread-safe and accumulates counts cumulatively across chunks. + ### Fixed - **CLI `ModuleNotFoundError` on import.** `rs_embed.cli` was importing from `rs_embed.export` and `rs_embed.inspect`, two modules that do not exist in the current package layout. The imports now point directly to `rs_embed.api` (`export_batch`, `inspect_gee_patch`). The `export-npz` subcommand call site has been updated to match `export_batch`'s current signature: a single spatial argument is wrapped in `spatials=[...]`, the output path becomes `ExportTarget.combined(args.out)`, and the flat boolean flags (`save_inputs`, `save_manifest`, etc.) are grouped into an `ExportConfig` object. The stub injection in `tests/test_cli_parsers.py` that was masking the broken imports has been removed and the integration test updated to patch `cli.export_batch` instead of `cli.export_npz`. diff --git a/src/rs_embed/pipelines/exporter.py b/src/rs_embed/pipelines/exporter.py index 5489089..9f4e2af 100644 --- a/src/rs_embed/pipelines/exporter.py +++ b/src/rs_embed/pipelines/exporter.py @@ -24,6 +24,7 @@ point_resume_manifest, summarize_status, ) +from ..tools.progress import FetchStats from ..tools.progress import create_progress as _default_create_progress from ..tools.serialization import ( sanitize_key, @@ -155,6 +156,8 @@ def _run_per_item(self) -> list[dict[str, Any]]: model_progress, on_model_done = self._create_model_progress(total=len(pending_idxs)) + fetch_stats = FetchStats() if (bool(cfg.show_progress) and need_prefetch) else None + # Chunk pipeline csize = cfg.effective_chunk_size chunk_groups = [pending_idxs[s : s + csize] for s in range(0, len(pending_idxs), csize)] @@ -166,7 +169,7 @@ def _run_per_item(self) -> list[dict[str, Any]]: if need_prefetch and chunk_groups: prefetch_pipeline_ex = ThreadPoolExecutor(max_workers=1) prefetched_chunk_fut = prefetch_pipeline_ex.submit( - self._prefetch_chunk, active_prefetch, chunk_groups[0] + self._prefetch_chunk, active_prefetch, chunk_groups[0], fetch_stats ) for chunk_idx, idxs in enumerate(chunk_groups): @@ -174,6 +177,8 @@ def _run_per_item(self) -> list[dict[str, Any]]: if prefetched_chunk_fut is not None: prefetched_chunk_fut.result() prefetched_chunk_fut = None + if fetch_stats is not None: + fetch_stats.log() # Kick off next chunk prefetch next_prefetch: PrefetchManager | None = None @@ -182,7 +187,10 @@ def _run_per_item(self) -> list[dict[str, Any]]: # is isolated from the current chunk while both are live. next_prefetch = self._clone_prefetch(active_prefetch) prefetched_chunk_fut = prefetch_pipeline_ex.submit( - self._prefetch_chunk, next_prefetch, chunk_groups[chunk_idx + 1] + self._prefetch_chunk, + next_prefetch, + chunk_groups[chunk_idx + 1], + fetch_stats, ) # Chunk-level batch inference when useful for GPU or precomputed models. @@ -267,9 +275,17 @@ def _run_combined(self) -> dict[str, Any]: unit="step", ) + fetch_stats = ( + FetchStats() if (bool(cfg.show_progress) and prefetch.provider is not None) else None + ) + # Prefetch all inputs if prefetch.provider is not None and tasks: - prefetch.fetch_chunk(all_idxs, self.spatials, self.temporal, progress=progress) + prefetch.fetch_chunk( + all_idxs, self.spatials, self.temporal, progress=progress, fetch_stats=fetch_stats + ) + if fetch_stats is not None: + fetch_stats.log() # Store prefetch checkpoint if prefetch.provider is not None: @@ -504,9 +520,14 @@ def _should_batch_per_item(self) -> bool: # get_embeddings_batch(); this is useful on CPU and independent of input_prep. return any(bool(mc.is_precomputed) for mc in self.models) - def _prefetch_chunk(self, prefetch: PrefetchManager, idxs: list[int]) -> None: + def _prefetch_chunk( + self, + prefetch: PrefetchManager, + idxs: list[int], + fetch_stats: FetchStats | None = None, + ) -> None: """Prefetch a chunk of inputs (for use in pipelined prefetch).""" - prefetch.fetch_chunk(idxs, self.spatials, self.temporal) + prefetch.fetch_chunk(idxs, self.spatials, self.temporal, fetch_stats=fetch_stats) def _clone_prefetch(self, src: PrefetchManager) -> PrefetchManager: """Clone a PrefetchManager preserving the plan with fresh per-chunk caches.""" diff --git a/src/rs_embed/pipelines/prefetch.py b/src/rs_embed/pipelines/prefetch.py index 7957b1a..c02f3f5 100644 --- a/src/rs_embed/pipelines/prefetch.py +++ b/src/rs_embed/pipelines/prefetch.py @@ -20,6 +20,7 @@ select_prefetched_channels, ) from ..tools.normalization import normalize_input_array +from ..tools.progress import FetchStats from .runner import run_with_retry @@ -124,12 +125,21 @@ def fetch_chunk( temporal: TemporalSpec | None, *, progress: Any = None, + fetch_stats: FetchStats | None = None, ) -> None: """Prefetch provider inputs for *idxs* into ``self.cache``.""" if not self.enabled or self.provider is None: return tasks = self.build_tasks(idxs, spatials) + + if fetch_stats is not None: + possible = len(idxs) * len(self.fetch_sensor_by_key) + cache_hits = possible - len(tasks) + if cache_hits > 0: + fetch_stats.record_cache_hits(cache_hits) + fetch_stats.record_planned(len(tasks)) + if not tasks: return @@ -176,6 +186,8 @@ def _fetch_one( err_s = repr(e) for member_skey in self.fetch_members.get(skey, []): self.errors[(i, member_skey)] = err_s + if fetch_stats is not None: + fetch_stats.record_failure() if progress is not None: progress.update(1) continue @@ -203,12 +215,20 @@ def _fetch_one( if not cfg.continue_on_error: raise err self.errors[(i, member_skey)] = repr(err) + if fetch_stats is not None: + fetch_stats.record_failure() continue self.input_reports[(i, member_skey)] = rep self.cache[(i, member_skey)] = x_member if fmeta: self.fetch_meta[(i, member_skey)] = fmeta + if fetch_stats is not None: + fsensor = self.fetch_sensor_by_key.get(skey) + fetch_stats.record_success( + point=i, + sensor=fsensor.collection if fsensor is not None else None, + ) if progress is not None: progress.update(1) diff --git a/src/rs_embed/tools/progress.py b/src/rs_embed/tools/progress.py index a2e6e1a..959b255 100644 --- a/src/rs_embed/tools/progress.py +++ b/src/rs_embed/tools/progress.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import threading from typing import Any @@ -44,6 +45,93 @@ def close(self) -> None: sys.stderr.flush() +class FetchStats: + """Thread-safe accumulator for GEE image fetch statistics. + + Updated by :class:`~rs_embed.pipelines.prefetch.PrefetchManager` during + ``fetch_chunk`` and surfaced as log messages when progress reporting is on. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._total = 0 + self._completed = 0 + self._failed = 0 + self._cache_hits = 0 + self._last_point: int | None = None + self._last_sensor: str | None = None + + @property + def total(self) -> int: + """Total GEE fetch operations planned across all chunks.""" + with self._lock: + return self._total + + @property + def completed(self) -> int: + """Number of fetch operations that succeeded.""" + with self._lock: + return self._completed + + @property + def failed(self) -> int: + """Number of fetch operations that failed.""" + with self._lock: + return self._failed + + @property + def cache_hits(self) -> int: + """Number of fetch operations skipped due to cache reuse.""" + with self._lock: + return self._cache_hits + + def record_planned(self, n: int = 1) -> None: + """Register *n* newly planned fetch tasks.""" + with self._lock: + self._total += max(0, int(n)) + + def record_cache_hits(self, n: int = 1) -> None: + """Register *n* fetches skipped due to a cache hit.""" + with self._lock: + self._cache_hits += max(0, int(n)) + + def record_success(self, *, point: int | None = None, sensor: str | None = None) -> None: + """Register one successful fetch, optionally recording the point/sensor.""" + with self._lock: + self._completed += 1 + if point is not None: + self._last_point = point + if sensor is not None: + self._last_sensor = sensor + + def record_failure(self) -> None: + """Register one failed fetch.""" + with self._lock: + self._failed += 1 + + def format_summary(self) -> str: + """Return a compact summary line suitable for stderr logging.""" + with self._lock: + t, c, f, h = self._total, self._completed, self._failed, self._cache_hits + last_pt, last_s = self._last_point, self._last_sensor + pct = int(100 * c / t) if t > 0 else 0 + msg = f"[gee_fetch] total={t} | done={c} ({pct}%) | failed={f} | cached={h}" + if last_pt is not None and last_s is not None: + msg += f" | last=point:{last_pt},sensor:{last_s}" + return msg + + def log(self) -> None: + """Write the current summary to stderr, respecting any active tqdm bar.""" + msg = self.format_summary() + try: + from tqdm import tqdm + + tqdm.write(msg, file=sys.stderr) + except Exception: + sys.stderr.write(msg + "\n") + sys.stderr.flush() + + def create_progress(*, enabled: bool, total: int, desc: str, unit: str = "item") -> Any: """Create a progress bar, falling back gracefully when tqdm is unavailable. diff --git a/tests/test_fetch_stats.py b/tests/test_fetch_stats.py new file mode 100644 index 0000000..5cfc94b --- /dev/null +++ b/tests/test_fetch_stats.py @@ -0,0 +1,686 @@ +"""Tests for GEE fetch statistics (FetchStats + PrefetchManager integration).""" + +from __future__ import annotations + +import threading +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from rs_embed.core import registry +from rs_embed.core.embedding import Embedding +from rs_embed.core.specs import OutputSpec, PointBuffer, SensorSpec, TemporalSpec +from rs_embed.core.types import ExportConfig, ExportTarget +from rs_embed.tools.progress import FetchStats +from rs_embed.tools.runtime import get_embedder_bundle_cached + +# ── FetchStats unit tests ────────────────────────────────────────────────────── + + +def test_fetch_stats_initial_state(): + stats = FetchStats() + assert stats.total == 0 + assert stats.completed == 0 + assert stats.failed == 0 + assert stats.cache_hits == 0 + + +def test_fetch_stats_record_planned(): + stats = FetchStats() + stats.record_planned(5) + assert stats.total == 5 + stats.record_planned(3) + assert stats.total == 8 + + +def test_fetch_stats_record_planned_ignores_negative(): + stats = FetchStats() + stats.record_planned(-1) + assert stats.total == 0 + + +def test_fetch_stats_record_cache_hits(): + stats = FetchStats() + stats.record_cache_hits(4) + assert stats.cache_hits == 4 + stats.record_cache_hits(2) + assert stats.cache_hits == 6 + + +def test_fetch_stats_record_cache_hits_ignores_negative(): + stats = FetchStats() + stats.record_cache_hits(-3) + assert stats.cache_hits == 0 + + +def test_fetch_stats_record_success(): + stats = FetchStats() + stats.record_planned(3) + stats.record_success() + stats.record_success() + assert stats.completed == 2 + assert stats.failed == 0 + + +def test_fetch_stats_record_failure(): + stats = FetchStats() + stats.record_planned(3) + stats.record_failure() + assert stats.failed == 1 + assert stats.completed == 0 + + +def test_fetch_stats_mixed_outcomes(): + stats = FetchStats() + stats.record_planned(10) + stats.record_cache_hits(3) + for _ in range(5): + stats.record_success() + stats.record_failure() + assert stats.total == 10 + assert stats.completed == 5 + assert stats.failed == 1 + assert stats.cache_hits == 3 + + +def test_fetch_stats_format_summary_zero(): + stats = FetchStats() + s = stats.format_summary() + assert "gee_fetch" in s + assert "done=0" in s + assert "failed=0" in s + + +def test_fetch_stats_format_summary_with_data(): + stats = FetchStats() + stats.record_planned(10) + stats.record_cache_hits(2) + for _ in range(7): + stats.record_success() + stats.record_failure() + s = stats.format_summary() + assert "total=10" in s + assert "done=7" in s + assert "failed=1" in s + assert "cached=2" in s + + +def test_fetch_stats_format_summary_percentage(): + stats = FetchStats() + stats.record_planned(4) + for _ in range(2): + stats.record_success() + s = stats.format_summary() + assert "50%" in s + + +def test_fetch_stats_format_summary_100_percent(): + stats = FetchStats() + stats.record_planned(5) + for _ in range(5): + stats.record_success() + s = stats.format_summary() + assert "100%" in s + + +def test_fetch_stats_format_summary_no_last_when_none(): + stats = FetchStats() + stats.record_planned(1) + stats.record_success() + s = stats.format_summary() + assert "last=" not in s + + +def test_fetch_stats_format_summary_shows_last_point_and_sensor(): + stats = FetchStats() + stats.record_planned(2) + stats.record_success(point=3, sensor="COPERNICUS/S2_SR_HARMONIZED") + stats.record_success(point=7, sensor="COPERNICUS/S1_GRD") + s = stats.format_summary() + assert "last=point:7" in s + assert "sensor:COPERNICUS/S1_GRD" in s + + +def test_fetch_stats_record_success_partial_info(): + stats = FetchStats() + stats.record_planned(1) + stats.record_success(point=5) # no sensor + s = stats.format_summary() + assert "last=" not in s # both must be set to show the field + + +def test_fetch_stats_log_writes_to_stderr(capsys): + stats = FetchStats() + stats.record_planned(3) + stats.record_success() + stats.log() + captured = capsys.readouterr() + assert "gee_fetch" in captured.err + assert "total=3" in captured.err + + +def test_fetch_stats_thread_safety(): + stats = FetchStats() + n = 200 + + def _worker(): + stats.record_planned(1) + stats.record_success() + + threads = [threading.Thread(target=_worker) for _ in range(n)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert stats.total == n + assert stats.completed == n + assert stats.failed == 0 + + +def test_fetch_stats_thread_safety_mixed(): + stats = FetchStats() + n_success = 100 + n_fail = 50 + n_cache = 30 + + def _success(): + stats.record_planned(1) + stats.record_success() + + def _fail(): + stats.record_planned(1) + stats.record_failure() + + def _cache(): + stats.record_cache_hits(1) + + threads = ( + [threading.Thread(target=_success) for _ in range(n_success)] + + [threading.Thread(target=_fail) for _ in range(n_fail)] + + [threading.Thread(target=_cache) for _ in range(n_cache)] + ) + for t in threads: + t.start() + for t in threads: + t.join() + + assert stats.total == n_success + n_fail + assert stats.completed == n_success + assert stats.failed == n_fail + assert stats.cache_hits == n_cache + + +# ── PrefetchManager integration ──────────────────────────────────────────────── + + +@pytest.fixture(autouse=True) +def clean_registry(): + registry._REGISTRY.clear() + yield + registry._REGISTRY.clear() + + +def _make_sensor() -> SensorSpec: + return SensorSpec(collection="TEST", bands=["B1", "B2"], scale_m=10) + + +def _make_prefetch_manager(fetch_fn=None, inspect_fn=None): + from rs_embed.core.types import ExportConfig + from rs_embed.pipelines.prefetch import PrefetchManager + + sensor = _make_sensor() + provider = MagicMock() + provider.ensure_ready = MagicMock() + provider.normalize_bands = None + + cfg = ExportConfig( + save_inputs=True, + save_embeddings=True, + num_workers=1, + continue_on_error=True, + ) + + default_fetch = fetch_fn or (lambda *a, **kw: np.ones((2, 4, 4), dtype=np.float32)) + default_inspect = inspect_fn or (lambda *a, **kw: {"ok": True}) + + pm = PrefetchManager( + provider=provider, + models=["m1"], + resolved_sensor={"m1": sensor}, + model_type={"m1": "onthefly"}, + config=cfg, + fetch_fn=default_fetch, + inspect_fn=default_inspect, + ) + pm.plan() + return pm, provider + + +def test_prefetch_manager_fetch_stats_records_planned(): + pm, _ = _make_prefetch_manager() + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10), PointBuffer(lon=1, lat=1, buffer_m=10)] + temporal = TemporalSpec.year(2022) + + stats = FetchStats() + pm.fetch_chunk([0, 1], spatials, temporal, fetch_stats=stats) + + assert stats.total == 2 + assert stats.completed == 2 + assert stats.failed == 0 + assert stats.cache_hits == 0 + + +def test_prefetch_manager_fetch_stats_records_cache_hits(): + pm, _ = _make_prefetch_manager() + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10), PointBuffer(lon=1, lat=1, buffer_m=10)] + temporal = TemporalSpec.year(2022) + + # First fetch populates cache + stats = FetchStats() + pm.fetch_chunk([0, 1], spatials, temporal, fetch_stats=stats) + assert stats.total == 2 + assert stats.completed == 2 + assert stats.cache_hits == 0 + + # Second fetch for same indices: all should be cache hits + stats2 = FetchStats() + pm.fetch_chunk([0, 1], spatials, temporal, fetch_stats=stats2) + assert stats2.total == 0 + assert stats2.cache_hits == 2 + assert stats2.completed == 0 + + +def test_prefetch_manager_fetch_stats_partial_cache_hit(): + pm, _ = _make_prefetch_manager() + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10), PointBuffer(lon=1, lat=1, buffer_m=10)] + temporal = TemporalSpec.year(2022) + + # Fetch only index 0 first + stats1 = FetchStats() + pm.fetch_chunk([0], spatials, temporal, fetch_stats=stats1) + assert stats1.total == 1 + assert stats1.completed == 1 + assert stats1.cache_hits == 0 + + # Now fetch both: index 0 is cached, index 1 is new + stats2 = FetchStats() + pm.fetch_chunk([0, 1], spatials, temporal, fetch_stats=stats2) + assert stats2.total == 1 # only index 1 needs fetching + assert stats2.completed == 1 + assert stats2.cache_hits == 1 # index 0 is cached + + +def test_prefetch_manager_fetch_stats_records_failures(): + def _failing_fetch(*a, **kw): + raise RuntimeError("GEE fetch failed") + + pm, _ = _make_prefetch_manager(fetch_fn=_failing_fetch) + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + temporal = TemporalSpec.year(2022) + + stats = FetchStats() + # continue_on_error=True so it won't raise + pm.fetch_chunk([0], spatials, temporal, fetch_stats=stats) + + assert stats.total == 1 + assert stats.failed == 1 + assert stats.completed == 0 + + +def test_prefetch_manager_fetch_stats_none_is_noop(): + """Passing fetch_stats=None should not cause any errors.""" + pm, _ = _make_prefetch_manager() + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + temporal = TemporalSpec.year(2022) + + # Should not raise + pm.fetch_chunk([0], spatials, temporal, fetch_stats=None) + + +def test_prefetch_manager_fetch_stats_no_provider(): + """When provider is None, fetch_chunk returns early; stats stay at zero.""" + from rs_embed.core.types import ExportConfig + from rs_embed.pipelines.prefetch import PrefetchManager + + sensor = _make_sensor() + cfg = ExportConfig(save_inputs=True, save_embeddings=True) + pm = PrefetchManager( + provider=None, + models=["m1"], + resolved_sensor={"m1": sensor}, + model_type={"m1": "onthefly"}, + config=cfg, + ) + pm.plan() + + stats = FetchStats() + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + pm.fetch_chunk([0], spatials, TemporalSpec.year(2022), fetch_stats=stats) + + assert stats.total == 0 + assert stats.completed == 0 + + +# ── BatchExporter / export_batch integration ────────────────────────────────── + + +@pytest.fixture(autouse=True) +def disable_real_progress(monkeypatch): + import rs_embed.api as api + + class _NoOpProgress: + def update(self, n: int = 1): + pass + + def close(self): + pass + + monkeypatch.setattr( + api, + "_create_progress", + lambda *, enabled, total, desc, unit="item": _NoOpProgress(), + ) + + +def _register_onthefly(name: str): + class DummyOntheFly: + def describe(self): + return { + "type": "onthefly", + "inputs": {"collection": "C", "bands": ["B1", "B2", "B3"]}, + "defaults": { + "scale_m": 10, + "cloudy_pct": 30, + "composite": "median", + "fill_value": 0.0, + }, + } + + def get_embedding( + self, *, spatial, temporal, sensor, output, backend, device="auto", input_chw=None + ): + return Embedding(data=np.array([1.0], dtype=np.float32), meta={}) + + DummyOntheFly.__name__ = name + registry.register(name)(DummyOntheFly) + + +def _patch_gee(monkeypatch): + """Patch GEE provider and fetch function to avoid real network calls.""" + + class DummyProvider: + def __init__(self, *a, **kw): + pass + + def ensure_ready(self): + pass + + monkeypatch.setattr("rs_embed.tools.runtime.get_provider", lambda _name, **_kw: DummyProvider()) + monkeypatch.setattr( + "rs_embed.providers.fetch.fetch_sensor_patch_chw", + lambda provider, *, spatial, temporal, sensor: np.ones((3, 8, 8), dtype=np.float32), + ) + monkeypatch.setattr( + "rs_embed.providers.fetch.inspect_fetch_result", + lambda x_chw, *, sensor, name: {"ok": True}, + ) + + +def test_export_batch_fetch_stats_logged_per_item(tmp_path, monkeypatch, capsys): + import rs_embed.api as api + + _register_onthefly("dummy_otf_stats") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [ + PointBuffer(lon=0, lat=0, buffer_m=10), + PointBuffer(lon=1, lat=1, buffer_m=10), + PointBuffer(lon=2, lat=2, buffer_m=10), + ] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_stats"], + target=ExportTarget.per_item(str(tmp_path / "out")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=True, + chunk_size=2, + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "[gee_fetch]" in captured.err + + +def test_export_batch_fetch_stats_not_logged_when_progress_disabled(tmp_path, monkeypatch, capsys): + import rs_embed.api as api + + _register_onthefly("dummy_otf_nolog") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_nolog"], + target=ExportTarget.per_item(str(tmp_path / "out")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=False, + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "[gee_fetch]" not in captured.err + + +def test_export_batch_combined_fetch_stats_logged(tmp_path, monkeypatch, capsys): + import rs_embed.api as api + + _register_onthefly("dummy_otf_combined_stats") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [ + PointBuffer(lon=0, lat=0, buffer_m=10), + PointBuffer(lon=1, lat=1, buffer_m=10), + ] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_combined_stats"], + target=ExportTarget.combined(str(tmp_path / "combined.npz")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=True, + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "[gee_fetch]" in captured.err + + +def test_export_batch_combined_fetch_stats_not_logged_when_disabled(tmp_path, monkeypatch, capsys): + import rs_embed.api as api + + _register_onthefly("dummy_otf_combined_nolog") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_combined_nolog"], + target=ExportTarget.combined(str(tmp_path / "combined.npz")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=False, + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "[gee_fetch]" not in captured.err + + +def test_export_batch_fetch_stats_counts_in_log(tmp_path, monkeypatch, capsys): + """The log line should reflect the correct completed count for a single chunk.""" + import rs_embed.api as api + + _register_onthefly("dummy_otf_counts") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [ + PointBuffer(lon=0, lat=0, buffer_m=10), + PointBuffer(lon=1, lat=1, buffer_m=10), + ] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_counts"], + target=ExportTarget.per_item(str(tmp_path / "out")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=True, + chunk_size=10, # single chunk: both fetches happen together + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "done=2" in captured.err + assert "total=2" in captured.err + + +def test_export_batch_fetch_stats_multi_chunk_cumulative(tmp_path, monkeypatch, capsys): + """Stats accumulate across multiple chunks (chunk_size=1 forces one chunk per point).""" + import rs_embed.api as api + + _register_onthefly("dummy_otf_multi_chunk") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [ + PointBuffer(lon=0, lat=0, buffer_m=10), + PointBuffer(lon=1, lat=1, buffer_m=10), + PointBuffer(lon=2, lat=2, buffer_m=10), + ] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_multi_chunk"], + target=ExportTarget.per_item(str(tmp_path / "out")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=True, + chunk_size=1, # one point per chunk → 3 separate fetch log lines + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + lines = [ln for ln in captured.err.splitlines() if "[gee_fetch]" in ln] + # Each chunk completion emits one log line; with chunk_size=1 and 3 points we get 3 + assert len(lines) >= 1 + # The final log line should show all 3 fetches completed (stats are cumulative) + assert "done=3" in captured.err + + +def test_export_batch_fetch_stats_no_provider_no_log(tmp_path, monkeypatch, capsys): + """Precomputed models (no GEE provider) must not emit gee_fetch log lines.""" + import rs_embed.api as api + + class DummyPrecomputed: + def describe(self): + return {"type": "precomputed", "backend": ["local"], "output": ["pooled"]} + + def get_embedding( + self, *, spatial, temporal, sensor, output, backend, device="auto", input_chw=None + ): + return Embedding(data=np.array([1.0], dtype=np.float32), meta={}) + + registry.register("dummy_precomputed_stats")(DummyPrecomputed) + get_embedder_bundle_cached.cache_clear() + + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_precomputed_stats"], + target=ExportTarget.per_item(str(tmp_path / "out")), + config=ExportConfig( + save_inputs=False, + save_embeddings=True, + show_progress=True, + ), + backend="local", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "[gee_fetch]" not in captured.err + + +def test_export_batch_fetch_stats_last_point_and_sensor_in_log(tmp_path, monkeypatch, capsys): + """The log line includes last=point:N,sensor:collection after a successful fetch.""" + import rs_embed.api as api + + _register_onthefly("dummy_otf_last_field") + _patch_gee(monkeypatch) + get_embedder_bundle_cached.cache_clear() + + spatials = [PointBuffer(lon=0, lat=0, buffer_m=10)] + + api.export_batch( + spatials=spatials, + temporal=TemporalSpec.year(2022), + models=["dummy_otf_last_field"], + target=ExportTarget.per_item(str(tmp_path / "out")), + config=ExportConfig( + save_inputs=True, + save_embeddings=True, + show_progress=True, + num_workers=1, + ), + backend="gee", + output=OutputSpec.pooled(), + ) + + captured = capsys.readouterr() + assert "last=point:" in captured.err + assert "sensor:" in captured.err