diff --git a/examples/multimodal_audio/whisper_example.py b/examples/multimodal_audio/whisper_example.py index c8e8833934..9dc0764673 100644 --- a/examples/multimodal_audio/whisper_example.py +++ b/examples/multimodal_audio/whisper_example.py @@ -1,6 +1,10 @@ import torch from datasets import load_dataset -from transformers import WhisperForConditionalGeneration, WhisperProcessor +from transformers import ( + WhisperForConditionalGeneration, + WhisperProcessor, + default_data_collator, +) from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier @@ -55,20 +59,27 @@ def process(sample): return_tensors="pt", ) - inputs["input_features"] = inputs["input_features"].to(dtype=model.dtype) + # treat labels as calibration prefill inputs["decoder_input_ids"] = inputs["labels"] del inputs["labels"] + # strip extra dim added by multimodal processors + inputs = {key: value[0] for key, value in inputs.items()} + return inputs ds = ds.map(process, remove_columns=ds.column_names) -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +# Patch: mismatch between processor and model dtype +def data_collator(features): + for feature in features: + feature["input_features"] = torch.tensor( + feature["input_features"], dtype=model.dtype + ) + + return default_data_collator(features, return_tensors="pt") # Recipe diff --git a/examples/multimodal_vision/gemma3_example.py b/examples/multimodal_vision/gemma3_example.py index dce35b7b83..b9fcb59bf8 100644 --- a/examples/multimodal_vision/gemma3_example.py +++ b/examples/multimodal_vision/gemma3_example.py @@ -1,5 +1,4 @@ import requests -import torch from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration @@ -13,17 +12,11 @@ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} +BATCH_SIZE = 4 NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 - - -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} - +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} # Recipe recipe = [ @@ -41,14 +34,13 @@ def data_collator(batch): # Perform oneshot oneshot( model=model, - tokenizer=model_id, + processor=processor, dataset=DATASET_ID, splits=DATASET_SPLIT, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - data_collator=data_collator, ) # Confirm generations of the quantized model look sane. diff --git a/examples/multimodal_vision/internvl3_example.py b/examples/multimodal_vision/internvl3_example.py index 9d79d80b01..d10a065ae2 100644 --- a/examples/multimodal_vision/internvl3_example.py +++ b/examples/multimodal_vision/internvl3_example.py @@ -37,20 +37,14 @@ def preprocess_and_tokenize(example): return_dict=True, return_tensors="pt", ) - return inputs - -ds = ds.map(preprocess_and_tokenize) + # remove extra dim added by multimodal processors + inputs = {key: value[0] for key, value in inputs.items()} + return inputs -def data_collator(batch): - assert len(batch) == 1 - item = {key: value for key, value in batch[0].items()} - item["attention_mask"] = torch.tensor([item["attention_mask"]]) - item["input_ids"] = torch.LongTensor([item["input_ids"]]) - - return item +ds = ds.map(preprocess_and_tokenize, remove_columns=ds.column_names) # Recipe recipe = GPTQModifier( @@ -68,7 +62,6 @@ def data_collator(batch): max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - data_collator=data_collator, ) # Save to disk compressed. diff --git a/examples/multimodal_vision/llava_example.py b/examples/multimodal_vision/llava_example.py index da0f712182..3a623cda16 100644 --- a/examples/multimodal_vision/llava_example.py +++ b/examples/multimodal_vision/llava_example.py @@ -1,5 +1,4 @@ import requests -import torch from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration @@ -19,12 +18,6 @@ MAX_SEQUENCE_LENGTH = 2048 -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} - - # Recipe recipe = [ GPTQModifier( @@ -44,7 +37,6 @@ def data_collator(batch): max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - data_collator=data_collator, sequential_targets=["LlamaDecoderLayer"], ) diff --git a/examples/multimodal_vision/mistral3_example.py b/examples/multimodal_vision/mistral3_example.py index 6f9567cf15..cd59d7a185 100644 --- a/examples/multimodal_vision/mistral3_example.py +++ b/examples/multimodal_vision/mistral3_example.py @@ -4,7 +4,11 @@ import requests import torch from PIL import Image -from transformers import AutoProcessor, Mistral3ForConditionalGeneration +from transformers import ( + AutoProcessor, + Mistral3ForConditionalGeneration, + default_data_collator, +) from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier @@ -27,17 +31,13 @@ MAX_SEQUENCE_LENGTH = 2048 -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return { - key: ( - torch.tensor(value) - if key != "pixel_values" - else torch.tensor(value, dtype=model.dtype) +# Patch: mismatch between processor and model dtype +def data_collator(features): + for feature in features: + feature["pixel_values"] = torch.tensor( + feature["pixel_values"], dtype=model.dtype ) - for key, value in batch[0].items() - } + return default_data_collator(features, return_tensors="pt") # Recipe diff --git a/examples/multimodal_vision/mllama_example.py b/examples/multimodal_vision/mllama_example.py index edc7bc91f5..b7969676cc 100644 --- a/examples/multimodal_vision/mllama_example.py +++ b/examples/multimodal_vision/mllama_example.py @@ -1,5 +1,4 @@ import requests -import torch from PIL import Image from transformers import AutoProcessor, MllamaForConditionalGeneration @@ -19,12 +18,6 @@ MAX_SEQUENCE_LENGTH = 2048 -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} - - # Recipe recipe = [ GPTQModifier( @@ -44,7 +37,6 @@ def data_collator(batch): max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - data_collator=data_collator, sequential_targets=["MllamaSelfAttentionDecoderLayer"], ) diff --git a/examples/multimodal_vision/pixtral_example.py b/examples/multimodal_vision/pixtral_example.py index 3ce58629a2..fa25fdb1d2 100644 --- a/examples/multimodal_vision/pixtral_example.py +++ b/examples/multimodal_vision/pixtral_example.py @@ -1,7 +1,11 @@ import requests import torch from PIL import Image -from transformers import AutoProcessor, LlavaForConditionalGeneration +from transformers import ( + AutoProcessor, + LlavaForConditionalGeneration, + default_data_collator, +) from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier @@ -19,16 +23,13 @@ MAX_SEQUENCE_LENGTH = 2048 -# Define a oneshot data collator for multimodal inputs. -# NOTE: for transformers<4.48.0, please squeeze the first dimension of `pixel_values` -# by appending `[0]` to the end of line 32 -def data_collator(batch): - assert len(batch) == 1 - return { - "input_ids": torch.LongTensor(batch[0]["input_ids"]), - "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"]), - } +# Patch: mismatch between processor and model dtype +def data_collator(features): + for feature in features: + feature["pixel_values"] = torch.tensor( + feature["pixel_values"], dtype=model.dtype + ) + return default_data_collator(features, return_tensors="pt") # Recipe @@ -46,11 +47,11 @@ def data_collator(batch): tokenizer=model_id, dataset=DATASET_ID, splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"}, + data_collator=data_collator, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - data_collator=data_collator, sequential_targets=["MistralDecoderLayer"], ) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 2618b90197..943ef94ff0 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -8,9 +8,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Callable - -from transformers import DefaultDataCollator +from typing import Callable @dataclass @@ -69,9 +67,27 @@ class CustomDatasetArguments(DVCDatasetArguments): }, ) - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, + batch_size: int = field( + default=1, + metadata={ + "help": ( + "Calibration batch size. During calibration, LLM Compressor disables " + "lm_head output computations to reduce memory usage from large " + "batch sizes. Large batch sizes may result in excess padding or " + "truncation, depending on the data_collator" + ) + }, + ) + + data_collator: str | Callable = field( + default="truncation", + metadata={ + "help": ( + "The function to used to form a batch from the dataset. Can also " + "specify 'truncation' or 'padding' to truncate or pad non-uniform " + "sequence lengths in a batch. Defaults to 'padding'." + ) + }, ) @@ -126,8 +142,8 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - shuffle_calibration_samples: bool | None = field( - default=True, + shuffle_calibration_samples: bool = field( + default=False, metadata={ "help": "whether to shuffle the dataset before selecting calibration data" }, @@ -142,7 +158,7 @@ class DatasetArguments(CustomDatasetArguments): ) preprocessing_num_workers: int | None = field( default=None, - metadata={"help": "The number of processes to use for the preprocessing."}, + metadata={"help": "The number of workers to use for dataset processing."}, ) pad_to_max_length: bool = field( default=True, @@ -214,6 +230,14 @@ class DatasetArguments(CustomDatasetArguments): "definition" }, ) + offload_sequential_activations: bool = field( + default=True, + metadata={ + "help": "Whether to offload intermediate activations between sequential " + "layers to the CPU. Disabling offloading is much faster, but uses " + "signficiantly more memory. Default is True." + }, + ) quantization_aware_calibration: bool = field( default=True, metadata={ diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 2b80b1ed9a..0d5fceca8c 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -7,20 +7,22 @@ one-shot calibration workflows. """ -import multiprocessing import re -from typing import Any, Callable +from collections.abc import Iterator, Sized +from typing import Any, Callable, Optional import torch from datasets import Dataset from loguru import logger -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator +from torch.utils.data import DataLoader, RandomSampler, Sampler +from transformers.data import DataCollatorWithPadding, default_data_collator from llmcompressor.args import DatasetArguments from llmcompressor.transformers.data import TextGenerationDataset from llmcompressor.typing import Processor +BS_WARNING_THRESHOLD = 16 + def get_processed_dataset( dataset_args: DatasetArguments, @@ -113,67 +115,23 @@ def get_calibration_dataloader( do_oneshot=True, do_train=False, ) - calibration_dataset = datasets.get("calibration") - return format_calibration_data( - tokenized_dataset=calibration_dataset, - num_calibration_samples=dataset_args.num_calibration_samples, - do_shuffle=dataset_args.shuffle_calibration_samples, - collate_fn=dataset_args.data_collator, - ) + return format_calibration_data(dataset_args, calibration_dataset, processor) def format_calibration_data( + args: DatasetArguments, tokenized_dataset: Dataset, - num_calibration_samples: int | None = None, - do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, -) -> list[torch.Tensor]: - """ - Creates a dataloader out of the calibration dataset split, trimming it to - the desired number of calibration samples - :param tokenized_dataset: dataset to convert to dataloader - :param num_calibration_samples: number of data samples to convert - :param do_shuffle: whether to shuffle the dataset before selecting calibration - samples, true by default - :param collate_fn: optional custom collate function, or use default - :return: list of trimmed calibration data tensors - """ - safe_calibration_samples = len(tokenized_dataset) - if num_calibration_samples is not None: - safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) - if safe_calibration_samples != num_calibration_samples: - logger.warning( - f"Requested {num_calibration_samples} calibration samples but " - f"the provided dataset only has {safe_calibration_samples}. " - ) - - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() - tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) - - MAX_DATALOADER_WORKERS = 8 - try: - num_workers = min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2) - except NotImplementedError: - logger.warning( - "Could not determine number of CPUs, defaulting to 0 dataloader workers." - ) - num_workers = 0 - dataloader_params = { - "batch_size": 1, - "sampler": RandomSampler(tokenized_calibration) - if do_shuffle - else SequentialSampler(tokenized_calibration), - "collate_fn": collate_fn, - "pin_memory": True, - "num_workers": num_workers, - } - - calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) - - return calibration_dataloader + processor: Processor, +) -> DataLoader: + return DataLoader( + tokenized_dataset, + batch_size=args.batch_size, + sampler=_make_sampler(args, tokenized_dataset), + collate_fn=_make_collate_fn(args, processor), + pin_memory=False, + ) def make_dataset_splits( @@ -213,3 +171,125 @@ def make_dataset_splits( "calibration": calib_split, } return split_datasets + + +def _make_collate_fn(args: DatasetArguments, processor: Processor) -> Callable: + if isinstance(args.data_collator, Callable): + return args.data_collator + + if args.data_collator == "truncation": + if args.batch_size > BS_WARNING_THRESHOLD: + logger.warning( + f"Calibrating with batch sizes greater than {BS_WARNING_THRESHOLD} and " + "`data_collator='truncation'` can lead to significant portions of the " + "calibration dataset being deleted via truncation. Please consider " + "reducing the calibration batch size or using filtering the dataset " + "to use more uniformm sequence lengths" + ) + + return data_collator_with_truncation + + elif args.data_collator == "padding": + if args.batch_size > BS_WARNING_THRESHOLD: + logger.warning( + f"Calibrating with batch sizes greater than {BS_WARNING_THRESHOLD} and " + "`data_collator='padding'` can lead to excess token used for padding, " + "which slows down calibration time and calibrates on padding tokens not" + " seen at runtime. Please consider reducing the calibration batch size " + "or using filtering the dataset to use more uniformm sequence lengths" + ) + + tokenizer = getattr(processor, "tokenizer", processor) + if tokenizer.pad_token is None or tokenizer.pad_token_id < 0: + logger.debug("Could not find padding token. Setting PAD token to EOS token") + tokenizer.pad_token = tokenizer.eos_token + + return DataCollatorWithPadding(tokenizer) + + else: + raise ValueError(f"Unknown data collator {args.data_collator}") + + +def _make_sampler(args: DatasetArguments, dataset: Dataset) -> Sampler: + num_samples = args.num_calibration_samples + shuffle = args.shuffle_calibration_samples + batch_size = args.batch_size + + if num_samples is not None and num_samples > len(dataset): + logger.warning( + f"Requested {num_samples} samples but the provided dataset only has " + f"{len(dataset)} samples." + ) + num_samples = len(dataset) + + if shuffle: + if batch_size > 1: + logger.warning( + "Shuffling the dataset can lead to unoptimal batching for sequence " + "lengths non-uniform sizes. When collating with truncation, this will " + "delete a large number of tokens. When collating with padding, this " + "will add a large number of padding tokens.\n\nPlease consider calling " + "`oneshot` with `batch_size=1`" + ) + + return RandomSampler(dataset, num_samples=num_samples) + else: + return LengthAwareSampler(dataset, num_samples=num_samples) + + +def data_collator_with_truncation( + features: list[dict[str, Any]], return_tensors: str = "pt" +) -> dict[str, Any]: + for key in ("input_ids", "labels", "attention_mask"): + if any(key not in feature for feature in features): + continue + + min_len = min(len(feature[key]) for feature in features) + for feature in features: + feature[key] = feature[key][:min_len] + + return default_data_collator(features, return_tensors) + + +class LengthAwareSampler(Sampler[int]): + """ + Sample data in order of descending sequence length. Relies on `input_ids` or + `decoder_input_ids` column existing in dataset + + :param data_source: dataset containing a `input_ids` or `decoder_input_ids` column + :param num_samples: Maximum number of samples to sample. Shorted sequence lengths + are truncated first + """ + + data_source: Sized + replacement: bool + + def __init__( + self, + data_source: Dataset, + num_samples: Optional[int] = None, + ) -> None: + self.data_source = data_source + self._num_samples = num_samples or len(data_source) + + if "input_ids" in data_source.column_names: + feature_name = "input_ids" + elif "decoder_input_ids" in data_source.column_names: + feature_name = "decoder_input_ids" + else: + logger.warning(f"Could not find input ids in {data_source.column_names}") + self.order = range(len(data_source)) + return + + lengths = [len(sample) for sample in data_source[feature_name]] + self.order = torch.argsort(torch.tensor(lengths), descending=True).tolist() + + @property + def num_samples(self) -> int: + return self._num_samples + + def __iter__(self) -> Iterator[int]: + return iter(self.order[: self._num_samples]) + + def __len__(self) -> int: + return self._num_samples diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index 398ffae372..cd0cf66280 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -12,7 +12,7 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from loguru import logger from torch.utils.data import DataLoader @@ -250,8 +250,10 @@ def oneshot( dataset_config_name: str | None = None, dataset_path: str | None = None, splits: str | list[str] | dict[str, str] | None = None, + batch_size: int = 1, + data_collator: str | Callable = "truncation", num_calibration_samples: int = 512, - shuffle_calibration_samples: bool = True, + shuffle_calibration_samples: bool = False, max_seq_length: int = 384, pad_to_max_length: bool = True, text_column: str = "text", @@ -308,6 +310,13 @@ def oneshot( to use. :param dataset_path: Path to a custom dataset. Supports json, csv, dvc. :param splits: Optional percentages of each split to download. + :param batch_size: calibration dataset batch size. During calibration, + LLM Compressor disables lm_head output computations to reduce memory + usage from large calibration batch sizes. Large batch sizes may result + excess padding or truncation, depending on the data_collator + :param data_collator: The function to used to form a batch from the dataset. Can + also specify 'truncation' or 'padding' to truncate or pad non-uniform sequence + lengths in a batch. Defaults to 'padding'. :param num_calibration_samples: Number of samples to use for one-shot calibration. :param shuffle_calibration_samples: Whether to shuffle the dataset before @@ -319,8 +328,7 @@ def oneshot( max_seq_length. :param streaming: True to stream data from a cloud dataset. :param overwrite_cache: Whether to overwrite the cached preprocessed datasets. - :param preprocessing_num_workers: Number of processes for - preprocessing. + :param preprocessing_num_workers: Number of processes for dataset preprocessing. :param min_tokens_per_module: Minimum percentage of tokens per module, relevant for MoE models. :param moe_calibrate_all_experts: Whether to calibrate all experts during MoE diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index fbff5bab16..5a19a96b36 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -402,7 +402,7 @@ def cache_smooth_activations_hook( ): self._smooth_activation_means[smooth_name] = _accumulate_mean( # Assume that first argument is the input - args[0].cpu().abs().detach().squeeze(), + args[0].cpu().abs().detach().flatten(0, -2), self._smooth_activation_means.get(smooth_name, None), ) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index ea0d5f254c..b647c68244 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -64,7 +64,6 @@ def from_dataloader( cls, dataloader: torch.utils.data.DataLoader, model_device: torch.device = torch.device("cpu"), - mask_padding: bool = True, offload_device: Optional[torch.device] = torch.device("cpu"), ): """ @@ -72,20 +71,15 @@ def from_dataloader( :param dataloader: dataloader which generates values to be cached :param model_device: device which values will be onloaded to when fetched - :param mask_padding: zero out padding tokens if True. This affects modifiers - such as GPTQ and SparseGPT :param offload_device: device to offload values to """ - # note: list comprehesion was found to not improve performance - batch_intermediates = [] - for batch in tqdm(dataloader, desc="Preparing cache"): - values = {} - for key, value in batch.items(): - if mask_padding and (key == "input_ids") and "attention_mask" in batch: - value = cls._mask_padding(value, batch["attention_mask"]) - values[key] = cls._offload_value(value, offload_device, model_device) - - batch_intermediates.append(values) + batch_intermediates = [ + { + key: cls._offload_value(value, offload_device, model_device) + for key, value in batch.items() + } + for batch in tqdm(dataloader, desc="Preparing cache") + ] return cls(batch_intermediates, offload_device) @@ -274,14 +268,3 @@ def _offload_value( ): warnings.warn(f"Offloading not implemented for type {type(value)}.") return IntermediateValue(value=value, device=None) - - @staticmethod - def _mask_padding( - input_ids: torch.Tensor, attention_mask: torch.Tensor - ) -> torch.Tensor: - if attention_mask.dim() == 4: - # some attention masks, such as those from pixtral, are are 4d - attention_mask = attention_mask[0, 0, 0].unsqueeze(0) - - # Assumes that `attention_mask` only contains zeros and ones - return input_ids * attention_mask diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 511a693b95..73d105a201 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -66,7 +66,6 @@ def __call__( # prepare to trace subgraphs modifiers = session.lifecycle.recipe.modifiers sequential_targets = get_sequential_targets(modifiers, model, dataset_args) - ignore = dataset_args.tracing_ignore # trace subgraphs @@ -90,7 +89,11 @@ def __call__( stack.enter_context(DisableQuantization(model)) # prepare intermediates cache - activations = IntermediatesCache.from_dataloader(dataloader, model_device) + cache_offload = dataset_args.offload_sequential_activations + offload_device = torch.device("cpu") if cache_offload else None + activations = IntermediatesCache.from_dataloader( + dataloader, model_device, offload_device=offload_device + ) for subgraph_index, subgraph in enumerate(subgraphs): # prepare tqdm description texts diff --git a/src/llmcompressor/transformers/data/base.py b/src/llmcompressor/transformers/data/base.py index 968c2555a8..b5a81c0ed6 100644 --- a/src/llmcompressor/transformers/data/base.py +++ b/src/llmcompressor/transformers/data/base.py @@ -16,6 +16,7 @@ from datasets import Dataset, IterableDataset from datasets.formatting.formatting import LazyRow from loguru import logger +from transformers import ProcessorMixin from llmcompressor.args import DatasetArguments from llmcompressor.transformers.data.data_helpers import ( @@ -266,6 +267,12 @@ def tokenize(self, data: LazyRow) -> Dict[str, Any]: truncation=True, ) + # strip the extra dim added by multimodal processors + if isinstance(self.processor, ProcessorMixin): + for key in data: + if isinstance(data[key], list) and len(data[key]) == 1: + data[key] = data[key][0] + # store unpadded prompt so we can mask out correct number of elements in labels if prompt is not None: data[self.PROMPT_KEY] = self.processor( diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index ff86f9400a..c5c10c3cba 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -29,7 +29,6 @@ def sample_cache(sample_dataloader): return IntermediatesCache.from_dataloader( dataloader=sample_dataloader, model_device=torch.device("cpu"), - mask_padding=True, offload_device=torch.device("cpu"), ) @@ -47,7 +46,6 @@ def test_initialization(sample_dataloader): cache = IntermediatesCache.from_dataloader( dataloader=sample_dataloader, model_device=torch.device("cpu"), - mask_padding=True, ) assert isinstance(cache, IntermediatesCache) @@ -96,18 +94,6 @@ def test_delete_intermediates(sample_cache): assert "logits" in sample_cache.batch_intermediates[0] -@pytest.mark.unit -def test_mask_padding(): - input_ids = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) - attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 1, 0]]) - - masked = IntermediatesCache._mask_padding(input_ids, attention_mask) - - # Check if padding tokens are properly masked - expected = torch.tensor([[1, 2, 3, 0], [4, 5, 6, 0]]) - assert torch.equal(masked, expected) - - @pytest.mark.unit @pytest.mark.parametrize("value", values_to_test) def test_from_dataloader(value): @@ -127,18 +113,6 @@ def test_offload_and_onload(value): assert deep_equal(onloaded, value) -@pytest.mark.unit -def test_4d_attention_mask(): - input_ids = torch.tensor([[1, 2, 3, 0]]) - attention_mask = torch.ones(1, 1, 1, 4) # 4D attention mask - - masked = IntermediatesCache._mask_padding(input_ids, attention_mask) - - # Check if the function handles 4D attention mask properly - expected = torch.tensor([[1, 2, 3, 0]]) - assert torch.equal(masked, expected) - - @pytest.mark.unit def test_device_handling(sample_dataloader): if not torch.cuda.is_available(): diff --git a/tests/llmcompressor/transformers/data/test_dataset_loading.py b/tests/llmcompressor/transformers/data/test_dataset_loading.py index 6c07f8554b..590c87e5e8 100644 --- a/tests/llmcompressor/transformers/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/data/test_dataset_loading.py @@ -1,5 +1,4 @@ import pytest -import torch from datasets import IterableDataset, load_dataset from llmcompressor.args import DatasetArguments @@ -255,14 +254,9 @@ def preprocess(sample): assert "input_ids" in data_cols assert "attention_mask" in data_cols - # confirm turning shuffle off works - calib_dataloader = format_calibration_data( - tokenized_dataset=calib_dataset, - num_calibration_samples=num_calibration_samples, - do_shuffle=dataset_args.shuffle_calibration_samples, + dataset_args, calib_dataset, tiny_llama_tokenizer ) assert len(calib_dataloader) == num_calibration_samples dataloader_sample = next(iter(calib_dataloader))["input_ids"] - diff = dataloader_sample - torch.Tensor(calib_dataset[0]["input_ids"]) - assert torch.sum(diff) == 0 + assert dataloader_sample[0].tolist() in calib_dataset["input_ids"] diff --git a/tests/llmcompressor/transformers/sparsegpt/test_sparsegpt_owl.py b/tests/llmcompressor/transformers/sparsegpt/test_sparsegpt_owl.py index f6a8327181..d346334f4b 100644 --- a/tests/llmcompressor/transformers/sparsegpt/test_sparsegpt_owl.py +++ b/tests/llmcompressor/transformers/sparsegpt/test_sparsegpt_owl.py @@ -3,6 +3,7 @@ from datasets import Dataset from transformers import AutoModelForCausalLM +from llmcompressor.args import DatasetArguments from llmcompressor.core.session_functions import create_session from llmcompressor.datasets import format_calibration_data from llmcompressor.modifiers.pruning.sparsegpt import SparseGPTModifier @@ -26,7 +27,8 @@ def test_infer_owl_layer_sparsity(): dataset = Dataset.from_dict( {"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))} ) - dataloader = format_calibration_data(dataset) + args = DatasetArguments(data_collator="truncation") + dataloader = format_calibration_data(args, dataset, None) sequential_targets = modifier._infer_sequential_targets(model) layers = get_layers(sequential_targets, model)