diff --git a/src/lighteval/metrics/imports/bert_scorer.py b/src/lighteval/metrics/imports/bert_scorer.py index b8025bf3f..88a063c6a 100644 --- a/src/lighteval/metrics/imports/bert_scorer.py +++ b/src/lighteval/metrics/imports/bert_scorer.py @@ -37,6 +37,16 @@ logger = logging.getLogger(__name__) +def validate_tokenizer_length(tokenizer: AutoTokenizer, override_length: int | None) -> int: + if override_length: + return override_length + if tokenizer.model_max_length == int(1e30): + logger.warning("Could not read max_model_length attribute for BERTScorer's tokenizer - defaulting to 512.") + return 512 + else: + return tokenizer.model_max_length + + def padding(arr, pad_token, dtype=torch.long): lens = torch.LongTensor([len(a) for a in arr]) max_len = lens.max().item() @@ -321,6 +331,7 @@ def __init__( lang=None, rescale_with_baseline=False, baseline_path=None, + tokenizer_max_len: int | None = None, ): """Initialize BERTScorer. @@ -343,6 +354,7 @@ def __init__( return_hash (bool): Return hash code of the setting. rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline. baseline_path (str): Customized baseline file. + tokenizer_max_len (int, optional): will override the tokenizer's max model length if set. """ assert lang is not None or model_type is not None, "Either lang or model_type should be specified" @@ -366,6 +378,7 @@ def __init__( # Model and tokenizer are lazily loaded in `score()`. self._tokenizer = None + self._tokenizer_len = tokenizer_max_len self._model = None self._idf_dict = None @@ -430,6 +443,9 @@ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False): if self._model is None: logger.info(f"Loading BERTScorer model `{self._model_type}`") self._tokenizer = AutoTokenizer.from_pretrained(self._model_type) + self._tokenizer.model_max_length = validate_tokenizer_length( + tokenizer=self._tokenizer, override_length=self._tokenizer_len + ) self._model = AutoModel.from_pretrained(self._model_type) self._model.eval() self._model.to(self.device) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 25b4f68ff..37d4e40cb 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -643,7 +643,11 @@ def compute(self, doc: Doc, model_response: ModelResponse, **kwargs) -> dict[str logger.warning("The first metric computation step might be a bit longer as we need to download the model.") # We only initialize on first compute self.bert_scorer = BERTScorer( - model_type="microsoft/deberta-large-mnli", lang="en", rescale_with_baseline=True, num_layers=9 + model_type="microsoft/deberta-large-mnli", + lang="en", + rescale_with_baseline=True, + num_layers=9, + tokenizer_max_len=512, ) golds = as_list(golds) predictions = as_list(predictions)