@@ -94,7 +94,15 @@ def _evaluate_dataframe(
9494
9595
9696def _evaluate_basic (
97- to_evaluate , input_cols , evaluation_function , func_args , func_kwargs , mapper , task_ids , db
97+ to_evaluate ,
98+ input_cols ,
99+ evaluation_function ,
100+ func_args ,
101+ func_kwargs ,
102+ mapper ,
103+ task_ids ,
104+ db ,
105+ progress_bar = True ,
98106):
99107 res = []
100108 # Setup the function to apply to the data
@@ -109,8 +117,11 @@ def _evaluate_basic(
109117 arg_list = list (to_evaluate .loc [task_ids , input_cols ].to_dict ("index" ).items ())
110118
111119 try :
120+ tasks = mapper (eval_func , arg_list )
121+ if progress_bar :
122+ tasks = tqdm (tasks , total = len (task_ids ))
112123 # Compute and collect the results
113- for task_id , result , exception in tqdm ( mapper ( eval_func , arg_list ), total = len ( task_ids )) :
124+ for task_id , result , exception in tasks :
114125 res .append (dict ({"df_index" : task_id , "exception" : exception }, ** result ))
115126
116127 # Save the results into the DB
@@ -163,6 +174,7 @@ def evaluate(
163174 func_args = None ,
164175 func_kwargs = None ,
165176 shuffle_rows = True ,
177+ progress_bar = True ,
166178 ** mapper_kwargs ,
167179):
168180 """Evaluate and save results in a sqlite database on the fly and return dataframe.
@@ -185,12 +197,14 @@ def evaluate(
185197 func_args (list): the arguments to pass to the evaluation_function.
186198 func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
187199 shuffle_rows (bool): if :obj:`True`, it will shuffle the rows before computing the results.
200+ progress_bar (bool): if :obj:`True`, a progress bar will be displayed during computation.
188201 **mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the
189202 :class:`ParallelFactory` instance.
190203
191204 Return:
192205 pandas.DataFrame: dataframe with new columns containing the computed results.
193206 """
207+ # pylint: disable=too-many-branches
194208 # Initialize the parallel factory
195209 if isinstance (parallel_factory , str ) or parallel_factory is None :
196210 parallel_factory = init_parallel_factory (parallel_factory )
@@ -243,6 +257,8 @@ def evaluate(
243257 return to_evaluate
244258
245259 # Get the factory mapper
260+ if isinstance (parallel_factory , DaskDataFrameFactory ):
261+ mapper_kwargs ["progress_bar" ] = progress_bar
246262 mapper = parallel_factory .get_mapper (** mapper_kwargs )
247263
248264 if isinstance (parallel_factory , DaskDataFrameFactory ):
@@ -267,6 +283,7 @@ def evaluate(
267283 mapper ,
268284 task_ids ,
269285 db ,
286+ progress_bar ,
270287 )
271288 to_evaluate .loc [res_df .index , res_df .columns ] = res_df
272289
0 commit comments