@@ -29,22 +29,23 @@ def sample_requests(
29
29
dataset = [(data ["conversations" ][0 ]["value" ],
30
30
data ["conversations" ][1 ]["value" ]) for data in dataset ]
31
31
32
- # Tokenize the prompts and completions.
33
- prompts = [prompt for prompt , _ in dataset ]
34
- prompt_token_ids = tokenizer (prompts ).input_ids
35
- completions = [completion for _ , completion in dataset ]
36
- completion_token_ids = tokenizer (completions ).input_ids
37
- tokenized_dataset = []
38
- for i in range (len (dataset )):
39
- output_len = len (completion_token_ids [i ])
40
- if fixed_output_len is not None :
41
- output_len = fixed_output_len
42
- tokenized_dataset .append ((prompts [i ], prompt_token_ids [i ], output_len ))
32
+ # Shuffle the dataset.
33
+ random .shuffle (dataset )
43
34
44
- # Filter out too long sequences.
35
+ # Filter out sequences that are too long or too short
45
36
filtered_dataset : List [Tuple [str , int , int ]] = []
46
- for prompt , prompt_token_ids , output_len in tokenized_dataset :
37
+ for i in range (len (dataset )):
38
+ if len (filtered_dataset ) == num_requests :
39
+ break
40
+
41
+ # Tokenize the prompts and completions.
42
+ prompt = dataset [i ][0 ]
43
+ prompt_token_ids = tokenizer (prompt ).input_ids
44
+ completion = dataset [i ][1 ]
45
+ completion_token_ids = tokenizer (completion ).input_ids
47
46
prompt_len = len (prompt_token_ids )
47
+ output_len = len (completion_token_ids
48
+ ) if fixed_output_len is None else fixed_output_len
48
49
if prompt_len < 4 or output_len < 4 :
49
50
# Prune too short sequences.
50
51
continue
@@ -53,9 +54,7 @@ def sample_requests(
53
54
continue
54
55
filtered_dataset .append ((prompt , prompt_len , output_len ))
55
56
56
- # Sample the requests.
57
- sampled_requests = random .sample (filtered_dataset , num_requests )
58
- return sampled_requests
57
+ return filtered_dataset
59
58
60
59
61
60
def run_vllm (
0 commit comments