@@ -168,28 +168,28 @@ def call_column(column, task):
168
168
return table
169
169
170
170
171
- class DivergenceBarColumn (BarColumn ):
172
- """Rich colorbar that changes color when a chain has detected a divergence ."""
171
+ class RecolorOnFailureBarColumn (BarColumn ):
172
+ """Rich colorbar that changes color when a chain has detected a failure ."""
173
173
174
- def __init__ (self , * args , diverging_color = "red" , ** kwargs ):
174
+ def __init__ (self , * args , failing_color = "red" , ** kwargs ):
175
175
from matplotlib .colors import to_rgb
176
176
177
- self .diverging_color = diverging_color
178
- self .diverging_rgb = [int (x * 255 ) for x in to_rgb (self .diverging_color )]
177
+ self .failing_color = failing_color
178
+ self .failing_rgb = [int (x * 255 ) for x in to_rgb (self .failing_color )]
179
179
180
180
super ().__init__ (* args , ** kwargs )
181
181
182
- self .non_diverging_style = self .complete_style
183
- self .non_diverging_finished_style = self .finished_style
182
+ self .default_complete_style = self .complete_style
183
+ self .default_finished_style = self .finished_style
184
184
185
185
def callbacks (self , task : "Task" ):
186
- divergences = task .fields .get ("divergences" , 0 )
187
- if isinstance (divergences , float | int ) and divergences > 0 :
188
- self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
189
- self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .diverging_rgb ))
186
+ if task .fields ["failing" ]:
187
+ self .complete_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
188
+ self .finished_style = Style .parse ("rgb({},{},{})" .format (* self .failing_rgb ))
190
189
else :
191
- self .complete_style = self .non_diverging_style
192
- self .finished_style = self .non_diverging_finished_style
190
+ # Recovered from failing yay
191
+ self .complete_style = self .default_complete_style
192
+ self .finished_style = self .default_finished_style
193
193
194
194
195
195
class ProgressBarManager :
@@ -284,7 +284,6 @@ def __init__(
284
284
self .update_stats_functions = step_method ._make_progressbar_update_functions ()
285
285
286
286
self ._show_progress = show_progress
287
- self .divergences = 0
288
287
self .completed_draws = 0
289
288
self .total_draws = draws + tune
290
289
self .desc = "Sampling chain"
@@ -311,6 +310,7 @@ def _initialize_tasks(self):
311
310
chain_idx = 0 ,
312
311
sampling_speed = 0 ,
313
312
speed_unit = "draws/s" ,
313
+ failing = False ,
314
314
** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
315
315
)
316
316
]
@@ -325,6 +325,7 @@ def _initialize_tasks(self):
325
325
chain_idx = chain_idx ,
326
326
sampling_speed = 0 ,
327
327
speed_unit = "draws/s" ,
328
+ failing = False ,
328
329
** {stat : value [chain_idx ] for stat , value in self .progress_stats .items ()},
329
330
)
330
331
for chain_idx in range (self .chains )
@@ -354,27 +355,30 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
354
355
elapsed = self ._progress .tasks [chain_idx ].elapsed
355
356
speed , unit = self .compute_draw_speed (elapsed , draw )
356
357
357
- if not tuning and stats and stats [0 ].get ("diverging" ):
358
- self .divergences += 1
359
-
360
358
if self .full_stats :
361
- # TODO: Index by chain already?
359
+ failing = False
360
+ all_step_stats = {}
361
+
362
362
chain_progress_stats = [
363
- update_states_fn (step_stats )
364
- for update_states_fn , step_stats in zip (
363
+ update_stats_fn (step_stats )
364
+ for update_stats_fn , step_stats in zip (
365
365
self .update_stats_functions , stats , strict = True
366
366
)
367
367
]
368
- all_step_stats = {}
369
368
for step_stats in chain_progress_stats :
370
369
for key , val in step_stats .items ():
370
+ if key == "failing" :
371
+ failing |= val
372
+ continue
373
+
371
374
if key in all_step_stats :
372
375
# TODO: Figure out how to integrate duplicate / non-scalar keys, ignoring them for now
373
376
continue
374
377
else :
375
378
all_step_stats [key ] = val
376
379
377
380
else :
381
+ failing = False
378
382
all_step_stats = {}
379
383
380
384
self ._progress .update (
@@ -383,6 +387,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats):
383
387
draws = draw ,
384
388
sampling_speed = speed ,
385
389
speed_unit = unit ,
390
+ failing = failing ,
386
391
** all_step_stats ,
387
392
)
388
393
@@ -410,9 +415,9 @@ def create_progress_bar(self, step_columns, progressbar, progressbar_theme):
410
415
]
411
416
412
417
return CustomProgress (
413
- DivergenceBarColumn (
418
+ RecolorOnFailureBarColumn (
414
419
table_column = Column ("Progress" , ratio = 2 ),
415
- diverging_color = "tab:red" ,
420
+ failing_color = "tab:red" ,
416
421
complete_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
417
422
finished_style = Style .parse ("rgb(31,119,180)" ), # tab:blue
418
423
),
0 commit comments