Skip to content

Commit b63e6e0

Browse files
authored
Update conversion mapping to separate renaming from converting (#42254)
* inital commit * up * update unexpected later on * fix * update * simplify our lives * isolate a bit more * fixup * small nits * style * nit * fix common cases * fix post merge * bnb needs missing keys * small fix * bettrer documentation * no veradict + base class * rake review comments * take all comments * fix super init * update doc to be more real * small nits * nits * fix dtype * fix dtype issue * remove one unused function * cleanup and nits * up * should be the final fix! * fixup
1 parent 7b84f72 commit b63e6e0

File tree

5 files changed

+525
-329
lines changed

5 files changed

+525
-329
lines changed

src/transformers/conversion_mapping.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from copy import deepcopy
1717

18-
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
18+
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter, WeightRenaming
1919
from .utils import is_torch_available
2020

2121

@@ -26,6 +26,7 @@
2626
def _build_checkpoint_conversion_mapping():
2727
mapping = {
2828
"mixtral": [
29+
WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"),
2930
WeightConverter(
3031
source_keys=[
3132
"block_sparse_moe.experts.*.w1.weight",
@@ -50,12 +51,6 @@ def _build_checkpoint_conversion_mapping():
5051
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
5152
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
5253
),
53-
# WeightConverter(
54-
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
55-
# "self_attn.qkv_proj",
56-
# operations=[Concatenate(dim=0)], # more like stack?
57-
# ),
58-
WeightConverter("*.block_sparse_moe.", "*.mlp."),
5954
],
6055
"qwen2_moe": [
6156
WeightConverter(
@@ -73,34 +68,34 @@ def _build_checkpoint_conversion_mapping():
7368
),
7469
],
7570
"legacy": [
76-
WeightConverter(
71+
WeightRenaming(
7772
source_keys="LayerNorm.gamma",
7873
target_keys="LayerNorm.weight",
7974
),
80-
WeightConverter(
75+
WeightRenaming(
8176
source_keys="LayerNorm.beta",
8277
target_keys="LayerNorm.bias",
8378
),
8479
],
8580
}
8681
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
8782
mapping["legacy"] += [
88-
WeightConverter(
83+
WeightRenaming(
8984
source_keys="weight_g",
9085
target_keys="parametrizations.weight.original0",
9186
),
92-
WeightConverter(
87+
WeightRenaming(
9388
source_keys="weight_v",
9489
target_keys="parametrizations.weight.original1",
9590
),
9691
]
9792
else:
9893
mapping["legacy"] += [
99-
WeightConverter(
94+
WeightRenaming(
10095
source_keys="parametrizations.weight.original0",
10196
target_keys="weight_g",
10297
),
103-
WeightConverter(
98+
WeightRenaming(
10499
source_keys="parametrizations.weight.original1",
105100
target_keys="weight_v",
106101
),

0 commit comments

Comments
 (0)