Skip to content

Commit

Permalink
Fix unintuitive --gen_kwargs behavior (#1329)
Browse files Browse the repository at this point in the history
* don't override do_sample if no value for it is passed

* Update gen_kwargs override condition

* Update huggingface.py

* Update huggingface.py

* run linters

* silence an erroneous warning
  • Loading branch information
haileyschoelkopf authored Jan 31, 2024
1 parent 1554066 commit bd7d265
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
9 changes: 6 additions & 3 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,12 @@ def simple_evaluate(
)
else:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
if default_num_fewshot:
# warn a user, if a specific num_fewshot > 0 was specified.
# if unspecified in config, no warning message
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)

task_obj._config["num_fewshot"] = num_fewshot

Expand Down
10 changes: 6 additions & 4 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,12 @@ def _model_call(self, inps, attn_mask=None, labels=None):
return self.model(inps).logits

def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs:
generation_kwargs["do_sample"] = False
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
do_sample = generation_kwargs.get("do_sample", None)
if do_sample is False and "temperature" == 0.0:
generation_kwargs.pop("temperature", 0.0)
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0]
Expand Down
1 change: 1 addition & 0 deletions lm_eval/tasks/gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ generation_kwargs:
- "\n\n"
- "Question:"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 5
filter_list:
Expand Down

0 comments on commit bd7d265

Please sign in to comment.