-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow for pymc native samplers to resume sampling from ZarrTrace
#7687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ZarrTrace
vars=trace_vars, | ||
test_point=initial_point, | ||
) | ||
except TraceAlreadyInitialized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just InitializedTrace
? Seems a little verbose!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fine to me, it's an internal thing
pymc/sampling/mcmc.py
Outdated
if isinstance(trace, ZarrChain): | ||
progress_manager.set_initial_state(*trace.completed_draws_and_divergences()) | ||
progress_manager._progress.update( | ||
progress_manager.tasks[i], | ||
draws=progress_manager.completed_draws | ||
if progress_manager.combined_progress | ||
else progress_manager.draws, | ||
divergences=progress_manager.divergences, | ||
refresh=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't like this abstraction leaking elsewhere, just provide a default to the Ndarray backend that makes it work for either method. In that case I suppose start everything at zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to improve the abstraction to prevent most of the custom ZarrChain
checks
pymc/sampling/mcmc.py
Outdated
if isinstance(trace, ZarrChain): | ||
trace.link_stepper(step) | ||
stored_draw_idx = trace._sampling_state.draw_idx[chain] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here all this logic including the old link_stepper
can have a sensible default in the base trace class so you don't need to worry about what kind of trace you have here. Just make link_stepper
a no op and stored_draw_idx
to be zero by default?
if stored_draw_idx > 0: | ||
if stored_sampling_state is not None: | ||
self._step_method.sampling_state = stored_sampling_state | ||
else: | ||
raise RuntimeError( | ||
"Cannot use the supplied ZarrTrace to restart sampling because " | ||
"it has no sampling_state information stored. You will have to " | ||
"resample from scratch." | ||
) | ||
draw = stored_draw_idx | ||
self._write_point(trace.get_mcmc_point()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicated logic, should be a property of the backend object?
pymc/sampling/parallel.py
Outdated
@@ -491,6 +509,10 @@ def __init__( | |||
progressbar=progressbar, | |||
progressbar_theme=progressbar_theme, | |||
) | |||
if self.zarr_recording: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abstraction leaking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the new functionality, I am deeply against all the if isinstance(..., ZarrTrace)
in the codebase. Either our code is supposed to allow different trace backends or it is not, this suggests you want to drop the Ndarray altogether, which fine if you do.
Otherwise all these cases seem like they could be handled by the BaseTrace having sensible default for these methods. We used to have continuation of traces in the past with Ndarray, I don't see anything that fundamentally needs ZarrTrace other than dev interest in it? So just make it raise NotImplementedErrors
or make them no-ops and adjust the external code appropriately
I stopped half-way so it was not an extensive review. I think this is a bigger design point that needs decision before settling on the details of the PR.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7687 +/- ##
==========================================
- Coverage 92.83% 92.79% -0.04%
==========================================
Files 107 107
Lines 18354 18708 +354
==========================================
+ Hits 17039 17361 +322
- Misses 1315 1347 +32
🚀 New features to boost your workflow:
|
@@ -201,6 +201,42 @@ def _slice(self, idx: slice) -> "IBaseTrace": | |||
def point(self, idx: int) -> dict[str, np.ndarray]: | |||
return self._chain.get_draws_at(idx, [v.name for v in self._chain.variables.values()]) | |||
|
|||
def completed_draws_and_divergences(self, chain_specific: bool = True) -> tuple[int, int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question why do we have special handling for divergences, instead of being sampler agnostic?
# We only need to pass the traces when zarr_recording is happening because | ||
# it's the only backend that can resume sampling | ||
traces=traces if zarr_recording else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Someone could implement a different trace outside of PyMC so why assume this?
@@ -194,6 +195,24 @@ def _start_loop(self): | |||
|
|||
draw = 0 | |||
tuning = True | |||
if self._zarr_recording: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a method of the trace, still zarr specific
traces_send: Sequence[IBaseTrace] | bytes | None = None | ||
if traces_pickled is not None: | ||
traces_send = traces_pickled | ||
elif traces is not None: | ||
if mp_ctx.get_start_method() == "spawn": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note this will be the default soon in linux as well (fork is being deprecated). In case we are not automating something we should be automating
return compatible_dataclass_values(self, other) | ||
|
||
|
||
def resolve_typehint(hint: Any, anchor: object = None) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These new functions seems like a big code smell. Can we approach it differently?
with pytest.raises( | ||
AssertionError, match="The supplied state is incompatible with the current sampling state." | ||
): | ||
b1.sampling_state = b5.sampling_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, I still dislike magic property. A method call is much more obvious that some checks may be done, and you're not just doing a random assignment. It also gives you a hatch to disable them if you're confident it's correct and want to avoid costly checks, because you're allowed to have kwargs then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And you can't refactor properties into methods with back-compat. We had stuff with model.logp
and it was a pita for saving two parenthesis back in the day
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look okay, I still have some reservations on code modularity and complexity. I don't like the extra work we're doing to validate states are compatible. If it's that hard just don't do it.
Some places still treat Zarr specially without much reason for it.
I'm okay with approving this next iteration / if you disagree, so this doesn't hang for ever
Description
Big PR approaching! This finishes adding the ability of pymc native step methods to resume sampling from an existing trace (as long as it's a
ZarrTrace
!). This means that you can now continue tuning or sampling from a pre-existing sample run. For exampleAnother thing is that the
chunks_per_draw
fromZarrTrace
along with its persistent storage backends (likeZipStore
orDirectoryStore
) makes the sampling store the results and final sampling state periodically, so in case of a crash during sampling, you can use the existing store to load the trace usingZarrTrace.from_store
and then resume sampling from there.The only thing that I haven't tested for yet is to add an
Op
that makespm.sample
crash to see if I can reload the partial results from the store and resume sampling. @ricardoV94 gave me some pointers to that, but I won't be working on this for the rest of the month and I thought it best to open a draft PR to kick off any discussion you have or collect feedbackRelated Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7687.org.readthedocs.build/en/7687/