Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mathvista #2321

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
86 changes: 59 additions & 27 deletions lm_eval/models/hf_vlms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -18,8 +18,12 @@
replace_placeholders,
stop_sequences_criteria,
)
from lm_eval.utils import add_padding_if_needed


if TYPE_CHECKING:
import PIL

DEFAULT_IMAGE_PLACEHOLDER = "<image>"


Expand All @@ -43,11 +47,16 @@ def __init__(
interleave: bool = True,
# TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999,
convert_img_format=False,
convert_img_format: bool = False,
auto_model_class: str = None,
**kwargs,
):
if auto_model_class is not None:
self.AUTO_MODEL_CLASS = getattr(transformers, auto_model_class)

# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior.

super().__init__(pretrained, **kwargs)

assert (
Expand Down Expand Up @@ -169,7 +178,9 @@ def tok_multimodal_encode(

return text_encoding, encoding # image_encoding is a dict

def _encode_multimodal_pair(self, context, continuation, images):
def _encode_multimodal_pair(
self, context, continuation, images: List["PIL.Image.Image"]
):
"""Helper function to perform the role of TemplateLM._encode_pair
Except allowing for image input to also be processed alongside `context`.

Expand All @@ -182,7 +193,12 @@ def _encode_multimodal_pair(self, context, continuation, images):
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]

# TODO: replace default <image> placeholder with self.image_token, for contexts
context = replace_placeholders(
context, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)

if self.rgb:
images = [img.convert("RGB") for img in images]

whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images
Expand Down Expand Up @@ -267,7 +283,9 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str
def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
images: List[List], # TODO: images are pil.Image at the moment, update typehint
images: List[
List["PIL.Image.Image"] # noqa: F821
], # TODO: images are pil.Image at the moment, update typehint
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
Expand Down Expand Up @@ -298,14 +316,25 @@ def tok_batch_multimodal_encode(
if getattr(self.config, "model_type", "") == "llava":
images = flatten_image_list(images)

encoding = self.processor(
images=images,
text=strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
)
try:
encoding = self.processor(
images=images,
text=strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
)
# Qwen processor errors out if a dimension is too small (defaults to do_resize=True, and that requires a min dimension)
except Exception:
encoding = self.processor(
images=[add_padding_if_needed(image) for image in images],
text=strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
)

encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention
self.device, self.model.dtype
Expand All @@ -325,7 +354,7 @@ def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None):
"""
# note: imgs is a dict.
with torch.no_grad():
return self.model(inps, **imgs).logits
return self.model(inps, **imgs, attention_mask=attn_mask).logits

def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
Expand Down Expand Up @@ -363,7 +392,9 @@ def _batch_images(self, image_encs):
batched_imgs[key] = torch.cat(
[
torch.tensor(
image_enc[key], device=self.device, dtype=self.model.dtype
image_enc[key],
device=self.device,
dtype=self.model.dtype if key == "pixel_values" else torch.int,
)
for image_enc in image_encs
],
Expand All @@ -380,10 +411,6 @@ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
raise NotImplementedError(
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
)

new_reqs = []
for context, continuation, aux_arguments in [req.args for req in requests]:
if context == "":
Expand Down Expand Up @@ -436,16 +463,16 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-1] + req[-3] + req[-2][:-1]
return req[-3] + req[-2]

re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
group_by=None,
# group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
# if self.backend == "causal" and self.logits_cache
# else None,
# group_fn=_lookup_one_token_cont,
)

# automatic (variable) batch size detection for vectorization
Expand Down Expand Up @@ -529,7 +556,12 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
) # TODO: fix/test for bs>1 case with differently-sized imgs!

multi_logits = F.log_softmax(
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs),
self._model_multimodal_call(
batched_inps,
batched_imgs,
attn_mask=torch.ones_like(batched_inps),
**call_kwargs,
),
dim=-1,
) # [batch, padding_length (inp or cont), vocab]

Expand All @@ -549,7 +581,7 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
if self.backend == "causal"
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
Expand Down
197 changes: 196 additions & 1 deletion lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import asyncio
import base64
import copy
import itertools
import json
import os
from functools import cached_property
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union

from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential
from tqdm import tqdm

from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
from lm_eval.models.api_models import JsonChatStr, TemplateAPI
from lm_eval.models.utils import Collator
from lm_eval.utils import eval_logger


Expand Down Expand Up @@ -238,3 +250,186 @@ def loglikelihood(self, requests, **kwargs):
raise NotImplementedError(
"Loglikelihood (and therefore `multiple_choice`-type tasks) is not supported for chat completions as OpenAI does not provide prompt logprobs. See https://github.com/EleutherAI/lm-evaluation-harness/issues/942#issuecomment-1777836312 or https://github.com/EleutherAI/lm-evaluation-harness/issues/1196 for more background on this limitation."
)


@register_model("pixtral-api")
class PixtralAPI(LocalChatCompletion):
MULTIMODAL = True
DEFAULT_IMAGE_PLACEHOLDER = "<image>"

def __init__(
self,
max_images: int = 999,
**kwargs,
):
self.max_images = max_images
super().__init__(
tokenizer_backend=None,
tokenized_requests=False,
model="mistralai/Pixtral-12B-2409",
**kwargs,
)

def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []

def _collate_gen(_requests):
# sort by the length of the non-tokenized contexts
return -len(_requests[0])

# Let the API deal with tokenization
requests, all_gen_kwargs, aux_args = zip(*(req.args for req in requests))
if self.tokenized_requests:
encodings_list = self.tok_encode(
requests, add_special_tokens=self.add_bos_token
)
else:
requests = [
self.update_json_chat_str_with_image(req, pil_image["visual"])
for req, pil_image in zip(requests, aux_args)
]
encodings_list = [None] * len(requests)
requests = [
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
]

re_ord = Collator(
requests,
sort_fn=_collate_gen,
group_by="gen_kwargs",
)
chunked = re_ord.get_batched(
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
)
if self._concurrent <= 1:
pbar = tqdm(desc="Requesting API", total=len(requests))
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
outputs = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
)(self.model_call)(
messages=req,
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
for generated_text, context in zip(
self.parse_generations(
outputs=outputs,
contexts=contexts,
),
contexts,
):
if generated_text is not None:
res.append(generated_text)

# partial caching
if context is not None:
self.cache_hook.add_partial(
"generate_until",
(context, all_gen_kwargs[0]),
generated_text,
)
pbar.update(1)
else:
for chunk in chunked:
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
req = encodings_list if self.tokenized_requests else contexts
results = itertools.chain.from_iterable(
asyncio.run(
self.get_batched_requests(
req,
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
generate=True,
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
)
)
)
res.extend(results)

return re_ord.get_original(res)

@staticmethod
def encode_pillow_image(img):
if img.mode == "P":
img = img.convert("RGB")
if img.mode == "RGBA":
# Create a white background
background = Image.new("RGB", img.size, (255, 255, 255))
# Paste the image on the background.
# The alpha channel is automatically used as mask
background.paste(img, mask=img.split()[3])
img = background

buffered = BytesIO()
img.save(buffered, format="JPEG")

return base64.b64encode(buffered.getvalue()).decode("utf-8")

def update_json_chat_str_with_image(
self, json_chat_str, pil_images: Union["Image.Image", List["Image.Image"]]
):
# Parse the JSON string
chat_data = json.loads(json_chat_str.prompt)

# Convert single image to list for consistency
if not isinstance(pil_images, list):
pil_images = [pil_images]

# Encode the Pillow image(s)
base64_images = [self.encode_pillow_image(img) for img in pil_images]

# Update the image_url(s) in the chat data
image_index = 0
for message in chat_data:
if message["role"] == "user":
for content in message["content"]:
if content["type"] == "image_url":
if image_index < len(base64_images):
content["image_url"] = {
"url": f"data:image/jpeg;base64,{base64_images[image_index]}"
}
image_index += 1
else:
# If we run out of images, set to None or handle as needed
content["image_url"] = None

# Update the JsonChatStr object with the new JSON string
json_chat_str = JsonChatStr(json.dumps(chat_data))

return json_chat_str

def apply_chat_template(
self, chat_history: List[Dict[str, str]]
) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True
)
else:
# bit of a hack. We'll load back before sending to the API
new_messages = []
for message in chat_history:
if message["role"] == "user":
# Split the content at <image> placeholder
parts = message["content"].split("<image>")
new_content = [
{"type": "text", "text": parts[0].strip()},
{"type": "image_url", "image_url": None},
]
if len(parts) > 1:
new_content.append({"type": "text", "text": parts[1].strip()})

new_messages.append(
{"role": message["role"], "content": new_content}
)
else:
# For non-user messages, keep the format as is
new_messages.append(message)

return JsonChatStr(json.dumps(new_messages))
Loading
Loading