File tree Expand file tree Collapse file tree 2 files changed +4
-8
lines changed Expand file tree Collapse file tree 2 files changed +4
-8
lines changed Original file line number Diff line number Diff line change @@ -142,7 +142,7 @@ def run(self, server_url: str | None = None) -> tuple:
142142 agg_score : float = 0.0
143143
144144 results = self ._run_mmlu (server_url )
145- for task , result in results .items ():
145+ for task , result in results [ 'results' ] .items ():
146146 agg_score += float (result ["acc,none" ])
147147 individual_scores [task ] = {
148148 "score" : float (result ["acc,none" ]),
@@ -154,7 +154,7 @@ def run(self, server_url: str | None = None) -> tuple:
154154 return overall_score , individual_scores
155155
156156 def _run_mmlu (
157- self , server_url : str | None = None , return_all_results : bool = False
157+ self , server_url : str | None = None
158158 ) -> dict :
159159 if server_url is not None :
160160 # Requires lm_eval >= 0.4.4
@@ -179,11 +179,7 @@ def _run_mmlu(
179179 device = self .device ,
180180 task_manager = tm ,
181181 )
182- if return_all_results :
183- results = mmlu_output
184- else :
185- results = mmlu_output ["results" ]
186- return results
182+ return mmlu_output
187183
188184 # This method converts general errors from simple_evaluate
189185 # into a more user-understandable error
Original file line number Diff line number Diff line change @@ -90,7 +90,7 @@ def run(self, server_url: str | None = None) -> tuple:
9090 self .prepare_unitxt_files ()
9191 logger .debug (locals ())
9292 os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
93- results = self ._run_mmlu (server_url = server_url , return_all_results = True )
93+ results = self ._run_mmlu (server_url = server_url )
9494 taskname = self .tasks [0 ]
9595 global_scores = results ["results" ][taskname ]
9696 global_scores .pop ("alias" )
You can’t perform that action at this time.
0 commit comments