12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import List , Optional
15
+ from typing import List , Optional , Union
16
16
17
17
from sparseml .transformers .utils .sparse_model import SparseAutoModelForCausalLM
18
18
from sparseml .transformers .utils .sparse_tokenizer import SparseAutoTokenizer
21
21
try :
22
22
import numpy
23
23
import torch
24
+ from datasets import Dataset as HuggingFaceDataset
24
25
from datasets import load_dataset
25
26
from torch .nn import CrossEntropyLoss
26
27
from tqdm import tqdm
@@ -65,6 +66,7 @@ def perplexity_eval(
65
66
dataset_config_name = dataset_config_name ,
66
67
split = split ,
67
68
limit = limit ,
69
+ text_column_name = kwargs .pop ("text_column_name" , None ),
68
70
)
69
71
add_start_token = True
70
72
max_length = None
@@ -186,7 +188,7 @@ def perplexity_eval(
186
188
return Result (formatted = [eval ], raw = raw )
187
189
188
190
189
- def _infer_dataset_config_name (datasets ):
191
+ def _infer_dataset_config_name (datasets : str ):
190
192
"""
191
193
:param datasets: The name of the dataset to load
192
194
:return: The name of the dataset config to load
@@ -199,7 +201,7 @@ def _infer_dataset_config_name(datasets):
199
201
def _load_perplexity_dataset (
200
202
dataset_name : str ,
201
203
dataset_config_name : str ,
202
- text_column_name : Optional [str ] = None ,
204
+ text_column_name : Union [str , List [ str ], None ] = None ,
203
205
split : Optional [str ] = None ,
204
206
limit : Optional [int ] = None ,
205
207
) -> List [str ]:
@@ -209,27 +211,86 @@ def _load_perplexity_dataset(
209
211
:param dataset_name: The name of the dataset to load
210
212
:param dataset_config_name: The name of the dataset config to load
211
213
:param text_column_name: The name of the column containing the text data
212
- if None, defaults to "text"
214
+ if None, defaults to "text". If a list of column names is passed, the
215
+ columns will be concatenated to form the input text
213
216
:param split: The split of the dataset to load, if None uses test split
214
217
if available, otherwise uses train split
215
218
:param nsamples: The number of samples to load from the dataset
216
219
:return: The loaded dataset as a list of strings
217
220
"""
218
- dataset = load_dataset (dataset_name , dataset_config_name , split = split )
219
- if isinstance (dataset , dict ):
220
- # check if test split exists
221
- dataset = dataset ["test" ] if "test" in dataset else dataset ["train" ]
222
-
223
- text_column_name = text_column_name or "text"
224
- if text_column_name not in dataset .column_names :
225
- raise ValueError (
226
- f"Dataset { dataset_name } does not contain a column named { text_column_name } "
227
- )
228
- dataset = dataset [text_column_name ]
221
+ dataset : HuggingFaceDataset = _fetch_dataset_split (
222
+ dataset_name = dataset_name ,
223
+ dataset_config_name = dataset_config_name ,
224
+ split = split ,
225
+ )
226
+ text_column_name : List [str ] = _verify_text_column_name (
227
+ dataset = dataset , text_column_name = text_column_name
228
+ )
229
+
229
230
inputs = []
230
- for s in dataset :
231
- if s != "" :
232
- inputs .append (s )
231
+ for sample in dataset :
232
+ input_sample = "" .join (sample [column_name ] for column_name in text_column_name )
233
+ if input_sample != "" :
234
+ inputs .append (input_sample )
233
235
if limit is not None and len (inputs ) >= limit :
234
236
break
235
237
return inputs
238
+
239
+
240
+ def _fetch_dataset_split (
241
+ dataset_name : str , dataset_config_name : Optional [str ] = None , split = None
242
+ ):
243
+ """
244
+ Loads and returns the specified split of the dataset.
245
+
246
+ :param dataset_name: The name of the dataset to load from the HuggingFace
247
+ datasets library
248
+ :param dataset_config_name: The name of the dataset config to load, if any.
249
+ :param split: The split of the dataset to load, if None uses test split
250
+ if available, otherwise uses train split. Also supports HuggingFace
251
+ style splits such as "train[:10%]", "test", "validation", etc.
252
+ :return: The loaded dataset split
253
+ """
254
+ dataset = load_dataset (dataset_name , dataset_config_name , split = split )
255
+ if split is not None :
256
+ # specified split was found in the dataset
257
+ return dataset
258
+
259
+ # try to infer the split to use
260
+ if "test" in dataset :
261
+ return dataset ["test" ]
262
+
263
+ if "train" in dataset :
264
+ return dataset ["train" ]
265
+
266
+ raise ValueError (
267
+ f"Neither 'test' nor 'train' split found in dataset { dataset_name } . "
268
+ "Specify a valid split using the 'split' argument."
269
+ )
270
+
271
+
272
+ def _verify_text_column_name (
273
+ dataset : HuggingFaceDataset ,
274
+ text_column_name : Union [str , List [str ], None ] = None ,
275
+ ) -> List [str ]:
276
+ """
277
+ Verifies that the dataset contains the specified text column name(s),
278
+ and returns the text column name(s) to use for evaluation as a list.
279
+
280
+ :param dataset: The huggingface dataset to verify
281
+ :param text_column_name: The name of the column containing the text data
282
+ if None, defaults to "text". If a list of column names is passed, all
283
+ columns must be present in the dataset
284
+ :return: The text column name(s) to use for evaluation as a list of strings
285
+ """
286
+ text_column_names = text_column_name or ["text" ]
287
+
288
+ if isinstance (text_column_names , str ):
289
+ text_column_names = [text_column_name ]
290
+
291
+ for column_name in text_column_names :
292
+ if column_name not in dataset .column_names :
293
+ raise ValueError (
294
+ f"Dataset { dataset } does not contain a column named { column_name } "
295
+ )
296
+ return text_column_names
0 commit comments