diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 636a872487e5..b6aad7e94650 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -15,7 +15,7 @@ from copy import deepcopy -from .core_model_loading import Concatenate, MergeModulelist, WeightConverter +from .core_model_loading import Concatenate, MergeModulelist, WeightConverter, WeightRenaming from .utils import is_torch_available @@ -26,6 +26,7 @@ def _build_checkpoint_conversion_mapping(): mapping = { "mixtral": [ + WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"), WeightConverter( source_keys=[ "block_sparse_moe.experts.*.w1.weight", @@ -50,12 +51,6 @@ def _build_checkpoint_conversion_mapping(): ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first ), - # WeightConverter( - # ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], - # "self_attn.qkv_proj", - # operations=[Concatenate(dim=0)], # more like stack? - # ), - WeightConverter("*.block_sparse_moe.", "*.mlp."), ], "qwen2_moe": [ WeightConverter( @@ -73,11 +68,11 @@ def _build_checkpoint_conversion_mapping(): ), ], "legacy": [ - WeightConverter( + WeightRenaming( source_keys="LayerNorm.gamma", target_keys="LayerNorm.weight", ), - WeightConverter( + WeightRenaming( source_keys="LayerNorm.beta", target_keys="LayerNorm.bias", ), @@ -85,22 +80,22 @@ def _build_checkpoint_conversion_mapping(): } if hasattr(torch.nn.utils.parametrizations, "weight_norm"): mapping["legacy"] += [ - WeightConverter( + WeightRenaming( source_keys="weight_g", target_keys="parametrizations.weight.original0", ), - WeightConverter( + WeightRenaming( source_keys="weight_v", target_keys="parametrizations.weight.original1", ), ] else: mapping["legacy"] += [ - WeightConverter( + WeightRenaming( source_keys="parametrizations.weight.original0", target_keys="weight_g", ), - WeightConverter( + WeightRenaming( source_keys="parametrizations.weight.original1", target_keys="weight_v", ), diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index aeabe61c1a46..f8db2406017c 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -16,16 +16,14 @@ from __future__ import annotations -import itertools import os import re from abc import abstractmethod -from collections import defaultdict from collections.abc import MutableMapping, MutableSet, Sequence from concurrent.futures import Future, ThreadPoolExecutor from contextlib import contextmanager +from copy import deepcopy from dataclasses import dataclass, field -from functools import partial from typing import TYPE_CHECKING, Any, Optional, Union import torch @@ -49,71 +47,57 @@ logger = logging.get_logger(__name__) -def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str: +def compile_glob_rule(source_glob: str, target_glob: str) -> tuple[re.Pattern, str]: """ - Convert a glob with '*' into a regex *source* string. We don't use `glob.translate` - '*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing. + Convert a glob-style source + target into a full regex + replacement. + + Rules: + - '*' in source_glob → (.*) capture group + - '*' in target_glob → \\1, \\2, ... backrefs """ - star = r"(\d+)" if digits_only else r"(.+)" - return glob.replace(r"\*", star) + regex = re.compile(source_glob) + counter = 0 -def build_glob_alt( - globs: list[str], -) -> tuple[re.Pattern, dict[str, str]]: - r""" - Build one compiled regex alternation with a named group per glob. This allows to run a single - re.match and get the correct group name to finally get which pattern matched. - Returns (compiled_regex, name->glob map). - - Example: - - ```py - >>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"]) - >>> print(reg) - (re.compile(r'(?P.*mlp\.(\d+)\.w1)|(?P.*mlp\.(\d+)\.w2)', re.UNICODE), - >>> print(map_) - {'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'}) - >>> match_ = reg.match("model.layers.0.mlp.0.w1.weight") - >>> print(match_.lastgroup) - 'g0' - >>> print(map_[match_.lastgroup]) - mlp.*.w1 - ``` - """ - name_map: dict[str, str] = {} - parts: list[str] = [] - - for i, g in enumerate(globs): - name = f"g{i}" - name_map[name] = g - pat_src = _glob_to_regex_src(g) - prefix_src = "" - if pat_src.startswith("*"): - prefix_src = "." - elif not pat_src.startswith(r"\^") and not pat_src.startswith(r".*"): - prefix_src = ".*" - - parts.append(f"(?P<{name}>{prefix_src}{pat_src}.*)") - - alt_src = "|".join(parts).replace("\\^", "^").replace("\\.", r"\.") - try: - reg = re.compile(alt_src) - except re.error as e: - logger.error(f"Error compiling regex for alternation: {alt_src}") - raise e + def _star_to_backref(_: re.Match) -> str: + nonlocal counter + counter += 1 + return rf"\{counter}" - return reg, name_map + replacement = re.sub(r"\*", _star_to_backref, target_glob) + return regex, replacement -def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]: +def build_glob_alternation( + globs: list[Union[WeightRenaming, WeightConverter, str]], +) -> tuple[re.Pattern, dict[str, str], dict[str, str]]: """ - Match the key against the alternation; return the original glob string that matched. + Build a single alternation regex with one named group per glob. """ - m = alt.match(key) - if not m: - return None - return name_map.get(m.lastgroup) + src_group_to_glob: dict[str, str] = {} + tgt_group_to_glob: dict[str, str] = {} + branches: list[str] = [] + i = 0 + for glob in globs: + if isinstance(glob, (WeightRenaming, WeightConverter)): + for src in glob.source_keys: + group_name = f"g{i}" + src_group_to_glob[group_name] = src + i += 1 + body = src.replace("*", r".*") + branches.append(f"(?P<{group_name}>{body})") + tgt_group_to_glob[group_name] = glob.target_keys[0] # we index witht the first target + else: + group_name = f"g{i}" + src_group_to_glob[group_name] = glob + i += 1 + body = glob + body = body.replace("*", r".*") + branches.append(f"(?P<{group_name}>{body})") + tgt_group_to_glob[group_name] = glob + + alternation = re.compile("|".join(branches)) + return alternation, src_group_to_glob, tgt_group_to_glob class ConversionOps: @@ -124,8 +108,14 @@ class ConversionOps: @abstractmethod def convert( - self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs - ) -> torch.Tensor: + self, + value: dict[str, Any], + source_keys: list[str], + target_keys: list[str], + full_layer_name: str, + config, + **kwargs, + ) -> dict[str, list[torch.Tensor]]: raise NotImplementedError @@ -135,20 +125,24 @@ class Chunk(ConversionOps): reverse_op: type[ConversionOps] def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None): - if chunks is None and sizes is None: - raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.") - if chunks is not None and chunks <= 0: - raise ValueError("`chunks` must be a strictly positive integer.") self.dim = dim self.chunks = chunks self.sizes = list(sizes) if sizes is not None else None self.reverse_op = Concatenate - def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]: - # chunk requires a single tensor input - if len(value) != 1 or len(value[0]) != 1: - raise ValueError("Chunk operation requires a single tensor input.") - return list(torch.chunk(value[0][0], self.chunks, dim=self.dim)) + def convert( + self, + value: dict[str, list[torch.Tensor]], + source_keys: list[str], + target_keys: list[str], + full_layer_name: str, + config, + ) -> dict[str, list[torch.Tensor]]: + tensors = next(iter(value.values())) + tensor = tensors[0] + sizes = len(target_keys) + chunks = torch.chunk(tensor, sizes, dim=self.dim) + return {full_layer_name.replace(target_keys[0], target): [chunk] for target, chunk in zip(target_keys, chunks)} class Concatenate(ConversionOps): @@ -161,14 +155,20 @@ def __init__(self, dim: int = 0): self.reverse_op = Chunk @torch.no_grad - def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor: - if isinstance(value[0], list): - value = [v[0] for v in value] - tensors = value - if not tensors: - raise ValueError("Fuse requires at least one tensor to concatenate.") + def convert( + self, + value: dict[str, list[torch.Tensor]], + source_keys: list[str], + target_keys: list[str], + full_layer_name: str, + config, + ) -> dict[str, torch.Tensor]: + if len(target_keys) != 1: + raise ValueError("Concatenate expects a single target key.") + if len(value) != len(source_keys): + raise ValueError("Concatenate received an unexpected number of tensors compared to source keys.") - return torch.cat(tuple(tensors), dim=self.dim) + return {full_layer_name: torch.cat(tuple(value.values()), dim=self.dim)} class MergeModulelist(Concatenate): @@ -183,13 +183,21 @@ def __init__(self, dim: int = 0): self.reverse_op = SplitModulelist @torch.no_grad - def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]: - merged = [] - for group in value: - if not isinstance(group, Sequence) or len(group) == 0: - raise ValueError("MergeModulelist requires non-empty sub-sequences.") - group = [k for k in group if k.ndim] - merged.append(torch.stack(group, dim=self.dim)) + def convert( + self, + value: dict[str, list[torch.Tensor]], + source_keys: list[str], + target_keys: list[str], + full_layer_name: str, + config, + ) -> dict[str, torch.Tensor]: + merged: dict[str, torch.Tensor] = {} + for idx, key in enumerate(value.keys()): + tensors = value.get(key, []) + if len(source_keys) == 1: + key = full_layer_name + stacked = torch.stack(tensors, dim=self.dim) + merged[key] = stacked return merged @@ -204,18 +212,24 @@ def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0): self.reverse_op = MergeModulelist @torch.no_grad - def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]: - if not isinstance(value, Sequence): - raise TypeError("SplitModulelist expects a sequence of tensors.") + def convert( + self, + value: dict[str, list[torch.Tensor]], + source_keys: list[str], + target_keys: list[str], + full_layer_name: str, + config, + ) -> dict[str, list[torch.Tensor]]: if len(value) != len(self.sizes): - raise ValueError("Number of tensors does not match the provided split specifications.") - - result: list[list[torch.Tensor]] = [] - for tensor, split_sizes in zip(value, self.sizes): - if not isinstance(tensor, torch.Tensor): + raise ValueError("SplitModulelist received an unexpected number of tensors.") + result: dict[str, list[torch.Tensor]] = {} + for (key, tensors), split_sizes in zip(value.items(), self.sizes): + if len(tensors) != 1: + raise ValueError("SplitModulelist expects exactly one tensor per key.") + current_tensor = tensors[0] + if not isinstance(current_tensor, torch.Tensor): raise TypeError("SplitModulelist can only split torch.Tensor instances.") - splits = torch.split(tensor, split_sizes, dim=self.dim) - result.append(list(splits)) + result[key] = list(torch.split(current_tensor, split_sizes, dim=self.dim)) return result @@ -237,60 +251,112 @@ def _apply(self, tensor: torch.Tensor) -> torch.Tensor: @torch.no_grad def convert( - self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config - ) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]: + self, + value: dict[str, list[torch.Tensor]], + source_keys: list[str], + target_keys: list[str], + full_layer_name: str, + config, + ) -> dict[str, list[torch.Tensor]]: self.config = config - out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value] - return out + output: dict[str, list[torch.Tensor]] = {} + for key, tensors in value.items(): + if len(tensors) != 1: + raise ValueError("PermuteForRope expects a single tensor per key.") + output[key] = [self._apply(tensors[0])] + return output @dataclass(slots=True) -class WeightConverter: - r""" - A weight convert that acts on a pattern of source keys. - The keys need to be collected based on the target keys. - - With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match: - `model.layers.*.experts.*` -> it will act on all of them - {"model.layers.*.experts.*": []} - but - `experts.*.mlp` will be layer specific. - {"model.layers.1.experts.*": [], } - - source_keys: str | list[str] (wildcards '*' match digits) - - target_keys: str | list[str] | None - - distributed_operation / operations / quantization_operations are ALWAYS lists. - - TODO: for BNB we need to collect model.weight.quant_state_keys - """ - - source_keys: Union[str, list[str]] - target_keys: Optional[Union[str, list[str]]] = None - operations: list[ConversionOps] = field(default_factory=list, repr=False) +class WeightTransform: + source_keys: Union[str, list[str]] = field(init=True) + target_keys: Union[str, list[str]] = field(init=True) distributed_operation: Optional[TensorParallelLayer] = None quantization_operation: Optional[ConversionOps] = None + collected_tensors: dict[str, list[Future]] = field(default_factory=dict, init=False) + layer_targets: dict[str, set[str]] = field(default_factory=dict, init=False) + def __post_init__(self): - if not isinstance(self.source_keys, list): + if isinstance(self.source_keys, str): self.source_keys = [self.source_keys] - targets_were_none = False - if not isinstance(self.target_keys, list): - if self.target_keys is None: - self.target_keys = list(self.source_keys) - targets_were_none = True - else: - self.target_keys = [self.target_keys] + if isinstance(self.target_keys, str): + self.target_keys = [self.target_keys] - if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2: - raise ValueError( - f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one." - ) + def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future): + bucket = self.collected_tensors.setdefault(source_pattern, []) + bucket += [future] + + bucket = self.layer_targets.setdefault(target_key, set()) + bucket.add(source_key) @dataclass(slots=True) -class ConversionEntry: - weight_converter: WeightConverter - collected_tensors: dict = field(default_factory=lambda: defaultdict(dict)) +class WeightRenaming(WeightTransform): + # Special case of WeightTransform that only renames keys without any conversion. + + def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None): + misc = {} + for pattern, futures in self.collected_tensors.items(): + self.collected_tensors[pattern] = [future.result() for future in futures] + + collected_tensors = self.collected_tensors + if quantizer is not None and self.quantization_operation is not None: + with log_to_misc(layer_name, misc, (self.collected_tensors, layer_name), self.quantization_operation): + collected_tensors = self.quantization_operation.convert( + self.collected_tensors, + source_keys=self.source_keys, + target_keys=self.target_keys, + full_layer_name=layer_name, + config=config, + quant_config=quantizer.quantization_config, + missing_keys=missing_keys, + ) + + return collected_tensors, misc + + +@dataclass(slots=True) +class WeightConverter(WeightTransform): + operations: list[ConversionOps] = field(default_factory=list, repr=False) + + def __post_init__(self): + WeightTransform.__post_init__(self) + if bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2: + raise ValueError( + f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one." + ) + if not self.operations: + raise ValueError("WeightConverter requires at least one operation.") + + def convert(self, layer_name: str, config=None, quantizer=None, missing_keys: Optional[MutableSet[str]] = None): + misc = {} + for pattern, futures in self.collected_tensors.items(): + self.collected_tensors[pattern] = [future.result() for future in futures] + + collected_tensors = self.collected_tensors + for op in self.operations: + with log_to_misc(layer_name, misc, (collected_tensors, layer_name), op): + collected_tensors = op.convert( + collected_tensors, + source_keys=self.source_keys, + target_keys=self.target_keys, + full_layer_name=layer_name, + config=config, + ) + if quantizer is not None and self.quantization_operation is not None: + with log_to_misc(layer_name, misc, (collected_tensors, layer_name), self.quantization_operation): + collected_tensors = self.quantization_operation.convert( + collected_tensors, + source_keys=self.source_keys, + target_keys=self.target_keys, + full_layer_name=layer_name, + config=config, + quant_config=quantizer.quantization_config, + missing_keys=missing_keys, + ) + return collected_tensors, misc GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4 @@ -353,7 +419,7 @@ def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> values, target_keys = extras descriptor = f"{op_name} " if op_name else "" misc[layer_name] = ( - f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}" + f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values)}" ) elif isinstance(extras, str): suffix = f" via {op_name}" if op_name else "" @@ -372,6 +438,7 @@ def set_param_for_module( mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]], missing_keys: MutableSet[str], misc: MutableMapping[str, Any], + unexpected_keys: MutableSet[str], distributed_operation: Optional[TensorParallelLayer], hf_quantizer: HfQuantizer, ): @@ -382,33 +449,38 @@ def set_param_for_module( param_value = param_value[0] elif not isinstance(param_value, torch.nn.Parameter): param_value = param_value[...] - ref = getattr(module_obj, param_name) - use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor - if not isinstance(param_value, torch.nn.Parameter): - if distributed_operation is not None: - param_value = DTensor.from_local( - param_value, - distributed_operation.device_mesh, - getattr(distributed_operation, "shard", Replicate()), - run_check=False, - shape=ref.size(), - stride=ref.stride(), - ) - if not use_dtensor: - # we convert to local - param_value = param_value.to_local() - if param_name not in module_obj._buffers: - param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) - - # Remove from missing keys (it's either mismatched, or all good) - missing_keys.discard(layer_name) - if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: - mismatch_keys.add((layer_name, param_value.shape, ref.shape)) - module_obj.param_name._is_hf_initialized = False # Needs to be initialized + ref = getattr(module_obj, param_name) + if ref is None: + unexpected_keys.add(layer_name) else: - param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing - setattr(module_obj, param_name, param_value) + use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor + if not isinstance(param_value, torch.nn.Parameter): + if distributed_operation is not None: + param_value = DTensor.from_local( + param_value, + distributed_operation.device_mesh, + getattr(distributed_operation, "shard", Replicate()), + run_check=False, + shape=ref.size(), + stride=ref.stride(), + ) + if not use_dtensor: + # we convert to local + param_value = param_value.to_local() + if param_name not in module_obj._buffers: + param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) + + # Remove from missing keys (it's either mismatched, or all good) + missing_keys.discard(layer_name) + if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: + mismatch_keys.add((layer_name, param_value.shape, ref.shape)) + module_obj.param_name._is_hf_initialized = False # Needs to be initialized + else: + param_value._is_hf_initialized = ( + True # super important otherwise _init_weight re-initi if bias is missing + ) + setattr(module_obj, param_name, param_value) class SkipLayer(Exception): @@ -417,10 +489,29 @@ class SkipLayer(Exception): pass +def repl(m, repl_map: dict[str, str]) -> str: + # Collect all groups that matched + matched_groups = [name for name, val in m.groupdict().items() if val] + + if len(matched_groups) == 0: + # Should never happen + return m.group(0) + + if len(matched_groups) > 1: + raise ValueError( + "only a single match should happen, your regex patterns are tangled: " + f"groups matched = {matched_groups} for the patternsL {repl_map.keys()}" + ) + + # Exactly one match => return replacement + name = matched_groups[0] + return repl_map[name] + + def convert_and_load_state_dict_in_model( model: PreTrainedModel, state_dict: dict[str, Any], - weight_mapping: dict[str, WeightConverter] | None, + weight_mapping: list[WeightConverter | WeightRenaming] | None, tp_plan: dict[str, str] | None, hf_quantizer: HfQuantizer | None, dtype: torch.dtype | None = None, @@ -428,19 +519,101 @@ def convert_and_load_state_dict_in_model( dtype_plan: dict | None = None, device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None, ): - """ - Convert a state dict according to a weight mapping (one WeightConverter per glob pattern), - collecting tensors per *layer instance* (the concrete indices captured from '*'). + r""" + We build a mapping from the keys obtained by renaming each of the checkpoint keys according to the weight_mapping rules. + Then we load the tensors into the model, applying any conversion operations as needed. + + The `param_name_to_load` will look like this: + { + "model.layers.0.attention.q.weight": # Notice here there is only the first key of the target keys + WeightConverter( + source_keys=["qkv"], + target_keys=["q", "k","v"], + operations=[Chunk(dim=0, chunks=3)]), + collected_tensors={ + "qkv": [Future, Future, Future]}, + layer_targets={ + "model.layers.0.attention.q.weight": {"model.layers.0.attention.qkv.weight"}, + "model.layers.0.attention.k.weight": {"model.layers.0.attention.qkv.weight"}, + "model.layers.0.attention.v.weight": {"model.layers.0.attention.qkv.weight"}, + } + ), + ... + } + + We make sure that the keys are the full keys. The only "nit" here is that 1 key can map to multiple target keys (e.g. qkv -> q, k, v). + In that case the weight converter will take care of doing the appropriate renaming. + + For example for: + ```python + WeightConverter( + source_keys=["mlp.experts.*.gate_proj.weight","mlp.experts.*.up_proj.weight"], + target_keys="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ) + ``` + we would have the following collected tensors: + ```python + collected_tensors = { + "mlp.experts.*.gate_proj.weight": [Future, Future, Future, Future, Future, Future, Future, Future], + "mlp.experts.*.up_proj.weight": [Future, Future, Future, Future, Future, Future, Future, Future], + } + ``` + The first op, `MergeModulelist`, would stack the 8 tensors of each source but will not "rename" them into the fused target name. + The second op, `Concatenate`, would then rename the fused tensor into the final target name. + + If we want to split `qkv` we would have: + ```python + collected_tensors = { + "attention.qkv.weight": [Future], # here its the full SOURCE keys. + } + ``` + The `Chunk` operation would then split the single tensor into 3 and rename them accordingly and update the collected tensors to: + ```python + realized_values = { + "attention.q.weight": [Tensor], + "attention.k.weight": [Tensor], + "attention.v.weight": [Tensor], + } + ``` + + Now that this is done, we can quantize / dequantize accordingly the collected_tensors. + + For some quantization methods, we need to gather different tensors: + + ```python + # for "medmekk/llama-3.2-1b-float8-torchao" + WeightConverter( + source_keys=[":qdata", ":scale"], + target_keys="", + operations=[TorchaoDeserialize()], + ) + ``` + This will collect all tensors that have the same prefix, but end with `:qdata` or `:scale`. This will give us: + ```python + all_weight_mapping = { + "model.layers.13.self_attn.o_proj.weight": WeightConverter( + source_keys=[":qdata", ":scale"], + target_keys="", + operations=[TorchaoDeserialize()], + collected_tensors={ + ":qdata": [Future], + ":scale": [Future], + }, + ... + } + ``` + """ prefix = model.base_model_prefix - tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key} - device_map = device_map or {"": "cpu"} # {exact_target_key: device} + tp_plan = tp_plan or {} + device_map = device_map or {"": "cpu"} device_map_regex = re.compile( "|".join(rf"({k})" for k in sorted(device_map.keys(), key=lambda x: x.count("."), reverse=True)) ) - dtype_plan = dtype_plan or {} # {glob_pattern: dtype} - weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter} + dtype_plan = dtype_plan or {} + weight_mapping = weight_mapping or [] meta_model_state_dict = model.state_dict() missing_keys = set(meta_model_state_dict.keys()) @@ -450,135 +623,125 @@ def convert_and_load_state_dict_in_model( # Global thread_pool thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS) - _patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping])) - source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys} - weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) - tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys())) - dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys())) + renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] + converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)] + if hf_quantizer: + # We will add the quantizer's deserialization WeightConverter here. + pass + + param_name_to_load: dict[str, Union[WeightRenaming | WeightConverter]] = {} + + # build '(?P.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'} + # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched. + rename_alt, _, rename_by_group = build_glob_alternation(renamings) + if converters != []: + weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters) + if tp_plan != {}: + tp_plan_alt, tp_plan_by_group_name, _ = build_glob_alternation(list(tp_plan.keys())) + if dtype_plan != {}: + dtype_policy_alt, dtype_policy_by_group_name, _ = build_glob_alternation(list(dtype_plan.keys())) + + pattern_to_converter = {k: converter for converter in converters for k in converter.source_keys} state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) - # 1. Create the conversion entries - by_conversion_pattern: dict[str, ConversionEntry] = {} for original_key, tensor in state_dict: - matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name) - if matched_pattern is not None: - converter = source_to_target[matched_pattern] # TODO make sure its the ref - sub_with_extractor = partial(re.sub, matched_pattern.replace("*", r"(\d+)"), string=original_key) - entry_key = "|".join(converter.target_keys) - target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys])) - entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter)) - converter_key = sub_with_extractor(matched_pattern) - else: - converter = WeightConverter(original_key) - converter_key = entry_key = target_key = original_key - entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter)) - - _dtype = dtype - new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10) - for t in target_key.split("|"): - if t.startswith(prefix) and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", t, count=1)) is not None: - t = re.sub(f"^{prefix}.", "", t, count=1) - elif meta_model_state_dict.get(f"{prefix}.{t}") is not None: - t = f"{prefix}.{t}" - new_target_key.append(t) - empty_param = meta_model_state_dict.get(t) - # If it does not exist, it's unexpected - if empty_param is None: - unexpected_keys.add(t) - continue + # 1. apply all renamings + renamed_key = rename_alt.sub(lambda m: repl(m, rename_by_group), original_key).replace("\\", "") + + # 2. apply 1 weight conversion on the key + matched_pattern = weight_pattern_alt.search(renamed_key) if converters != [] else None + if matched_pattern is not None: # we have a converter to apply + renamed_key = weight_pattern_alt.sub(lambda m: repl(m, tgt_group_to_glob), renamed_key).replace("\\", "") + + # 3. check if we need to add or remove prefix + if ( + renamed_key.startswith(prefix) + and meta_model_state_dict.get(re.sub(f"^{prefix}.", "", renamed_key, count=1)) is not None + ): + renamed_key = re.sub(f"^{prefix}.", "", renamed_key, count=1) + elif meta_model_state_dict.get(f"{prefix}.{renamed_key}") is not None: + renamed_key = f"{prefix}.{renamed_key}" + + # 4. finally, collect the tensor into the proper converter + if renamed_key in missing_keys: + empty_param = meta_model_state_dict.get(renamed_key) + if matched_pattern: + new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]]) + # each target key gets its own converter instance + mapping = param_name_to_load.setdefault(renamed_key, new_converter) + source_pattern = src_group_to_glob[matched_pattern.lastgroup] + else: + mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(renamed_key, renamed_key)) + source_pattern = renamed_key + + # 5. Handle dtype casting + if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, renamed_key): + mapping.quantization_operation = hf_quantizer.get_quantize_ops() - if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t): - converter.quantization_operation = hf_quantizer.get_quantize_ops() _dtype = dtype - matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name) - if matched_dtype_pattern is not None: - _dtype = dtype_plan[matched_dtype_pattern] - elif empty_param.dtype != _dtype: - _dtype = empty_param.dtype - - first_target_key = new_target_key[0] - target_key = "|".join(new_target_key) - - future = None - if device_mesh: - if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name): - empty_param = meta_model_state_dict.get(first_target_key) - if getattr(converter, "distributed_operation", {}) is None: - tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ - converter.distributed_operation = tp_layer( - device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone() + if dtype_plan != {} and dtype_policy_alt.search(renamed_key): + matched_dtype_pattern = dtype_policy_alt.search(renamed_key) + if matched_dtype_pattern is not None: + _dtype = dtype_plan[matched_dtype_pattern.group()] + elif empty_param is not None and empty_param.dtype != _dtype: + _dtype = empty_param.dtype # usually correct when initializing + + # 6. Handle TP sharding or device_map placement -> scheduled materialization + future = None + if device_mesh: + if matched_tp_pattern := tp_plan_alt.search(renamed_key): + matched_tp_pattern = tp_plan_by_group_name[matched_tp_pattern.lastgroup] + if getattr(mapping, "distributed_operation", None) is None: + tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ + mapping.distributed_operation = tp_layer( + device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone() + ) + shard_index = len(mapping.collected_tensors[source_pattern]) + future = spawn_tp_materialize( + thread_pool, + tensor, + _dtype, + mapping.distributed_operation, + shard_index, ) - # VERY IMPORTANT: this tells us wether we collected stuffs or not. - shard_index = len(entry.collected_tensors[target_key].get(converter_key, [])) - future = spawn_tp_materialize( - thread_pool, - tensor, - _dtype, - converter.distributed_operation, - shard_index, - ) - if future is None: # If not TP, async materialize the tensors. TODO handle disk offload? - device_match = device_map_regex.match(first_target_key) - param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") - future = spawn_materialize(thread_pool, tensor, param_device, _dtype) - entry.collected_tensors[target_key].setdefault(converter_key, []).append(future) - - # 2. Actually convert the ckpt - inverse_converters = {} - keys = list(by_conversion_pattern.keys()) - - with logging.tqdm(total=len(keys), desc="Loading weights") as pbar: - for key in keys[::-1]: # revert to process simple keys first - group = by_conversion_pattern.pop(key) - converter = group.weight_converter - operations = converter.operations if isinstance(converter.operations, list) else [converter.operations] - for layer_name, tensors_for_this_layer in group.collected_tensors.items(): - pbar.update(1) - pbar.set_postfix({"Materializing param": layer_name}) - pbar.refresh() - concrete_target_keys = layer_name.split("|") - try: - if bool(set(concrete_target_keys) - unexpected_keys): - with log_to_misc(layer_name, misc): - values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()] - - for op in operations: - with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations): - values = op.convert(values, model.config) - - values = [values] if not isinstance(values, list) else values - with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations): - realized_value = { - k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys - } - - for k in list(realized_value.keys()).copy(): - if op := converter.quantization_operation: - with log_to_misc(layer_name, misc, op=op): - realized_value.update( - op.convert({k: realized_value.pop(k)}, model=model, missing_keys=missing_keys) - ) - - for k, output_value in realized_value.items(): - for src in converter.source_keys: # what should happen to k when we meet k at saving - inverse_converters[k] = {src: converter} - set_param_for_module( - model, - k, - output_value, - mismatch_keys, - missing_keys, - misc, - converter.distributed_operation, - hf_quantizer, - ) - - except SkipLayer: - continue - del group - - model.inverse_converters = inverse_converters + if future is None: # TODO handle disk offload + device_match = device_map_regex.match(renamed_key) + param_device = device_map[device_match.group()] if device_match else device_map.get("", "cpu") + future = spawn_materialize(thread_pool, tensor, param_device, _dtype) + + mapping.add_tensor(renamed_key, original_key, source_pattern, future) + elif matched_pattern: # add all target keys as unexpected + mapping = pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]] + for k in mapping.target_keys: + unexpected_keys.add(renamed_key.replace(mapping.target_keys[0], k)) + else: + unexpected_keys.add(renamed_key) + + total_entries = len(param_name_to_load) + with logging.tqdm(total=total_entries, desc="Loading weights") as pbar: + for layer_name, mapping in param_name_to_load.items(): + pbar.update(1) + pbar.set_postfix({"Materializing param": layer_name}) + pbar.refresh() + try: + realized_value, misc = mapping.convert( + layer_name, config=model.config, quantizer=hf_quantizer, missing_keys=missing_keys + ) + for k, output_value in realized_value.items(): + set_param_for_module( + model, + k, + output_value, + mismatch_keys, + missing_keys, + misc, + unexpected_keys, + mapping.distributed_operation, + hf_quantizer, + ) + except SkipLayer: + continue thread_pool.shutdown(wait=False) return missing_keys, unexpected_keys, mismatch_keys, misc diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e67c3222f341..fe52822e1f92 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -49,6 +49,7 @@ from .conversion_mapping import get_checkpoint_conversion_mapping from .core_model_loading import ( WeightConverter, + WeightRenaming, convert_and_load_state_dict_in_model, revert_weight_conversion, ) @@ -3819,14 +3820,16 @@ def from_pretrained( config, quantization_config, dtype, device_map, weights_only, user_agent ) - weight_conversions: Optional[list[WeightConverter]] = None + weight_conversions: Optional[list[WeightConverter | WeightRenaming]] = None model_type = getattr(config, "model_type", None) if model_type is not None: weight_conversions = get_checkpoint_conversion_mapping(model_type) if weight_conversions is None: weight_conversions = get_checkpoint_conversion_mapping("legacy") if key_mapping is not None: - weight_conversions.extend([WeightConverter(k, v) for k, v in key_mapping.items()]) + weight_conversions.extend( + [WeightRenaming(source_keys=k, target_keys=v) for k, v in key_mapping.items()] + ) if gguf_file: if hf_quantizer is not None: @@ -3997,7 +4000,7 @@ def _load_pretrained_model( hf_quantizer: Optional[HfQuantizer] = None, device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, weights_only: bool = True, - weight_mapping: Optional[Sequence[WeightConverter]] = None, + weight_mapping: Optional[Sequence[WeightConverter | WeightRenaming]] = None, ): is_quantized = hf_quantizer is not None is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index ac02aad2c608..023b9fd0596d 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -24,9 +24,10 @@ MergeModulelist, PermuteForRope, WeightConverter, - build_glob_alt, + WeightRenaming, + build_glob_alternation, convert_and_load_state_dict_in_model, - match_glob, + repl, ) from transformers.utils.import_utils import is_triton_available @@ -38,40 +39,47 @@ def setUp(self): "model.layers.*.self_attn.q_proj.weight", "embed_tokens.weight", ] - self.alt_digits, self.map_digits = build_glob_alt(self.weight_globs_digits) + self.alt_digits, self.map_digits, _ = build_glob_alternation(self.weight_globs_digits) self.weight_globs_any = [ "model.layers.*.mlp.gate_up_proj.weight", "model.layers.*.self_attn.q_proj.weight", "embed_tokens.weight", ] - self.alt_any, self.map_any = build_glob_alt(self.weight_globs_any) + self.alt_any, self.map_any, _ = build_glob_alternation(self.weight_globs_any) + + @staticmethod + def _match_glob(key, alt, mapping): + matched = alt.search(key) + return mapping.get(matched.lastgroup) if matched else None def test_exact_match(self): - self.assertEqual(match_glob("embed_tokens.weight", self.alt_digits, self.map_digits), "embed_tokens.weight") + self.assertEqual( + self._match_glob("embed_tokens.weight", self.alt_digits, self.map_digits), "embed_tokens.weight" + ) def test_digits_only_star_accepts_digits(self): self.assertEqual( - match_glob("model.layers.0.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits), + self._match_glob("model.layers.0.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits), "model.layers.*.mlp.gate_up_proj.weight", ) self.assertEqual( - match_glob("model.layers.12.self_attn.q_proj.weight", self.alt_digits, self.map_digits), + self._match_glob("model.layers.12.self_attn.q_proj.weight", self.alt_digits, self.map_digits), "model.layers.*.self_attn.q_proj.weight", ) def test_anychar_star_accepts_nondigits(self): self.assertEqual( - match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_any, self.map_any), + self._match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_any, self.map_any), "model.layers.*.mlp.gate_up_proj.weight", ) self.assertEqual( - match_glob("model.layers.00x.mlp.gate_up_proj.weight", self.alt_any, self.map_any), + self._match_glob("model.layers.00x.mlp.gate_up_proj.weight", self.alt_any, self.map_any), "model.layers.*.mlp.gate_up_proj.weight", ) def test_no_match(self): - self.assertIsNone(match_glob("model.layers.0.mlp.up_proj.weight", self.alt_digits, self.map_digits)) + self.assertIsNone(self._match_glob("model.layers.0.mlp.up_proj.weight", self.alt_digits, self.map_digits)) def test_leftmost_alternative_wins_for_overlapping_patterns(self): # Overlapping patterns: both could match; ensure leftmost wins @@ -79,11 +87,11 @@ def test_leftmost_alternative_wins_for_overlapping_patterns(self): "model.layers.*.mlp.*.weight", # broader (first) "model.layers.0.mlp.gate_up_proj.weight", # more specific (second) ] - alt, mapping = build_glob_alt(globs) + alt, mapping, _ = build_glob_alternation(globs) # Both branches match; Python's regex picks the leftmost alternative → index 0 self.assertEqual( - match_glob("model.layers.0.mlp.gate_up_proj.weight", alt, mapping), "model.layers.*.mlp.*.weight" + self._match_glob("model.layers.0.mlp.gate_up_proj.weight", alt, mapping), "model.layers.*.mlp.*.weight" ) def test_multiple_patterns_same_prefix(self): @@ -92,34 +100,59 @@ def test_multiple_patterns_same_prefix(self): "model.layers.*.self_attn.k_proj.weight", "model.layers.*.self_attn.v_proj.weight", ] - alt, mapping = build_glob_alt( + alt, mapping, _ = build_glob_alternation( globs, ) self.assertEqual( - match_glob("model.layers.3.self_attn.q_proj.weight", alt, mapping), + self._match_glob("model.layers.3.self_attn.q_proj.weight", alt, mapping), "model.layers.*.self_attn.q_proj.weight", ) self.assertEqual( - match_glob("model.layers.3.self_attn.k_proj.weight", alt, mapping), + self._match_glob("model.layers.3.self_attn.k_proj.weight", alt, mapping), "model.layers.*.self_attn.k_proj.weight", ) self.assertEqual( - match_glob("model.layers.3.self_attn.v_proj.weight", alt, mapping), + self._match_glob("model.layers.3.self_attn.v_proj.weight", alt, mapping), "model.layers.*.self_attn.v_proj.weight", ) def test_anchor_full_match_only(self): - self.assertIsNotNone(match_glob("model.layers.0.mlp.gate_up_proj.weight.bar", self.alt_any, self.map_any)) + self.assertIsNotNone( + self._match_glob("model.layers.0.mlp.gate_up_proj.weight.bar", self.alt_any, self.map_any) + ) def test_large_batch_performance_smoke(self): # Not a perf benchmark, but ensures building and matching a larger alternation is OK globs = [f"model.layers.*.mlp.block{i}.weight" for i in range(200)] - alt, mapping = build_glob_alt( - globs, - ) + alt, mapping, _ = build_glob_alternation(globs) key = "model.layers.123.mlp.block57.weight" - self.assertEqual(match_glob(key, alt, mapping), "model.layers.*.mlp.block57.weight") + self.assertEqual(self._match_glob(key, alt, mapping), "model.layers.*.mlp.block57.weight") + + def test_sub_key_rewrites_targets(self): + renamings = [ + WeightRenaming("block_sparse_moe.experts.*.w1.weight", "mlp.experts.gate_up_proj"), + WeightRenaming("block_sparse_moe.experts.*.w2.weight", "mlp.experts.down_proj"), + WeightRenaming("model.language_model.*", "language_model"), + ] + rename_alt, _, rename_by_group = build_glob_alternation(renamings) + + def rename(original_key: str) -> str: + return rename_alt.sub(lambda m: repl(m, rename_by_group), original_key).replace("\\", "") + + self.assertEqual(rename("foo.block_sparse_moe.experts.3.w1.weight"), "foo.mlp.experts.gate_up_proj") + self.assertEqual(rename("foo.block_sparse_moe.experts.3.w2.weight"), "foo.mlp.experts.down_proj") + self.assertEqual(rename("model.language_model.lm_head.weight"), "language_model") + + def test_sub_key_no_match_returns_original(self): + renamings = [ + WeightRenaming("block_sparse_moe.experts.*.w1.weight", "*.mlp.experts.gate_up_proj"), + ] + rename_alt, _, rename_by_group = build_glob_alternation(renamings) + + key = "unrelated.key" + renamed_key = rename_alt.sub(lambda m: repl(m, rename_by_group), key).replace("\\", "") + self.assertEqual(renamed_key, key) class DummyParamModule(nn.Module): @@ -215,7 +248,7 @@ def test_moe_and_qkv_conversion(self): ], operations=[Chunk(dim=0, chunks=3)], ), - WeightConverter("mlp.w2.weight", "mlp.down_proj.weight"), + WeightRenaming("mlp.w2.weight", "mlp.down_proj.weight"), ] missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( model, state_dict, weight_mapping, tp_plan=None, hf_quantizer=None diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index ec6d8745cfd0..936b41ea23d7 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -628,7 +628,9 @@ def test_model_from_config_dtype_composite(self): # but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"] model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, dtype="auto") - self.assertEqual(model.model.language_model.dtype, torch.float32) + self.assertEqual( + model.model.language_model.dtype, torch.float32 + ) # remember config says float32 for text_config self.assertEqual(model.model.vision_tower.dtype, torch.bfloat16) self.assertEqual(model.model.multi_modal_projector.linear_1.weight.dtype, torch.float32) self.assertIsInstance(model.config.dtype, torch.dtype)