11"""MLC LLM bench prompts generator"""
2+
23import json
34import random
5+ from collections import defaultdict
46from pathlib import Path
57from typing import Any , Dict , List , Optional
68
@@ -18,6 +20,7 @@ class PromptsGenerator: # pylint: disable=too-few-public-methods
1820 def __init__ (
1921 self ,
2022 prompts_path : Optional [str ] = None ,
23+ json_prompts_path : Optional [str ] = None ,
2124 tokenizer : Optional [Any ] = None ,
2225 seed : Optional [int ] = 11111 ,
2326 ) -> None :
@@ -32,6 +35,11 @@ def __init__(
3235 or a .jsonl file where each line is a JSON object formatted as
3336 {"prompt": "prompt text", "prompt_tokens": 10}.
3437
38+ json_prompts_path : Optional[str]
39+ The path to the file containing the source json prompts. This file a
40+ .jsonl file where each line is a JSON object formatted as
41+ {"messages": List[Dict[str, Any]], "response_format": Dict[str, Any]}.
42+
3543 tokenizer : Optional[Any]
3644 The tokenizer object to use for tokenizing the prompts.
3745
@@ -66,6 +74,22 @@ def __init__(
6674 prompt_line = file .readline ()
6775 prompt_tokens = self ._count_tokens (prompt_line )
6876 self .prompts .append ({"prompt" : prompt_line , "prompt_tokens" : prompt_tokens })
77+ if json_prompts_path :
78+ self .json_prompts = defaultdict (list )
79+ with open (json_prompts_path , "r" , encoding = "utf-8" ) as file :
80+ for line in file :
81+ json_line = json .loads (line )
82+ assert (
83+ "messages" in json_line
84+ ), "The messages field is required in the JSONL file."
85+ assert (
86+ "response_format" in json_line
87+ ), "The response_format field is required in the JSONL file."
88+ self .json_prompts [json .dumps (json_line ["response_format" ]["schema" ])].append (
89+ json_line ["messages" ]
90+ )
91+ else :
92+ self .json_prompts = None
6993
7094 def _count_tokens (self , text : str ) -> int :
7195 """Get the number of tokens.
@@ -82,40 +106,44 @@ def _count_tokens(self, text: str) -> int:
82106 """
83107 return len (self .tokenizer .encode (text ))
84108
85- def generate_prompt (self , tokens_mean : int , tokens_stddev : Optional [ int ] = 0 ) -> str :
109+ def generate_prompt (self , params : Dict [ str , Any ] ) -> Dict [ str , Any ] :
86110 """
87- Generates a prompt that closely matches the desired token count .
111+ Generates a prompt based on the params, e.g. prompt_tokens, response_format .
88112
89113 Parameters
90114 ----------
91- token_mean : int
115+ params : Dict[str, Any]
92116 The desired mean number of tokens in the prompt.
93117
94- token_stddev : Optional[int]
95- The desired standard deviation of tokens in the prompt.
96-
97118 Returns
98119 -------
99- out: str
100- A prompt string with the specified number of tokens .
120+ override_params: Dict[ str, Any]
121+ The params to override the original request, e.g. messages, response_format .
101122 """
123+ if "response_format" in params :
124+ response_format = params ["response_format" ]
125+ if response_format .get ("type" ) == "json_object" :
126+ if response_format .get ("schema" ) in self .json_prompts :
127+ assert len (self .json_prompts [response_format ["schema" ]]) > 0
128+ return {"messages" : random .choice (self .json_prompts [response_format ["schema" ]])}
129+ schema , prompts = random .choice (list (self .json_prompts .items ()))
130+ response_format ["schema" ] = schema
131+ return {"messages" : random .choice (prompts ), "response_format" : response_format }
132+ tokens_mean = params .get ("prompt_tokens" , 128 )
102133 assert tokens_mean > 0 , "The mean number of tokens must be greater than 0."
103- out_prompt_tokens = (
104- int (random .gauss (tokens_mean , tokens_stddev )) if tokens_stddev else tokens_mean
105- )
106- if out_prompt_tokens <= 0 :
107- out_prompt_tokens = tokens_mean
108- remaining_prompt_tokens = out_prompt_tokens
134+ remaining_prompt_tokens = tokens_mean
109135 result_prompt = ""
136+ override_params = None
110137 while remaining_prompt_tokens > 0 :
111138 prompt_dict = random .choice (self .prompts )
112139 cur_prompt_tokens = prompt_dict ["prompt_tokens" ]
113140 cur_prompt = prompt_dict ["prompt" ]
141+ if override_params is None :
142+ override_params = prompt_dict ["override_params" ]
114143 if remaining_prompt_tokens - cur_prompt_tokens < 0 :
115144 result_prompt += cur_prompt [:remaining_prompt_tokens ]
116145 remaining_prompt_tokens = 0
117146 break
118147 result_prompt += cur_prompt
119148 remaining_prompt_tokens -= cur_prompt_tokens
120- self ._count_tokens (result_prompt )
121- return result_prompt
149+ return {"messages" : [{"role" : "system" , "content" : result_prompt }]}
0 commit comments