Skip to content

Commit 129d988

Browse files
committed
Abstract special behavior of NUTS divergences in ProgressBar
Every step sampler can now decide whether sampling is failing or not by setting "failing" in the returned update dict
1 parent bfd9189 commit 129d988

File tree

4 files changed

+69
-27
lines changed

4 files changed

+69
-27
lines changed

pymc/progress_bar.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,28 +168,28 @@ def call_column(column, task):
168168
return table
169169

170170

171-
class DivergenceBarColumn(BarColumn):
172-
"""Rich colorbar that changes color when a chain has detected a divergence."""
171+
class RecolorOnFailureBarColumn(BarColumn):
172+
"""Rich colorbar that changes color when a chain has detected a failure."""
173173

174-
def __init__(self, *args, diverging_color="red", **kwargs):
174+
def __init__(self, *args, failing_color="red", **kwargs):
175175
from matplotlib.colors import to_rgb
176176

177-
self.diverging_color = diverging_color
178-
self.diverging_rgb = [int(x * 255) for x in to_rgb(self.diverging_color)]
177+
self.failing_color = failing_color
178+
self.failing_rgb = [int(x * 255) for x in to_rgb(self.failing_color)]
179179

180180
super().__init__(*args, **kwargs)
181181

182-
self.non_diverging_style = self.complete_style
183-
self.non_diverging_finished_style = self.finished_style
182+
self.default_complete_style = self.complete_style
183+
self.default_finished_style = self.finished_style
184184

185185
def callbacks(self, task: "Task"):
186-
divergences = task.fields.get("divergences", 0)
187-
if isinstance(divergences, float | int) and divergences > 0:
188-
self.complete_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
189-
self.finished_style = Style.parse("rgb({},{},{})".format(*self.diverging_rgb))
186+
if task.fields["failing"]:
187+
self.complete_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb))
188+
self.finished_style = Style.parse("rgb({},{},{})".format(*self.failing_rgb))
190189
else:
191-
self.complete_style = self.non_diverging_style
192-
self.finished_style = self.non_diverging_finished_style
190+
# Recovered from failing yay
191+
self.complete_style = self.default_complete_style
192+
self.finished_style = self.default_finished_style
193193

194194

195195
class ProgressBarManager:
@@ -284,7 +284,6 @@ def __init__(
284284
self.update_stats_functions = step_method._make_progressbar_update_functions()
285285

286286
self._show_progress = show_progress
287-
self.divergences = 0
288287
self.completed_draws = 0
289288
self.total_draws = draws + tune
290289
self.desc = "Sampling chain"
@@ -311,6 +310,7 @@ def _initialize_tasks(self):
311310
chain_idx=0,
312311
sampling_speed=0,
313312
speed_unit="draws/s",
313+
failing=False,
314314
**{stat: value[0] for stat, value in self.progress_stats.items()},
315315
)
316316
]
@@ -325,6 +325,7 @@ def _initialize_tasks(self):
325325
chain_idx=chain_idx,
326326
sampling_speed=0,
327327
speed_unit="draws/s",
328+
failing=False,
328329
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
329330
)
330331
for chain_idx in range(self.chains)
@@ -354,27 +355,30 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
354355
elapsed = self._progress.tasks[chain_idx].elapsed
355356
speed, unit = self.compute_draw_speed(elapsed, draw)
356357

357-
if not tuning and stats and stats[0].get("diverging"):
358-
self.divergences += 1
359-
360358
if self.full_stats:
361-
# TODO: Index by chain already?
359+
failing = False
360+
all_step_stats = {}
361+
362362
chain_progress_stats = [
363-
update_states_fn(step_stats)
364-
for update_states_fn, step_stats in zip(
363+
update_stats_fn(step_stats)
364+
for update_stats_fn, step_stats in zip(
365365
self.update_stats_functions, stats, strict=True
366366
)
367367
]
368-
all_step_stats = {}
369368
for step_stats in chain_progress_stats:
370369
for key, val in step_stats.items():
370+
if key == "failing":
371+
failing |= val
372+
continue
373+
371374
if key in all_step_stats:
372375
# TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
373376
continue
374377
else:
375378
all_step_stats[key] = val
376379

377380
else:
381+
failing = False
378382
all_step_stats = {}
379383

380384
self._progress.update(
@@ -383,6 +387,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
383387
draws=draw,
384388
sampling_speed=speed,
385389
speed_unit=unit,
390+
failing=failing,
386391
**all_step_stats,
387392
)
388393

@@ -410,9 +415,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
410415
]
411416

412417
return CustomProgress(
413-
DivergenceBarColumn(
418+
RecolorOnFailureBarColumn(
414419
table_column=Column("Progress", ratio=2),
415-
diverging_color="tab:red",
420+
failing_color="tab:red",
416421
complete_style=Style.parse("rgb(31,119,180)"), # tab:blue
417422
finished_style=Style.parse("rgb(31,119,180)"), # tab:blue
418423
),

pymc/step_methods/hmc/base_hmc.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def __init__(
184184

185185
self._step_rand = step_rand
186186
self._num_divs_sample = 0
187+
self.divergences = 0
187188

188189
@abstractmethod
189190
def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
@@ -266,11 +267,14 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
266267
divergence_info=info_store,
267268
)
268269

270+
diverging = bool(hmc_step.divergence_info)
271+
if not self.tune:
272+
self.divergences += diverging
269273
self.iter_count += 1
270274

271275
stats: dict[str, Any] = {
272276
"tune": self.tune,
273-
"diverging": bool(hmc_step.divergence_info),
277+
"diverging": diverging,
274278
"perf_counter_diff": perf_end - perf_start,
275279
"process_time_diff": process_end - process_start,
276280
"perf_counter_start": perf_start,
@@ -288,6 +292,8 @@ def reset_tuning(self, start=None):
288292
self.reset(start=None)
289293

290294
def reset(self, start=None):
295+
self.iter_count = 0
296+
self.divergences = 0
291297
self.tune = True
292298
self.potential.reset()
293299

pymc/step_methods/hmc/hmc.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
import numpy as np
2121

22+
from rich.progress import TextColumn
23+
from rich.table import Column
24+
2225
from pymc.stats.convergence import SamplerWarning
2326
from pymc.step_methods.compound import Competence
2427
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
@@ -55,6 +58,7 @@ class HamiltonianMC(BaseHMC):
5558
"accept": (np.float64, []),
5659
"diverging": (bool, []),
5760
"energy_error": (np.float64, []),
61+
"divergences": (np.int64, []),
5862
"energy": (np.float64, []),
5963
"path_length": (np.float64, []),
6064
"accepted": (bool, []),
@@ -202,3 +206,27 @@ def competence(var, has_grad):
202206
if var.dtype in discrete_types or not has_grad:
203207
return Competence.INCOMPATIBLE
204208
return Competence.COMPATIBLE
209+
210+
@staticmethod
211+
def _progressbar_config(n_chains=1):
212+
columns = [
213+
TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)),
214+
TextColumn("{task.fields[n_steps]}", table_column=Column("Grad evals", ratio=1)),
215+
]
216+
217+
stats = {
218+
"divergences": [0] * n_chains,
219+
"n_steps": [0] * n_chains,
220+
}
221+
222+
return columns, stats
223+
224+
def _make_progressbar_update_functions(self):
225+
def update_stats(stats):
226+
divergences = self.divergences
227+
return {key: stats[key] for key in ("n_steps",)} | {
228+
"failing": divergences > 0,
229+
"divergences": divergences,
230+
}
231+
232+
return (update_stats,)

pymc/step_methods/hmc/nuts.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,13 @@ def _progressbar_config(n_chains=1):
247247

248248
return columns, stats
249249

250-
@staticmethod
251-
def _make_update_stats_functions():
250+
def _make_update_stats_functions(self):
252251
def update_stats(stats):
253-
return {key: stats[key] for key in ("diverging", "step_size", "tree_size")}
252+
divergences = self.divergences
253+
return {key: stats[key] for key in ("step_size", "tree_size")} | {
254+
"failing": divergences > 0,
255+
"divergences": divergences,
256+
}
254257

255258
return (update_stats,)
256259

0 commit comments

Comments
 (0)