From bd7d265af6a8a01ca1c1c2ca34da4598a2344c2e Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Wed, 31 Jan 2024 09:19:59 -0500 Subject: [PATCH] Fix unintuitive `--gen_kwargs` behavior (#1329) * don't override do_sample if no value for it is passed * Update gen_kwargs override condition * Update huggingface.py * Update huggingface.py * run linters * silence an erroneous warning --- lm_eval/evaluator.py | 9 ++++++--- lm_eval/models/huggingface.py | 10 ++++++---- lm_eval/tasks/gsm8k/gsm8k.yaml | 1 + 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index 657fdcfef0..89142a5afc 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -140,9 +140,12 @@ def simple_evaluate( ) else: default_num_fewshot = config["num_fewshot"] - eval_logger.warning( - f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" - ) + if default_num_fewshot: + # warn a user, if a specific num_fewshot > 0 was specified. + # if unspecified in config, no warning message + eval_logger.warning( + f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}" + ) task_obj._config["num_fewshot"] = num_fewshot diff --git a/lm_eval/models/huggingface.py b/lm_eval/models/huggingface.py index d5eb645b65..6f24ac61e7 100644 --- a/lm_eval/models/huggingface.py +++ b/lm_eval/models/huggingface.py @@ -705,10 +705,12 @@ def _model_call(self, inps, attn_mask=None, labels=None): return self.model(inps).logits def _model_generate(self, context, max_length, stop, **generation_kwargs): - # we require users to pass do_sample=True explicitly - # for non-greedy gen. This should be reevaluated when considering beam search. - if "do_sample" not in generation_kwargs: - generation_kwargs["do_sample"] = False + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + do_sample = generation_kwargs.get("do_sample", None) + if do_sample is False and "temperature" == 0.0: + generation_kwargs.pop("temperature", 0.0) # build stopping criteria stopping_criteria = stop_sequences_criteria( self.tokenizer, stop, context.shape[1], context.shape[0] diff --git a/lm_eval/tasks/gsm8k/gsm8k.yaml b/lm_eval/tasks/gsm8k/gsm8k.yaml index dc5ba61472..76be03ee51 100644 --- a/lm_eval/tasks/gsm8k/gsm8k.yaml +++ b/lm_eval/tasks/gsm8k/gsm8k.yaml @@ -24,6 +24,7 @@ generation_kwargs: - "\n\n" - "Question:" do_sample: false + temperature: 0.0 repeats: 1 num_fewshot: 5 filter_list: