diff --git a/lm_eval/models/hf_vlms.py b/lm_eval/models/hf_vlms.py index f2fcdd7027..1dc94e61f1 100644 --- a/lm_eval/models/hf_vlms.py +++ b/lm_eval/models/hf_vlms.py @@ -1,5 +1,5 @@ import copy -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -18,8 +18,12 @@ replace_placeholders, stop_sequences_criteria, ) +from lm_eval.utils import add_padding_if_needed +if TYPE_CHECKING: + import PIL + DEFAULT_IMAGE_PLACEHOLDER = "" @@ -43,11 +47,16 @@ def __init__( interleave: bool = True, # TODO: handle whitespace in image placeholder (replacement) max_images: Optional[int] = 999, - convert_img_format=False, + convert_img_format: bool = False, + auto_model_class: str = None, **kwargs, ): + if auto_model_class is not None: + self.AUTO_MODEL_CLASS = getattr(transformers, auto_model_class) + # We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer # modify init behavior. + super().__init__(pretrained, **kwargs) assert ( @@ -169,7 +178,9 @@ def tok_multimodal_encode( return text_encoding, encoding # image_encoding is a dict - def _encode_multimodal_pair(self, context, continuation, images): + def _encode_multimodal_pair( + self, context, continuation, images: List["PIL.Image.Image"] + ): """Helper function to perform the role of TemplateLM._encode_pair Except allowing for image input to also be processed alongside `context`. @@ -182,7 +193,12 @@ def _encode_multimodal_pair(self, context, continuation, images): continuation = context[-n_spaces:] + continuation context = context[:-n_spaces] - # TODO: replace default placeholder with self.image_token, for contexts + context = replace_placeholders( + context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images + ) + + if self.rgb: + images = [img.convert("RGB") for img in images] whole_enc, image_enc = self.tok_multimodal_encode( context + continuation, images @@ -267,7 +283,9 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str def tok_batch_multimodal_encode( self, strings: List[str], # note that input signature of this fn is different - images: List[List], # TODO: images are pil.Image at the moment, update typehint + images: List[ + List["PIL.Image.Image"] # noqa: F821 + ], # TODO: images are pil.Image at the moment, update typehint padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, @@ -298,14 +316,25 @@ def tok_batch_multimodal_encode( if getattr(self.config, "model_type", "") == "llava": images = flatten_image_list(images) - encoding = self.processor( - images=images, - text=strings, - truncation=truncation, - padding="longest", - return_tensors="pt", - # **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added? - ) + try: + encoding = self.processor( + images=images, + text=strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + # **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added? + ) + # Qwen processor errors out if a dimension is too small (defaults to do_resize=True, and that requires a min dimension) + except Exception: + encoding = self.processor( + images=[add_padding_if_needed(image) for image in images], + text=strings, + truncation=truncation, + padding="longest", + return_tensors="pt", + # **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added? + ) encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention self.device, self.model.dtype @@ -325,7 +354,7 @@ def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None): """ # note: imgs is a dict. with torch.no_grad(): - return self.model(inps, **imgs).logits + return self.model(inps, **imgs, attention_mask=attn_mask).logits def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs): generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) @@ -363,7 +392,9 @@ def _batch_images(self, image_encs): batched_imgs[key] = torch.cat( [ torch.tensor( - image_enc[key], device=self.device, dtype=self.model.dtype + image_enc[key], + device=self.device, + dtype=self.model.dtype if key == "pixel_values" else torch.int, ) for image_enc in image_encs ], @@ -380,10 +411,6 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood( self, requests: List[Instance], disable_tqdm: bool = False ) -> List[Tuple[float, bool]]: - raise NotImplementedError( - "'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!" - ) - new_reqs = [] for context, continuation, aux_arguments in [req.args for req in requests]: if context == "": @@ -436,16 +463,16 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # speeds up some multiple-choice tasks proportionally to the number of choices. # groups requests by context+continuation[:-1] and infer on one request/group. - return req[-1] + req[-3] + req[-2][:-1] + return req[-3] + req[-2] re_ord = Collator( requests, sort_fn=_collate, - group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM - and self.logits_cache - else None, - group_fn=_lookup_one_token_cont, + group_by=None, + # group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs + # if self.backend == "causal" and self.logits_cache + # else None, + # group_fn=_lookup_one_token_cont, ) # automatic (variable) batch size detection for vectorization @@ -529,7 +556,12 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): ) # TODO: fix/test for bs>1 case with differently-sized imgs! multi_logits = F.log_softmax( - self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs), + self._model_multimodal_call( + batched_inps, + batched_imgs, + attn_mask=torch.ones_like(batched_inps), + **call_kwargs, + ), dim=-1, ) # [batch, padding_length (inp or cont), vocab] @@ -549,7 +581,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): # from prompt/prefix tuning tokens, if applicable ctx_len = ( inplen + (logits.shape[0] - padding_len_inp) - if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM + if self.backend == "causal" else None ) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) diff --git a/lm_eval/models/openai_completions.py b/lm_eval/models/openai_completions.py index 331a3f6af5..df0cbed112 100644 --- a/lm_eval/models/openai_completions.py +++ b/lm_eval/models/openai_completions.py @@ -1,9 +1,21 @@ +import asyncio +import base64 +import copy +import itertools +import json import os from functools import cached_property +from io import BytesIO from typing import Any, Dict, List, Optional, Tuple, Union +from PIL import Image +from tenacity import retry, stop_after_attempt, wait_exponential +from tqdm import tqdm + +from lm_eval.api.instance import Instance from lm_eval.api.registry import register_model -from lm_eval.models.api_models import TemplateAPI +from lm_eval.models.api_models import JsonChatStr, TemplateAPI +from lm_eval.models.utils import Collator from lm_eval.utils import eval_logger @@ -238,3 +250,186 @@ def loglikelihood(self, requests, **kwargs): raise NotImplementedError( "Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation." ) + + +@register_model("pixtral-api") +class PixtralAPI(LocalChatCompletion): + MULTIMODAL = True + DEFAULT_IMAGE_PLACEHOLDER = "" + + def __init__( + self, + max_images: int = 999, + **kwargs, + ): + self.max_images = max_images + super().__init__( + tokenizer_backend=None, + tokenized_requests=False, + model="mistralai/Pixtral-12B-2409", + **kwargs, + ) + + def generate_until( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[str]: + res = [] + + def _collate_gen(_requests): + # sort by the length of the non-tokenized contexts + return -len(_requests[0]) + + # Let the API deal with tokenization + requests, all_gen_kwargs, aux_args = zip(*(req.args for req in requests)) + if self.tokenized_requests: + encodings_list = self.tok_encode( + requests, add_special_tokens=self.add_bos_token + ) + else: + requests = [ + self.update_json_chat_str_with_image(req, pil_image["visual"]) + for req, pil_image in zip(requests, aux_args) + ] + encodings_list = [None] * len(requests) + requests = [ + (a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list) + ] + + re_ord = Collator( + requests, + sort_fn=_collate_gen, + group_by="gen_kwargs", + ) + chunked = re_ord.get_batched( + n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None + ) + if self._concurrent <= 1: + pbar = tqdm(desc="Requesting API", total=len(requests)) + for chunk in chunked: + contexts, all_gen_kwargs, encodings_list = zip(*chunk) + req = encodings_list if self.tokenized_requests else contexts + outputs = retry( + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=0.5, min=1, max=10), + reraise=True, + )(self.model_call)( + messages=req, + generate=True, + gen_kwargs=copy.deepcopy(all_gen_kwargs[0]), + ) + for generated_text, context in zip( + self.parse_generations( + outputs=outputs, + contexts=contexts, + ), + contexts, + ): + if generated_text is not None: + res.append(generated_text) + + # partial caching + if context is not None: + self.cache_hook.add_partial( + "generate_until", + (context, all_gen_kwargs[0]), + generated_text, + ) + pbar.update(1) + else: + for chunk in chunked: + contexts, all_gen_kwargs, encodings_list = zip(*chunk) + req = encodings_list if self.tokenized_requests else contexts + results = itertools.chain.from_iterable( + asyncio.run( + self.get_batched_requests( + req, + cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts], + generate=True, + gen_kwargs=copy.deepcopy(all_gen_kwargs[0]), + ) + ) + ) + res.extend(results) + + return re_ord.get_original(res) + + @staticmethod + def encode_pillow_image(img): + if img.mode == "P": + img = img.convert("RGB") + if img.mode == "RGBA": + # Create a white background + background = Image.new("RGB", img.size, (255, 255, 255)) + # Paste the image on the background. + # The alpha channel is automatically used as mask + background.paste(img, mask=img.split()[3]) + img = background + + buffered = BytesIO() + img.save(buffered, format="JPEG") + + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + def update_json_chat_str_with_image( + self, json_chat_str, pil_images: Union["Image.Image", List["Image.Image"]] + ): + # Parse the JSON string + chat_data = json.loads(json_chat_str.prompt) + + # Convert single image to list for consistency + if not isinstance(pil_images, list): + pil_images = [pil_images] + + # Encode the Pillow image(s) + base64_images = [self.encode_pillow_image(img) for img in pil_images] + + # Update the image_url(s) in the chat data + image_index = 0 + for message in chat_data: + if message["role"] == "user": + for content in message["content"]: + if content["type"] == "image_url": + if image_index < len(base64_images): + content["image_url"] = { + "url": f"data:image/jpeg;base64,{base64_images[image_index]}" + } + image_index += 1 + else: + # If we run out of images, set to None or handle as needed + content["image_url"] = None + + # Update the JsonChatStr object with the new JSON string + json_chat_str = JsonChatStr(json.dumps(chat_data)) + + return json_chat_str + + def apply_chat_template( + self, chat_history: List[Dict[str, str]] + ) -> Union[str, JsonChatStr]: + """Applies a chat template to a list of chat history between user and model.""" + if self.tokenizer_backend == "huggingface" and self.tokenized_requests: + return self.tokenizer.apply_chat_template( + chat_history, tokenize=False, add_generation_prompt=True + ) + else: + # bit of a hack. We'll load back before sending to the API + new_messages = [] + for message in chat_history: + if message["role"] == "user": + # Split the content at placeholder + parts = message["content"].split("") + new_content = [ + {"type": "text", "text": parts[0].strip()}, + {"type": "image_url", "image_url": None}, + ] + if len(parts) > 1: + new_content.append({"type": "text", "text": parts[1].strip()}) + + new_messages.append( + {"role": message["role"], "content": new_content} + ) + else: + # For non-user messages, keep the format as is + new_messages.append(message) + + return JsonChatStr(json.dumps(new_messages)) diff --git a/lm_eval/tasks/ai2d/ai2d.yaml b/lm_eval/tasks/ai2d/ai2d.yaml new file mode 100644 index 0000000000..38c2e75b2d --- /dev/null +++ b/lm_eval/tasks/ai2d/ai2d.yaml @@ -0,0 +1,19 @@ +task: ai2d +dataset_path: lmms-lab/ai2d +output_type: multiple_choice +test_split: test +doc_to_text: " Question: {{question}}\nAnswer:" +doc_to_target: "{{ answer | int }}" +target_delimiter: "" +doc_to_choice: options +doc_to_image: + - image +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 diff --git a/lm_eval/tasks/mathvista/mathvista.yaml b/lm_eval/tasks/mathvista/mathvista.yaml new file mode 100644 index 0000000000..eadeb40969 --- /dev/null +++ b/lm_eval/tasks/mathvista/mathvista.yaml @@ -0,0 +1,25 @@ +dataset_path: AI4Math/MathVista +task: mathvista +test_split: testmini +output_type: "generate_until" +#process_docs: !function utils.process_docs +doc_to_image: + - decoded_image +doc_to_text: "{{query}}" +#doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}' +doc_to_target: answer +process_results: !function utils.process_results +generation_kwargs: + until: + - "<|endoftext|>" + temperature: 0.0 + do_sample: false + max_gen_toks: 1024 +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 +dataset_kwargs: + trust_remote_code: true diff --git a/lm_eval/tasks/mathvista/mathvista_mcq.yaml b/lm_eval/tasks/mathvista/mathvista_mcq.yaml new file mode 100644 index 0000000000..2f3736906b --- /dev/null +++ b/lm_eval/tasks/mathvista/mathvista_mcq.yaml @@ -0,0 +1,21 @@ +dataset_path: AI4Math/MathVista +task: mathvista_mcq +test_split: testmini +output_type: "multiple_choice" +doc_to_image: + - decoded_image +doc_to_text: "{{query}}\n\nAnswer:" +process_docs: !function utils.process_docs_mcq +doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F", "G"][:choices|length] }}' +doc_to_target: "{{choices.index(answer)}}" +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true +dataset_kwargs: + trust_remote_code: true +metadata: + version: 1.0 diff --git a/lm_eval/tasks/mathvista/utils.py b/lm_eval/tasks/mathvista/utils.py new file mode 100644 index 0000000000..f9d73f484b --- /dev/null +++ b/lm_eval/tasks/mathvista/utils.py @@ -0,0 +1,152 @@ +import re +from typing import Optional + +from Levenshtein import distance + + +# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/calculate_score.py +def get_most_similar(prediction: str, choices: list) -> float: + """ + Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction + """ + distances = [distance(prediction, choice) for choice in choices] + ind = distances.index(min(distances)) + return choices[ind] + # return min(choices, key=lambda choice: distance(prediction, choice)) + + +# taken from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py +def normalize_extracted_answer( + extraction: str, + choices: list, + question_type: str, + answer_type: str, + precision, + ignore_empty_extractions=True, +) -> Optional[str]: + """ + Normalize the extracted answer to match the answer type + """ + + if question_type == "multi_choice": + # make sure the extraction is a string + if isinstance(extraction, str): + extraction = extraction.strip() + else: + try: + extraction = str(extraction) + except Exception: + extraction = "" + + # if the extraction is empty, return None + if ignore_empty_extractions and not extraction: + return None + + # extract "A" from "(A) text" + letter = re.findall(r"\(([a-zA-Z])\)", extraction) + if len(letter) > 0: + extraction = letter[0].upper() + + sequential_characters = [chr(ord("A") + i) for i in range(len(choices))] + + # if model output a character, use it as index of available choices + if extraction in sequential_characters: + option_index = sequential_characters.index(extraction) + normalized_extraction = choices[option_index] + else: + # select the most similar option + normalized_extraction = get_most_similar(extraction, choices) + assert normalized_extraction in choices + + elif answer_type == "integer": + try: + normalized_extraction = str(int(float(extraction))) + except Exception: + normalized_extraction = None + + elif answer_type == "float": + try: + normalized_extraction = str(round(float(extraction), int(precision))) + except Exception: + normalized_extraction = None + + elif answer_type == "list": + try: + normalized_extraction = str(extraction) + except Exception: + normalized_extraction = None + + return normalized_extraction + + +def safe_equal(prediction, answer): + """ + Check if the prediction is equal to the answer, even if they are of different types + """ + try: + if prediction == answer: + return True + return False + except Exception: + return False + + +def extract_answer(response: str, problem: dict) -> str: + question_type = problem["question_type"] + answer_type = problem["answer_type"] + choices = problem["choices"] + # query = problem["query"] + # pid = problem['pid'] + + if response == "": + return "" + + ### This is not in the original code: + extract = re.findall( + r"[tT]he answer is ([A-Za-z0-9]+(?:\.[A-Za-z0-9]+)?)", response + ) + if extract: + return str(extract[0]) + ### + + if question_type == "multi_choice" and response in choices: + return response + + if answer_type == "integer": + try: + extraction = int(response) + return str(extraction) + except Exception: + pass + + if answer_type == "float": + try: + extraction = str(float(response)) + return extraction + except Exception: + pass + + return response + + +# adapted from https://github.com/lupantech/MathVista/blob/main/evaluation/extract_answer.py +def process_results(doc: dict, results: list[str]): + response = results[0] # noqa: F841 + choices = doc["choices"] + question_type = doc["question_type"] + answer_type = doc["answer_type"] + precision = doc["precision"] # noqa: F841 + answer = doc["answer"] + extracted_answer = extract_answer(response, doc) + normalized_extraction = normalize_extracted_answer( + extracted_answer, choices, question_type, answer_type, precision + ) + res = safe_equal(normalized_extraction, answer) + return {"acc": 1.0} if res else {"acc": 0.0} + + +### MathVista MCQ ### + + +def process_docs_mcq(dataset): + return dataset.filter(lambda x: x["question_type"] == "multi_choice") diff --git a/lm_eval/utils.py b/lm_eval/utils.py index 7166e24d07..7a6de73cc7 100644 --- a/lm_eval/utils.py +++ b/lm_eval/utils.py @@ -499,3 +499,40 @@ def weighted_f1_score(items): preds = unzipped_list[1] fscore = f1_score(golds, preds, average="weighted") return fscore + + +def add_padding_if_needed( + images: List["PIL.Image.Image"], # noqa: F821 + min_width: int = 50, + min_height: int = 50, + color=(255, 255, 255), +) -> List["PIL.Image.Image"]: # noqa: F821 + """Adds (default white) padding to images to make them at least min_width and min_height""" + from PIL import ImageOps + + res = [] + for image in images: + width, height = image.size + + if width >= min_width and height >= min_height: + return image + image = image.convert("RGB") + new_width = max(width, min_width) + new_height = max(height, min_height) + + delta_width = new_width - width + delta_height = new_height - height + + padding_left = delta_width // 2 + padding_right = delta_width - padding_left + padding_top = delta_height // 2 + padding_bottom = delta_height - padding_top + res.append( + ImageOps.expand( + image, + (padding_left, padding_top, padding_right, padding_bottom), + fill=color, + ) + ) + + return res