@@ -812,6 +812,7 @@ def __init__(
812
812
813
813
self ._show_progress = show_progress
814
814
self .divergences = 0
815
+ self .draws = 0
815
816
self .completed_draws = 0
816
817
self .total_draws = draws + tune
817
818
self .desc = "Sampling chain"
@@ -827,27 +828,35 @@ def __enter__(self):
827
828
def __exit__ (self , exc_type , exc_val , exc_tb ):
828
829
return self ._progress .__exit__ (exc_type , exc_val , exc_tb )
829
830
831
+ def set_initial_state (self , draws : int = 0 , divergences : int = 0 ):
832
+ self .draws = draws
833
+ self .completed_draws += draws
834
+ self .divergences += divergences
835
+
830
836
def _initialize_tasks (self ):
831
837
if self .combined_progress :
832
838
self .tasks = [
833
839
self ._progress .add_task (
834
840
self .desc .format (self ),
835
- completed = 0 ,
836
- draws = 0 ,
841
+ completed = self . completed_draws ,
842
+ draws = self . completed_draws ,
837
843
total = self .total_draws * self .chains - 1 ,
838
844
chain_idx = 0 ,
839
845
sampling_speed = 0 ,
840
846
speed_unit = "draws/s" ,
841
- ** {stat : value [0 ] for stat , value in self .progress_stats .items ()},
847
+ ** {
848
+ stat : value [0 ] if stat != "diverging" else self .divergences
849
+ for stat , value in self .progress_stats .items ()
850
+ },
842
851
)
843
852
]
844
853
845
854
else :
846
855
self .tasks = [
847
856
self ._progress .add_task (
848
857
self .desc .format (self ),
849
- completed = 0 ,
850
- draws = 0 ,
858
+ completed = self . completed_draws ,
859
+ draws = self . draws ,
851
860
total = self .total_draws - 1 ,
852
861
chain_idx = chain_idx ,
853
862
sampling_speed = 0 ,
0 commit comments