Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pydantic import BaseModel
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase

from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.requests import Doc

Expand All @@ -51,6 +51,9 @@ class ModelConfig(BaseModel, extra="forbid"):
generation_parameters (GenerationParameters):
Configuration parameters that control text generation behavior, including
temperature, top_p, max_new_tokens, etc. Defaults to empty GenerationParameters.
chat_template_parameters (ChatTemplateParameters):
Configuration parameters that control chat template behavior, including
reasoning_effort, enable_thinking, etc. Defaults to empty ChatTemplateParameters.
system_prompt (str | None):
Optional system prompt to be used with chat models. This prompt sets the
behavior and context for the model during evaluation.
Expand Down Expand Up @@ -85,6 +88,7 @@ class ModelConfig(BaseModel, extra="forbid"):
model_name: str = None

generation_parameters: GenerationParameters = GenerationParameters()
chat_template_parameters: ChatTemplateParameters = ChatTemplateParameters()
system_prompt: str | None = None
cache_dir: str = "~/.cache/huggingface/lighteval"

Expand Down Expand Up @@ -128,7 +132,7 @@ def _parse_args(args: str) -> dict:
'model': {'model_name': 'gpt2', 'generation_parameters': {'temperature': 0.7, 'top_p': 0.9},
}

>>> parse_args("model_name=gpt2,use_cache,generation_parameters={temperature:0.7}")
>>> parse_args("model_name=gpt2,use_cache,generation_parameters={temperature:0.7},chat_template_parameters={reasoning_effort:low}")
{
'model': {'model_name': 'gpt2', 'use_cache': True, 'generation_parameters': {'temperature': 0.7}},
}
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/custom/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,4 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
An example of a custom model can be found in `examples/custom_models/google_translate_model.py`.
"""

model_name: str
model_definition_file_path: str
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class ServerlessEndpointModelConfig(ModelConfig):
```
"""

model_name: str
add_special_tokens: bool = True
batch_size: int = 1

Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ class LiteLLMModelConfig(ModelConfig):
```
"""

model_name: str
provider: str | None = None
base_url: str | None = None
api_key: str | None = None
Expand Down
17 changes: 17 additions & 0 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,20 @@ def to_sglang_dict(self) -> dict:
"min_new_tokens": self.min_new_tokens,
}
return {k: v for k, v in args.items() if v is not None}


class ChatTemplateParameters(BaseModel):
reasoning_effort: str | None = None
enable_thinking: bool | None = None

def to_transformers_dict(self) -> dict:
"""Selects relevant chat template parameters for transformers models.

Returns:
dict: Valid parameters for the chat template
"""
args = {
"reasoning_effort": self.reasoning_effort,
"enable_thinking": self.enable_thinking,
}
return {k: v for k, v in args.items() if v is not None}
1 change: 0 additions & 1 deletion src/lighteval/models/sglang/sglang_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ class SGLangModelConfig(ModelConfig):
```
"""

model_name: str
load_format: str = "auto"
dtype: str = "auto"
tp_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism
Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ class TransformersModelConfig(ModelConfig):
(bitsandbytes for 4-bit/8-bit quantization).
"""

model_name: str
tokenizer: str | None = None
subfolder: str | None = None
revision: str = "main"
Expand Down Expand Up @@ -234,7 +233,10 @@ def __init__(
model_size = -1

self.prompt_manager = PromptManager(
use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt
use_chat_template=self.use_chat_template,
tokenizer=self.tokenizer,
system_prompt=config.system_prompt,
chat_template_parameters=config.chat_template_parameters,
)

# Initialize cache for tokenization and predictions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class VLMTransformersModelConfig(ModelConfig):
cache_dir (str, optional, defaults to "~/.cache/huggingface/lighteval"): Directory to cache the model.
"""

model_name: str
processor: str | None = None
use_fast_image_processor: bool | None = None
subfolder: str | None = None
Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ class VLLMModelConfig(ModelConfig):
```
"""

model_name: str
tokenizer: str | None = None
revision: str = "main" # revision of the model
dtype: str = "bfloat16"
Expand Down
11 changes: 10 additions & 1 deletion src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from itertools import cycle
from typing import TYPE_CHECKING

from lighteval.models.model_input import ChatTemplateParameters
from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list

Expand All @@ -40,10 +41,17 @@


class PromptManager:
def __init__(self, use_chat_template: bool = False, tokenizer=None, system_prompt: str | None = None):
def __init__(
self,
use_chat_template: bool = False,
tokenizer=None,
system_prompt: str | None = None,
chat_template_parameters: ChatTemplateParameters | None = None,
):
self.use_chat_template = use_chat_template
self.tokenizer = tokenizer
self.system_prompt = system_prompt # System prompt to be used in chat templates
self.chat_template_parameters = chat_template_parameters if chat_template_parameters else {}

def prepare_prompt(self, doc: Doc) -> str:
"""Prepare a prompt from a document, either using chat template or plain text format.
Expand Down Expand Up @@ -133,6 +141,7 @@ def _prepare_chat_template(self, doc: Doc, tokenize: bool = True) -> str:
messages,
tokenize=False,
add_generation_prompt=True,
**self.chat_template_parameters.to_transformers_dict(),
)

else: # for apis
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/prompt/test_prompt_manager_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytest

from lighteval.models.model_input import ChatTemplateParameters
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc

Expand All @@ -47,6 +48,22 @@ def test_init_with_chat_template(self):
assert pm.tokenizer == tokenizer
assert pm.system_prompt == system_prompt

def test_init_with_chat_template_and_chat_template_parameters(self):
"""Test PromptManager initialization with chat template enabled and chat template parameters."""
tokenizer = Mock()
system_prompt = "You are a helpful assistant."
pm = PromptManager(
use_chat_template=True,
tokenizer=tokenizer,
system_prompt=system_prompt,
chat_template_parameters=ChatTemplateParameters(reasoning_effort="medium"),
)
assert pm.use_chat_template is True
assert pm.tokenizer == tokenizer
assert pm.system_prompt == system_prompt
assert pm.chat_template_parameters is not None
assert pm.chat_template_parameters.reasoning_effort == "medium"

def test_prepare_prompt_plain_text_basic(self):
"""Test prepare_prompt with plain text format and basic document."""
pm = PromptManager()
Expand Down
84 changes: 84 additions & 0 deletions tests/unit/utils/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import unittest

from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters
from lighteval.models.utils import ModelConfig


class TestModelConfig(unittest.TestCase):
def test_model_config_init(self):
config = ModelConfig(
model_name="meta-llama/Llama-3.1-8B-Instruct",
generation_parameters=GenerationParameters(temperature=0.7),
system_prompt="You are a helpful assistant.",
chat_template_parameters=ChatTemplateParameters(reasoning_effort="low"),
)

self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.system_prompt, "You are a helpful assistant.")
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")

def test_model_config_init_command_line(self):
config = ModelConfig.from_args(
'model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt="You are a helpful assistant.",generation_parameters={temperature:0.7},chat_template_parameters={reasoning_effort:low}'
)

self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.system_prompt, '"You are a helpful assistant."') # is this what we want?
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")

def test_model_config_generation_parameters_parse_single_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7}"
)
self.assertEqual(config.generation_parameters.temperature, 0.7)

def test_model_config_generation_parameters_parse_multiple_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7,top_k:42}"
)
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.generation_parameters.top_k, 42)

@unittest.skip("This is not working at this time")
def test_model_config_generation_parameters_parse_string(self):
config = ModelConfig.from_args(
'model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={response_format:{"type":"json_object"}}'
)
self.assertEqual(config.generation_parameters.temperature, 0.7)

@unittest.skip("This is not working at this time")
def test_model_config_chat_template_parameters_parse_single_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={temperature:0.7}"
)
self.assertEqual(config.chat_template_parameters.temperature, 0.7)

def test_model_config_chat_template_parameters_parse_string(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={reasoning_effort:low}"
)
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")
Loading