77one-shot calibration workflows.
88"""
99
10- import math
11- import multiprocessing
1210import re
13- from typing import Any , Callable , Optional
1411from collections .abc import Iterator , Sized
15- from torch . utils . _pytree import tree_flatten
12+ from typing import Any , Callable , Optional
1613
1714import torch
1815from datasets import Dataset
1916from loguru import logger
20- from torch .utils .data import DataLoader , SequentialSampler , RandomSampler , Sampler
21- from transformers .data import DataCollatorWithPadding
17+ from torch .utils .data import DataLoader , RandomSampler , Sampler
18+ from transformers .data import DataCollatorWithPadding , default_data_collator
2219
2320from llmcompressor .args import DatasetArguments
2421from llmcompressor .transformers .data import TextGenerationDataset
@@ -110,71 +107,31 @@ def get_calibration_dataloader(
110107 # weight-only quantization or dynamic quantization
111108 return
112109
110+ # load and tokenize dataset
113111 datasets = get_processed_dataset (
114112 dataset_args = dataset_args ,
115113 processor = processor ,
116114 do_oneshot = True ,
117115 do_train = False ,
118116 )
119-
120117 calibration_dataset = datasets .get ("calibration" )
121- collate_fn = dataset_args .data_collator or _make_padding_collator (processor )
122- num_workers = dataset_args .num_data_workers or _infer_num_data_workers ()
123-
124- return format_calibration_data (
125- tokenized_dataset = calibration_dataset ,
126- collate_fn = collate_fn ,
127- batch_size = dataset_args .batch_size ,
128- num_calibration_samples = dataset_args .num_calibration_samples ,
129- do_shuffle = dataset_args .shuffle_calibration_samples ,
130- num_workers = num_workers ,
131- )
118+
119+ return format_calibration_data (dataset_args , calibration_dataset , processor )
132120
133121
134122def format_calibration_data (
123+ args : DatasetArguments ,
135124 tokenized_dataset : Dataset ,
136- collate_fn : Callable ,
137- batch_size : int = 1 ,
138- num_calibration_samples : int | None = None ,
139- do_shuffle : bool = False ,
140- num_workers : int = 1 ,
125+ processor : Processor ,
141126) -> list [torch .Tensor ]:
142- """
143- Creates a dataloader out of the calibration dataset split, trimming it to
144- the desired number of calibration samples
145- :param tokenized_dataset: dataset to convert to dataloader
146- :param num_calibration_samples: number of batches to convert
147- :param do_shuffle: whether to shuffle the dataset before selecting calibration
148- samples, true by default
149- :param collate_fn: optional custom collate function, or use default
150- :return: list of trimmed calibration data tensors
151- """
152- # (1) shuffle before truncating
153- if do_shuffle :
154- tokenized_dataset = tokenized_dataset .shuffle ()
155-
156- # (2) truncate dataset
157- if num_calibration_samples is not None :
158- if num_calibration_samples > len (tokenized_dataset ):
159- logger .warning (
160- f"Requested { num_calibration_samples } calibration samples but the "
161- f"provided dataset only has { len (tokenized_dataset )} samples."
162- )
163- num_calibration_samples = len (tokenized_dataset )
164- tokenized_dataset = tokenized_dataset .select (range (num_calibration_samples ))
165-
166- # (3) create sampler
167- sampler = _make_sampler (tokenized_dataset , num_calibration_samples , do_shuffle )
168-
169- # (4) create dataloader
170- dataloader_params = {
171- "batch_size" : batch_size ,
172- "sampler" : sampler ,
173- "collate_fn" : collate_fn ,
174- "pin_memory" : False ,
175- "num_workers" : num_workers ,
176- }
177- return DataLoader (tokenized_dataset , ** dataloader_params )
127+ return DataLoader (
128+ tokenized_dataset ,
129+ batch_size = args .batch_size ,
130+ sampler = _make_sampler (args , tokenized_dataset ),
131+ collate_fn = _make_collate_fn (args , processor ),
132+ pin_memory = False ,
133+ num_workers = args .get_num_data_workers (),
134+ )
178135
179136
180137def make_dataset_splits (
@@ -216,25 +173,30 @@ def make_dataset_splits(
216173 return split_datasets
217174
218175
219- def _make_padding_collator (processor : Processor ) -> DataCollatorWithPadding :
220- tokenizer = getattr (processor , "tokenizer" , processor )
221- if tokenizer .pad_token is None or tokenizer .pad_token_id < 0 :
222- logger .debug ("Could not find padding token. Setting PAD token to EOS token" )
223- tokenizer .pad_token = tokenizer .eos_token
176+ def _make_collate_fn (dataset_args : DatasetArguments , processor : Processor ) -> Callable :
177+ if isinstance (dataset_args .data_collator , Callable ):
178+ return dataset_args .data_collator
224179
225- return DataCollatorWithPadding (tokenizer )
180+ if dataset_args .data_collator == "truncation" :
181+ return data_collator_with_truncation
182+
183+ elif dataset_args .data_collator == "padding" :
184+ tokenizer = getattr (processor , "tokenizer" , processor )
185+ if tokenizer .pad_token is None or tokenizer .pad_token_id < 0 :
186+ logger .debug ("Could not find padding token. Setting PAD token to EOS token" )
187+ tokenizer .pad_token = tokenizer .eos_token
188+
189+ return DataCollatorWithPadding (tokenizer )
190+
191+ else :
192+ assert False
193+
194+
195+ def _make_sampler (args : DatasetArguments , dataset : Dataset ) -> Sampler :
196+ num_samples = args .num_calibration_samples
197+ shuffle = args .shuffle_calibration_samples
198+ batch_size = args .batch_size
226199
227- def _infer_num_data_workers () -> int :
228- MAX_DATALOADER_WORKERS = 8
229- try :
230- return min (MAX_DATALOADER_WORKERS , multiprocessing .cpu_count () // 2 )
231- except NotImplementedError :
232- logger .warning (
233- "Could not determine number of CPUs, defaulting to 0 dataloader workers."
234- )
235- return 0
236-
237- def _make_sampler (dataset : Dataset , num_samples : int | None , shuffle : bool ) -> Sampler :
238200 if num_samples is not None and num_samples > len (dataset ):
239201 logger .warning (
240202 f"Requested { num_samples } samples but the provided dataset only has "
@@ -243,19 +205,38 @@ def _make_sampler(dataset: Dataset, num_samples: int | None, shuffle: bool) -> S
243205 num_samples = len (dataset )
244206
245207 if shuffle :
246- return RandomSampler (
247- dataset ,
248- replacement = False ,
249- num_samples = num_samples ,
250- generator = None ,
251- )
208+ if batch_size > 1 :
209+ logger .warning (
210+ "Shuffling a dataset can lead to unoptimal batching for sequence "
211+ "lengths non-uniform sizes. When collating with truncation, this will "
212+ "delete a large number of tokens. When collating with padding, this "
213+ "will add a large number of padding tokens.\n \n Please consider calling "
214+ "`oneshot` with `batch_size=1`"
215+ )
216+
217+ return RandomSampler (dataset , num_samples = num_samples )
252218 else :
253- return LengthAwareSampler (
254- dataset ,
255- replacement = False ,
256- num_samples = num_samples ,
257- generator = None ,
258- )
219+ return LengthAwareSampler (dataset , num_samples = num_samples )
220+
221+
222+ def data_collator_with_truncation (
223+ features : list [dict [str , Any ]], return_tensors : str = "pt"
224+ ) -> dict [str , Any ]:
225+ total_removed = 0
226+ total = 0
227+
228+ keys = set ().union (* (feature .keys () for feature in features ))
229+ for key in keys :
230+ lengths = [
231+ len (feature [key ]) for feature in features if isinstance (feature [key ], list )
232+ ]
233+ min_len = min (lengths )
234+ for feature in features :
235+ total_removed += len (feature [key ]) - min_len
236+ total += len (feature [key ])
237+ feature [key ] = feature [key ][:min_len ]
238+
239+ return default_data_collator (features , return_tensors )
259240
260241
261242class LengthAwareSampler (Sampler [int ]):
@@ -265,20 +246,11 @@ class LengthAwareSampler(Sampler[int]):
265246 def __init__ (
266247 self ,
267248 data_source : Sized ,
268- replacement : bool = False ,
269249 num_samples : Optional [int ] = None ,
270- generator : Optional [torch .Generator ] = None ,
271250 ) -> None :
272251 self .data_source = data_source
273- self .replacement = replacement
274252 self ._num_samples = num_samples
275253
276- if replacement :
277- raise NotImplementedError ()
278-
279- if generator :
280- raise NotImplementedError ()
281-
282254 lengths = [len (sample ) for sample in data_source ["input_ids" ]]
283255 self .order = torch .argsort (torch .tensor (lengths ), descending = True ).tolist ()
284256
@@ -293,4 +265,4 @@ def __iter__(self) -> Iterator[int]:
293265 return iter (self .order )
294266
295267 def __len__ (self ) -> int :
296- return self ._num_samples
268+ return self ._num_samples
0 commit comments