Skip to content

Commit 8ee3e64

Browse files
authored
[Support] Evaluate openai_humaneval support (#2100)
* Add support for multiple text columns * Add cli support for multiple columns
1 parent d6132ac commit 8ee3e64

File tree

2 files changed

+96
-20
lines changed

2 files changed

+96
-20
lines changed

src/sparseml/evaluation/integrations/perplexity.py

+79-18
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional
15+
from typing import List, Optional, Union
1616

1717
from sparseml.transformers.utils.sparse_model import SparseAutoModelForCausalLM
1818
from sparseml.transformers.utils.sparse_tokenizer import SparseAutoTokenizer
@@ -21,6 +21,7 @@
2121
try:
2222
import numpy
2323
import torch
24+
from datasets import Dataset as HuggingFaceDataset
2425
from datasets import load_dataset
2526
from torch.nn import CrossEntropyLoss
2627
from tqdm import tqdm
@@ -65,6 +66,7 @@ def perplexity_eval(
6566
dataset_config_name=dataset_config_name,
6667
split=split,
6768
limit=limit,
69+
text_column_name=kwargs.pop("text_column_name", None),
6870
)
6971
add_start_token = True
7072
max_length = None
@@ -186,7 +188,7 @@ def perplexity_eval(
186188
return Result(formatted=[eval], raw=raw)
187189

188190

189-
def _infer_dataset_config_name(datasets):
191+
def _infer_dataset_config_name(datasets: str):
190192
"""
191193
:param datasets: The name of the dataset to load
192194
:return: The name of the dataset config to load
@@ -199,7 +201,7 @@ def _infer_dataset_config_name(datasets):
199201
def _load_perplexity_dataset(
200202
dataset_name: str,
201203
dataset_config_name: str,
202-
text_column_name: Optional[str] = None,
204+
text_column_name: Union[str, List[str], None] = None,
203205
split: Optional[str] = None,
204206
limit: Optional[int] = None,
205207
) -> List[str]:
@@ -209,27 +211,86 @@ def _load_perplexity_dataset(
209211
:param dataset_name: The name of the dataset to load
210212
:param dataset_config_name: The name of the dataset config to load
211213
:param text_column_name: The name of the column containing the text data
212-
if None, defaults to "text"
214+
if None, defaults to "text". If a list of column names is passed, the
215+
columns will be concatenated to form the input text
213216
:param split: The split of the dataset to load, if None uses test split
214217
if available, otherwise uses train split
215218
:param nsamples: The number of samples to load from the dataset
216219
:return: The loaded dataset as a list of strings
217220
"""
218-
dataset = load_dataset(dataset_name, dataset_config_name, split=split)
219-
if isinstance(dataset, dict):
220-
# check if test split exists
221-
dataset = dataset["test"] if "test" in dataset else dataset["train"]
222-
223-
text_column_name = text_column_name or "text"
224-
if text_column_name not in dataset.column_names:
225-
raise ValueError(
226-
f"Dataset {dataset_name} does not contain a column named {text_column_name}"
227-
)
228-
dataset = dataset[text_column_name]
221+
dataset: HuggingFaceDataset = _fetch_dataset_split(
222+
dataset_name=dataset_name,
223+
dataset_config_name=dataset_config_name,
224+
split=split,
225+
)
226+
text_column_name: List[str] = _verify_text_column_name(
227+
dataset=dataset, text_column_name=text_column_name
228+
)
229+
229230
inputs = []
230-
for s in dataset:
231-
if s != "":
232-
inputs.append(s)
231+
for sample in dataset:
232+
input_sample = "".join(sample[column_name] for column_name in text_column_name)
233+
if input_sample != "":
234+
inputs.append(input_sample)
233235
if limit is not None and len(inputs) >= limit:
234236
break
235237
return inputs
238+
239+
240+
def _fetch_dataset_split(
241+
dataset_name: str, dataset_config_name: Optional[str] = None, split=None
242+
):
243+
"""
244+
Loads and returns the specified split of the dataset.
245+
246+
:param dataset_name: The name of the dataset to load from the HuggingFace
247+
datasets library
248+
:param dataset_config_name: The name of the dataset config to load, if any.
249+
:param split: The split of the dataset to load, if None uses test split
250+
if available, otherwise uses train split. Also supports HuggingFace
251+
style splits such as "train[:10%]", "test", "validation", etc.
252+
:return: The loaded dataset split
253+
"""
254+
dataset = load_dataset(dataset_name, dataset_config_name, split=split)
255+
if split is not None:
256+
# specified split was found in the dataset
257+
return dataset
258+
259+
# try to infer the split to use
260+
if "test" in dataset:
261+
return dataset["test"]
262+
263+
if "train" in dataset:
264+
return dataset["train"]
265+
266+
raise ValueError(
267+
f"Neither 'test' nor 'train' split found in dataset {dataset_name}. "
268+
"Specify a valid split using the 'split' argument."
269+
)
270+
271+
272+
def _verify_text_column_name(
273+
dataset: HuggingFaceDataset,
274+
text_column_name: Union[str, List[str], None] = None,
275+
) -> List[str]:
276+
"""
277+
Verifies that the dataset contains the specified text column name(s),
278+
and returns the text column name(s) to use for evaluation as a list.
279+
280+
:param dataset: The huggingface dataset to verify
281+
:param text_column_name: The name of the column containing the text data
282+
if None, defaults to "text". If a list of column names is passed, all
283+
columns must be present in the dataset
284+
:return: The text column name(s) to use for evaluation as a list of strings
285+
"""
286+
text_column_names = text_column_name or ["text"]
287+
288+
if isinstance(text_column_names, str):
289+
text_column_names = [text_column_name]
290+
291+
for column_name in text_column_names:
292+
if column_name not in dataset.column_names:
293+
raise ValueError(
294+
f"Dataset {dataset} does not contain a column named {column_name}"
295+
)
296+
return text_column_names

src/sparseml/utils/helpers.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,12 @@ def main(..., kwargs):
864864
output = {'arg1': 1, 'arg2': 2, 'arg3': 3}
865865
```
866866
867+
```
868+
input = ('--arg1', 1, '--args1', 2 , 'arg2', 2, '-arg3', 3)
869+
output = parse_kwarg_tuples(input)
870+
output = {'arg1': [1, 2], 'arg2': 2, 'arg3': 3}
871+
```
872+
867873
:param kwargs: The kwargs to convert. Should be a tuple of alternating
868874
kwargs names and kwargs values e.g.('--arg1', 1, 'arg2', 2, -arg3', 3).
869875
The names can optionally have a '-' or `--` in front of them.
@@ -895,8 +901,17 @@ def main(..., kwargs):
895901
pass
896902
# remove any '-' or '--' from the names
897903
kwargs_names = [name.lstrip("-") for name in kwargs_names]
898-
899-
return dict(zip(kwargs_names, kwargs_values))
904+
processed_kwargs = {}
905+
for kwarg_name, kwarg_value in zip(kwargs_names, kwargs_values):
906+
if kwarg_name in processed_kwargs:
907+
# if the kwarg name is already in the processed kwargs,
908+
# then we should convert the value to a list
909+
if not isinstance(processed_kwargs[kwarg_name], list):
910+
processed_kwargs[kwarg_name] = [processed_kwargs[kwarg_name]]
911+
processed_kwargs[kwarg_name].append(kwarg_value)
912+
else:
913+
processed_kwargs[kwarg_name] = kwarg_value
914+
return processed_kwargs
900915

901916

902917
def download_zoo_training_dir(zoo_stub: str) -> str:

0 commit comments

Comments
 (0)