Skip to content

Commit dc957a8

Browse files
committed
WIP: truncate in batch
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 929bb5d commit dc957a8

File tree

6 files changed

+113
-118
lines changed

6 files changed

+113
-118
lines changed

examples/multimodal_vision/gemma3_example.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@
2525

2626
# Define a oneshot data collator for multimodal processors
2727
# remove extra dim added by vision processor
28-
def data_collator(features: list[dict[str, object]]):
29-
features = [{key: feature[key][0] for key in feature} for feature in features]
30-
return collator(features)
28+
# def data_collator(features: list[dict[str, object]]):
29+
# features = [{key: feature[key][0] for key in feature} for feature in features]
30+
# return collator(features)
31+
# Define a oneshot data collator for multimodal inputs.
32+
def data_collator(batch):
33+
assert len(batch) == 1
34+
return {key: torch.tensor(value) for key, value in batch[0].items()}
3135

3236

3337
# Recipe
@@ -57,10 +61,10 @@ def data_collator(features: list[dict[str, object]]):
5761
max_seq_length=MAX_SEQUENCE_LENGTH,
5862
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
5963
data_collator=data_collator,
60-
trust_remote_code_model=True,
6164
pipeline="sequential",
6265
)
6366
import torch
67+
6468
del prof._memory.timeline[torch.device("cpu")]
6569
prof.save_memory_timeline("with_disable.png")
6670
exit(0)

examples/quantization_w4a16/llama3_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# Select model and load it.
99
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
10-
#model_id = "meta-llama/Llama-3.2-1B-Instruct"
10+
# model_id = "meta-llama/Llama-3.2-1B-Instruct"
1111
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
1212
tokenizer = AutoTokenizer.from_pretrained(model_id)
1313

@@ -64,7 +64,7 @@ def tokenize(sample):
6464
max_seq_length=MAX_SEQUENCE_LENGTH,
6565
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
6666
pipeline="sequential",
67-
shuffle_calibration_samples=False
67+
shuffle_calibration_samples=False,
6868
)
6969
exit(0)
7070

src/llmcompressor/args/dataset_arguments.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Callable, Optional
11+
from typing import Callable
12+
13+
from loguru import logger
1214

1315

1416
@dataclass
@@ -67,15 +69,16 @@ class CustomDatasetArguments(DVCDatasetArguments):
6769
},
6870
)
6971

70-
data_collator: Optional[Callable] = field(
71-
default=None,
72+
data_collator: str | Callable = field(
73+
default="truncation",
7274
metadata={
7375
"help": (
7476
"The function to used to form a batch from the dataset. Defaults to "
7577
"`DataCollatorWithPadding(processor)`."
7678
)
7779
},
7880
)
81+
# remove_extra_processor_dim: bool = field
7982

8083
batch_size: int = field(
8184
default=1,
@@ -154,7 +157,7 @@ class DatasetArguments(CustomDatasetArguments):
154157
default=False,
155158
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
156159
)
157-
num_data_workers: int | None = field(
160+
preprocessing_num_workers: int | None = field(
158161
default=None,
159162
metadata={"help": "The number of workers to use for dataset processing."},
160163
)
@@ -240,3 +243,17 @@ class DatasetArguments(CustomDatasetArguments):
240243

241244
def is_dataset_provided(self) -> bool:
242245
return self.dataset is not None or self.dataset_path is not None
246+
247+
def get_num_data_workers(self):
248+
import multiprocessing
249+
250+
if self.preprocessing_num_workers is not None:
251+
return self.preprocessing_num_workers
252+
253+
try:
254+
return min(multiprocessing.cpu_count() // 2, 8) # cap max at 8
255+
except NotImplementedError:
256+
logger.warning(
257+
"Could not determine number of CPUs, defaulting to 1 dataloader worker."
258+
)
259+
return 1

src/llmcompressor/datasets/utils.py

Lines changed: 70 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,15 @@
77
one-shot calibration workflows.
88
"""
99

10-
import math
11-
import multiprocessing
1210
import re
13-
from typing import Any, Callable, Optional
1411
from collections.abc import Iterator, Sized
15-
from torch.utils._pytree import tree_flatten
12+
from typing import Any, Callable, Optional
1613

1714
import torch
1815
from datasets import Dataset
1916
from loguru import logger
20-
from torch.utils.data import DataLoader, SequentialSampler, RandomSampler, Sampler
21-
from transformers.data import DataCollatorWithPadding
17+
from torch.utils.data import DataLoader, RandomSampler, Sampler
18+
from transformers.data import DataCollatorWithPadding, default_data_collator
2219

2320
from llmcompressor.args import DatasetArguments
2421
from llmcompressor.transformers.data import TextGenerationDataset
@@ -110,71 +107,31 @@ def get_calibration_dataloader(
110107
# weight-only quantization or dynamic quantization
111108
return
112109

110+
# load and tokenize dataset
113111
datasets = get_processed_dataset(
114112
dataset_args=dataset_args,
115113
processor=processor,
116114
do_oneshot=True,
117115
do_train=False,
118116
)
119-
120117
calibration_dataset = datasets.get("calibration")
121-
collate_fn = dataset_args.data_collator or _make_padding_collator(processor)
122-
num_workers = dataset_args.num_data_workers or _infer_num_data_workers()
123-
124-
return format_calibration_data(
125-
tokenized_dataset=calibration_dataset,
126-
collate_fn=collate_fn,
127-
batch_size=dataset_args.batch_size,
128-
num_calibration_samples=dataset_args.num_calibration_samples,
129-
do_shuffle=dataset_args.shuffle_calibration_samples,
130-
num_workers=num_workers,
131-
)
118+
119+
return format_calibration_data(dataset_args, calibration_dataset, processor)
132120

133121

134122
def format_calibration_data(
123+
args: DatasetArguments,
135124
tokenized_dataset: Dataset,
136-
collate_fn: Callable,
137-
batch_size: int = 1,
138-
num_calibration_samples: int | None = None,
139-
do_shuffle: bool = False,
140-
num_workers: int = 1,
125+
processor: Processor,
141126
) -> list[torch.Tensor]:
142-
"""
143-
Creates a dataloader out of the calibration dataset split, trimming it to
144-
the desired number of calibration samples
145-
:param tokenized_dataset: dataset to convert to dataloader
146-
:param num_calibration_samples: number of batches to convert
147-
:param do_shuffle: whether to shuffle the dataset before selecting calibration
148-
samples, true by default
149-
:param collate_fn: optional custom collate function, or use default
150-
:return: list of trimmed calibration data tensors
151-
"""
152-
# (1) shuffle before truncating
153-
if do_shuffle:
154-
tokenized_dataset = tokenized_dataset.shuffle()
155-
156-
# (2) truncate dataset
157-
if num_calibration_samples is not None:
158-
if num_calibration_samples > len(tokenized_dataset):
159-
logger.warning(
160-
f"Requested {num_calibration_samples} calibration samples but the "
161-
f"provided dataset only has {len(tokenized_dataset)} samples."
162-
)
163-
num_calibration_samples = len(tokenized_dataset)
164-
tokenized_dataset = tokenized_dataset.select(range(num_calibration_samples))
165-
166-
# (3) create sampler
167-
sampler = _make_sampler(tokenized_dataset, num_calibration_samples, do_shuffle)
168-
169-
# (4) create dataloader
170-
dataloader_params = {
171-
"batch_size": batch_size,
172-
"sampler": sampler,
173-
"collate_fn": collate_fn,
174-
"pin_memory": False,
175-
"num_workers": num_workers,
176-
}
177-
return DataLoader(tokenized_dataset, **dataloader_params)
127+
return DataLoader(
128+
tokenized_dataset,
129+
batch_size=args.batch_size,
130+
sampler=_make_sampler(args, tokenized_dataset),
131+
collate_fn=_make_collate_fn(args, processor),
132+
pin_memory=False,
133+
num_workers=args.get_num_data_workers(),
134+
)
178135

179136

180137
def make_dataset_splits(
@@ -216,25 +173,30 @@ def make_dataset_splits(
216173
return split_datasets
217174

218175

219-
def _make_padding_collator(processor: Processor) -> DataCollatorWithPadding:
220-
tokenizer = getattr(processor, "tokenizer", processor)
221-
if tokenizer.pad_token is None or tokenizer.pad_token_id < 0:
222-
logger.debug("Could not find padding token. Setting PAD token to EOS token")
223-
tokenizer.pad_token = tokenizer.eos_token
176+
def _make_collate_fn(dataset_args: DatasetArguments, processor: Processor) -> Callable:
177+
if isinstance(dataset_args.data_collator, Callable):
178+
return dataset_args.data_collator
224179

225-
return DataCollatorWithPadding(tokenizer)
180+
if dataset_args.data_collator == "truncation":
181+
return data_collator_with_truncation
182+
183+
elif dataset_args.data_collator == "padding":
184+
tokenizer = getattr(processor, "tokenizer", processor)
185+
if tokenizer.pad_token is None or tokenizer.pad_token_id < 0:
186+
logger.debug("Could not find padding token. Setting PAD token to EOS token")
187+
tokenizer.pad_token = tokenizer.eos_token
188+
189+
return DataCollatorWithPadding(tokenizer)
190+
191+
else:
192+
assert False
193+
194+
195+
def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler:
196+
num_samples = args.num_calibration_samples
197+
shuffle = args.shuffle_calibration_samples
198+
batch_size = args.batch_size
226199

227-
def _infer_num_data_workers() -> int:
228-
MAX_DATALOADER_WORKERS = 8
229-
try:
230-
return min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2)
231-
except NotImplementedError:
232-
logger.warning(
233-
"Could not determine number of CPUs, defaulting to 0 dataloader workers."
234-
)
235-
return 0
236-
237-
def _make_sampler(dataset: Dataset, num_samples: int | None, shuffle: bool) -> Sampler:
238200
if num_samples is not None and num_samples > len(dataset):
239201
logger.warning(
240202
f"Requested {num_samples} samples but the provided dataset only has "
@@ -243,19 +205,38 @@ def _make_sampler(dataset: Dataset, num_samples: int | None, shuffle: bool) -> S
243205
num_samples = len(dataset)
244206

245207
if shuffle:
246-
return RandomSampler(
247-
dataset,
248-
replacement=False,
249-
num_samples=num_samples,
250-
generator=None,
251-
)
208+
if batch_size > 1:
209+
logger.warning(
210+
"Shuffling a dataset can lead to unoptimal batching for sequence "
211+
"lengths non-uniform sizes. When collating with truncation, this will "
212+
"delete a large number of tokens. When collating with padding, this "
213+
"will add a large number of padding tokens.\n\nPlease consider calling "
214+
"`oneshot` with `batch_size=1`"
215+
)
216+
217+
return RandomSampler(dataset, num_samples=num_samples)
252218
else:
253-
return LengthAwareSampler(
254-
dataset,
255-
replacement=False,
256-
num_samples=num_samples,
257-
generator=None,
258-
)
219+
return LengthAwareSampler(dataset, num_samples=num_samples)
220+
221+
222+
def data_collator_with_truncation(
223+
features: list[dict[str, Any]], return_tensors: str = "pt"
224+
) -> dict[str, Any]:
225+
total_removed = 0
226+
total = 0
227+
228+
keys = set().union(*(feature.keys() for feature in features))
229+
for key in keys:
230+
lengths = [
231+
len(feature[key]) for feature in features if isinstance(feature[key], list)
232+
]
233+
min_len = min(lengths)
234+
for feature in features:
235+
total_removed += len(feature[key]) - min_len
236+
total += len(feature[key])
237+
feature[key] = feature[key][:min_len]
238+
239+
return default_data_collator(features, return_tensors)
259240

260241

261242
class LengthAwareSampler(Sampler[int]):
@@ -265,20 +246,11 @@ class LengthAwareSampler(Sampler[int]):
265246
def __init__(
266247
self,
267248
data_source: Sized,
268-
replacement: bool = False,
269249
num_samples: Optional[int] = None,
270-
generator: Optional[torch.Generator] = None,
271250
) -> None:
272251
self.data_source = data_source
273-
self.replacement = replacement
274252
self._num_samples = num_samples
275253

276-
if replacement:
277-
raise NotImplementedError()
278-
279-
if generator:
280-
raise NotImplementedError()
281-
282254
lengths = [len(sample) for sample in data_source["input_ids"]]
283255
self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist()
284256

@@ -293,4 +265,4 @@ def __iter__(self) -> Iterator[int]:
293265
return iter(self.order)
294266

295267
def __len__(self) -> int:
296-
return self._num_samples
268+
return self._num_samples

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def oneshot(
249249
dataset_path: str | None = None,
250250
splits: str | list[str] | dict[str, str] | None = None,
251251
batch_size: int = 1,
252-
data_collator: Optional[Callable] = None,
252+
data_collator: Optional[Callable] = "truncation",
253253
num_calibration_samples: int = 512,
254254
shuffle_calibration_samples: bool = True,
255255
max_seq_length: int = 384,
@@ -258,7 +258,7 @@ def oneshot(
258258
concatenate_data: bool = False,
259259
streaming: bool = False,
260260
overwrite_cache: bool = False,
261-
num_data_workers: int | None = None,
261+
preprocessing_num_workers: int | None = None,
262262
min_tokens_per_module: float | None = None,
263263
moe_calibrate_all_experts: bool = True,
264264
quantization_aware_calibration: bool = True,
@@ -319,7 +319,7 @@ def oneshot(
319319
max_seq_length.
320320
:param streaming: True to stream data from a cloud dataset.
321321
:param overwrite_cache: Whether to overwrite the cached preprocessed datasets.
322-
:param num_data_workers: Number of processes for dataset preprocessing.
322+
:param preprocessing_num_workers: Number of processes for dataset preprocessing.
323323
:param min_tokens_per_module: Minimum percentage of tokens per
324324
module, relevant for MoE models.
325325
:param moe_calibrate_all_experts: Whether to calibrate all experts during MoE

0 commit comments

Comments
 (0)