Conversation
…l but needs further testing with batching.
… are too many right now
…r of additional layers
Greptile SummaryThis PR adds a self-contained Key observations:
|
| Filename | Overview |
|---|---|
| recipes/tc_tracking/src/tempest_extremes.py | Core implementation of TempestExtremes integration — provides synchronous TempestExtremes and asynchronous AsyncTempestExtremes classes; AsyncTempestExtremes.__call__ correctly submits tracking to a background thread pool enabling GPU/CPU overlap, but cleanup()/wait_for_completion() abort on the first task failure and silently abandon remaining failing tasks. |
| recipes/tc_tracking/src/modes/generate_tc_hunt_ensembles.py | Main inference loop orchestrating ensemble generation, stability checking, and cyclone tracking; logic is sound and correctly uses the async TempestExtremes API; previously flagged debug comments have been cleaned up. |
| recipes/tc_tracking/pyproject.toml | Package metadata and dependencies; contains a placeholder description ("no, i won't") and unpinned git sources for both earth2studio and torch-harmonics, which reduce build reproducibility. |
| recipes/tc_tracking/Dockerfile | Docker build environment that compiles TempestExtremes from source; clones TempestExtremes at HEAD without a pinned tag/commit, which makes image builds non-reproducible. |
| recipes/tc_tracking/tc_hunt.py | Entry-point script with Hydra configuration; still contains an informal print("finished **yaaayyyy**") celebration message (previously flagged). |
Last reviewed commit: "TE workers"
| for ic, mems, seed in ic_mems: | ||
| mini_batch_size = len(mems) | ||
|
|
||
| data_source = data_source_mngr.select_data_source(ic) | ||
|
|
||
| # if new IC, fetch data, create iterator | ||
| if ic != ic_prev: | ||
| if cfg.store_type == "netcdf": | ||
| store = initialise_netcdf_output(cfg, out_coords, ic, ic_mems) | ||
| x0, coords0 = fetch_data( | ||
| data_source, | ||
| time=[np.datetime64(ic)], | ||
| lead_time=model.input_coords()["lead_time"], | ||
| variable=model.input_coords()["variable"], | ||
| device=dist.device, | ||
| ) | ||
| ic_prev = ic | ||
|
|
||
| coords = {"ensemble": np.array(mems)} | coords0.copy() | ||
| xx = x0.unsqueeze(0).repeat(mini_batch_size, *([1] * x0.ndim)) | ||
|
|
||
| if stability_check: | ||
| stability_check.reset(deepcopy(coords)) | ||
| # print(stability_check.input_coords) | ||
| # exit() | ||
|
|
||
| # set random state or apply perturbation | ||
| if ("model" not in cfg) or (cfg.model == "fcn3"): | ||
| model.set_rng(seed=seed) | ||
| elif ( | ||
| cfg.model[:4] == "aifs" | ||
| ): # no need for perturbation, but also cannot set internal noise state | ||
| pass | ||
| else: | ||
| sg = SphericalGaussian(noise_amplitude=0.0005) | ||
| xx, coords = sg(xx, coords) | ||
|
|
||
| iterator = model.create_iterator(xx, coords) | ||
|
|
||
| # roll out the model and record data as desired | ||
| for _, (xx, coords) in tqdm( | ||
| zip(range(cfg.n_steps + 1), iterator), total=cfg.n_steps + 1 | ||
| ): | ||
| write_to_store(store, xx, coords, out_coords) | ||
| if cyclone_tracking: | ||
| cyclone_tracking.record_state(xx, coords) | ||
|
|
||
| if stability_check: | ||
| yy, coy = map_coords(xx, coords, stability_check.input_coords) | ||
| stab, _ = stability_check(yy, coy) | ||
| if not stab.all(): | ||
| ic_mems.append((ic, mems, seed + 1)) | ||
| print( | ||
| f"CAUTION: one of members {mems} became unstable. will re-create with new seed." | ||
| ) | ||
| break |
There was a problem hiding this comment.
Unbounded retry loop for unstable members
When a member is detected as unstable (line 260), it is re-appended to ic_mems with seed + 1. Because Python's for loop over a list processes newly-appended items, this creates an unbounded retry cycle — there is no guard on how many times any given (ic, mems) combination can be re-queued.
If a particular initial condition consistently produces unstable trajectories (e.g., a known degenerate edge case), the job will never terminate. A maximum-retry counter should be tracked per (ic, seed) pair, and members that exceed the limit should be skipped with a warning rather than being re-queued indefinitely.
There was a problem hiding this comment.
in practice, such jobs will be killed by the system after exceeding their allocated time.
in a future version I want to update the scheduling anyway to something smarter as individual ensemble members might then not always take roughly the same time to execute, as they do now.
| def cleanup(self, timeout_per_task: int | None = None) -> None: | ||
| """Explicitly clean up and wait for all background tasks to complete. | ||
|
|
||
| This method should be called before the object is destroyed or the program exits | ||
| to ensure all cyclone tracking tasks complete successfully. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| timeout_per_task : int | None, optional | ||
| Timeout in seconds for each task. If None, uses self.timeout. | ||
|
|
||
| Raises | ||
| ------ | ||
| ChildProcessError | ||
| If any background task failed | ||
| Exception | ||
| If any task failed with other exceptions | ||
| """ | ||
| if self._cleanup_done: | ||
| return | ||
|
|
||
| if timeout_per_task is None: | ||
| timeout_per_task = self.timeout | ||
|
|
||
| try: | ||
| # Wait for all instance tasks to complete | ||
| if hasattr(self, "_instance_tasks") and hasattr(self, "_instance_lock"): | ||
| with self._instance_lock: | ||
| tasks_to_wait = list(self._instance_tasks) | ||
|
|
||
| if tasks_to_wait: | ||
| print( | ||
| f"AsyncTempestExtremes: waiting for {len(tasks_to_wait)} background tasks to complete..." | ||
| ) | ||
|
|
||
| for i, future in enumerate(tasks_to_wait): | ||
| try: | ||
| print(f" Waiting for task {i+1}/{len(tasks_to_wait)}...") | ||
| future.result(timeout=timeout_per_task) | ||
| print( | ||
| f" Task {i+1}/{len(tasks_to_wait)} completed successfully" | ||
| ) | ||
| except ChildProcessError as e: | ||
| print( | ||
| f" Task {i+1}/{len(tasks_to_wait)} failed with ChildProcessError: {e}" | ||
| ) | ||
| raise # Re-raise to propagate the error | ||
| except Exception as e: | ||
| print(f" Task {i+1}/{len(tasks_to_wait)} failed: {e}") | ||
| raise # Re-raise to propagate the error | ||
|
|
||
| print( | ||
| f"All {len(tasks_to_wait)} background tasks completed successfully" | ||
| ) | ||
|
|
||
| self._cleanup_done = True | ||
|
|
||
| except Exception as _: | ||
| self._cleanup_done = True # Mark as done even on failure to avoid retry | ||
| raise | ||
|
|
There was a problem hiding this comment.
wait_for_completion and cleanup abort on first failure, silently abandoning later tasks
Both wait_for_completion() and cleanup() iterate over pending tasks and raise immediately on the first error. Any subsequent task failures are never collected — their exceptions are silently swallowed by the background threads and never surfaced to the caller. In a scenario where multiple members fail concurrently, only the first error reaches the user and the remaining failed tasks are abandoned.
Consider collecting all failures before re-raising, similar to the pattern used in _run_te_and_cleanup:
errors = []
for i, future in enumerate(tasks_to_wait):
try:
future.result(timeout=timeout_per_task)
except Exception as e:
print(f" Task {i+1}/{len(tasks_to_wait)} failed: {e}")
errors.append(e)
if errors:
raise ChildProcessError(
f"{len(errors)} background task(s) failed: {errors}"
)The same change should be applied to wait_for_completion().
Earth2Studio Pull Request
Description
Checklist
Dependencies