-
Notifications
You must be signed in to change notification settings - Fork 12.3k
Granite Four #13550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Granite Four #13550
Changes from all commits
271104c
8db1e4d
0028010
0c8b3b2
d66849f
a09db95
c460ff1
b6fafd1
b7ec12e
3b57b55
7e13f19
cbc743e
0fd13e9
61a88a1
ea2e63e
fc59407
181dadf
3a414b0
4e4c41e
3587a94
5d3c7b9
72eea49
18d1c14
61200ef
eb589d5
8fb57ac
17f6c1e
fee3c1d
6840ac0
372482d
43d8d4b
ff794f5
33425a7
10c3c41
9b38f8b
1f0fea7
dceff23
2bfe9de
aff9692
e04910d
fa358e7
38913dc
bc320ef
fcb889c
a03e32a
9d3f44d
5f62db7
375de5b
4bb4b22
63ac36b
0e601ca
273e7a4
7d6cb36
2c77d79
87b97d0
03d0e6e
7a351ab
8b15bc6
5b8ec2b
62b09b3
124c222
038d958
805512a
7d16e1b
3bc7103
8d8f065
b4e9c59
8006f3b
691698e
e3fe612
1ee6c48
c9ecf62
35d06fa
cf4f0a4
6def5cd
791998b
94c3d53
929fe85
d55b0d0
e94f393
9864bfc
2fa5f2c
757aa62
0b6f6be
a42f239
f8c7cae
830e554
afdb669
28881af
c43259b
26816fd
b901947
fc56325
b3453dc
13e8d3d
b435dce
3d4c36b
0d28bf6
ed6216a
a6f9f90
de4d870
7c2b0b8
915f1e3
d0d3723
2ca3416
3c22e1d
08493bf
ed15012
40e2346
a9dcc84
dc1d109
fdc9a8d
2b263e6
66a7a43
8cb4df5
f13f5bc
6cac586
28361c4
bb2bb37
8f9b513
eaec9c6
1085cf9
b6d772f
bb590f2
1c21a04
2bcaf64
908e655
d7f4d73
e100153
4b5f673
0796726
f7fa1b1
4682e21
20f8e43
07c252f
2e1431f
5c32e80
f9d6dd1
f716358
65f3d9e
257d436
b0b280e
db5ff0c
2f39cd7
f7c7a92
8a1ea3e
12c50f1
0b84bd5
a60a24b
1334c71
f8b81c0
0583d95
7f3955a
452207f
fa159cf
44cda75
4d6a179
d1d54d8
fe34d0e
6875697
8dd7f97
2b36420
dcf51e0
5b44f4e
4e9fef1
d02d3dd
afc1738
d7d5b01
f43a8dc
63f1ed8
f1485d2
04883fc
e53632b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4890,6 +4890,9 @@ def __init__(self, dir_model: Path, *args, **kwargs): | |
with open(dir_model / "config.json", "r", encoding="utf-8") as f: | ||
hparams = json.load(f) | ||
super().__init__(dir_model, *args, hparams=hparams, **kwargs) | ||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) | ||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model | ||
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1 | ||
|
||
def set_vocab(self): | ||
vocab_size = self.hparams["vocab_size"] | ||
|
@@ -4912,32 +4915,29 @@ def set_vocab(self): | |
self._set_vocab_builtin("gpt-neox", vocab_size) | ||
|
||
def set_gguf_parameters(self): | ||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) | ||
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 | ||
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 | ||
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64 | ||
n_group = self.find_hparam(["n_groups"], optional=True) or 1 | ||
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 | ||
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 | ||
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64 | ||
|
||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 | ||
|
||
# Fail early for models which don't have a block expansion factor of 2 | ||
# TODO: does this really matter? | ||
# skip the assertion for FalconH1 Model | ||
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1: | ||
assert d_inner == 2 * d_model | ||
assert d_inner % head_dim == 0 | ||
assert self.d_inner == 2 * self.d_model | ||
assert self.d_inner % head_dim == 0 | ||
|
||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default | ||
self.gguf_writer.add_embedding_length(d_model) | ||
self.gguf_writer.add_embedding_length(self.d_model) | ||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading | ||
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading | ||
self.gguf_writer.add_block_count(self.block_count) | ||
self.gguf_writer.add_ssm_conv_kernel(d_conv) | ||
self.gguf_writer.add_ssm_inner_size(d_inner) | ||
self.gguf_writer.add_ssm_inner_size(self.d_inner) | ||
self.gguf_writer.add_ssm_state_size(d_state) | ||
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) | ||
self.gguf_writer.add_ssm_group_count(n_group) | ||
self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim) | ||
self.gguf_writer.add_ssm_group_count(self.n_group) | ||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) | ||
self.gguf_writer.add_file_type(self.ftype) | ||
|
||
|
@@ -4962,10 +4962,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter | |
# (D is also unsqueezed, but for more straightforward broadcast internally) | ||
data_torch = data_torch.reshape((*data_torch.shape, 1)) | ||
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): | ||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) | ||
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model | ||
n_group = self.hparams.get("n_groups", 1) | ||
data_torch = data_torch.reshape((n_group, d_inner // n_group)) | ||
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group)) | ||
|
||
if name.endswith(".A_log"): | ||
logger.debug("A_log --> A ==> " + new_name) | ||
|
@@ -6452,18 +6449,148 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter | |
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), | ||
] | ||
|
||
has_experts = bool(self.hparams.get('num_local_experts')) | ||
|
||
if name.endswith("shared_mlp.input_linear.weight"): | ||
ffn_dim = self.hparams["shared_intermediate_size"] | ||
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" | ||
gate, up = data_torch.split(ffn_dim, dim=-2) | ||
if has_experts: | ||
return [ | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), | ||
] | ||
return [ | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), gate), | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), up), | ||
] | ||
|
||
if not has_experts and name.endswith("shared_mlp.output_linear.weight"): | ||
return [ | ||
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), data_torch) | ||
] | ||
|
||
return super().modify_tensors(data_torch, name, bid) | ||
|
||
|
||
@ModelBase.register("GraniteMoeHybridForCausalLM", "BambaForCausalLM") | ||
class GraniteHybridModel(Mamba2Model, GraniteMoeModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Multiple inheritance in Python works by using methods from the first class where it's present. (At least according to https://stackoverflow.com/questions/3277367/how-does-pythons-super-work-with-multiple-inheritance) In this case, it means methods from The resolution order seems to be $ python3
>>> import convert_hf_to_gguf
>>> convert_hf_to_gguf.GraniteHybridModel.__mro__
(<class 'convert_hf_to_gguf.GraniteHybridModel'>, <class 'convert_hf_to_gguf.Mamba2Model'>, <class 'convert_hf_to_gguf.GraniteMoeModel'>, <class 'convert_hf_to_gguf.GraniteModel'>, <class 'convert_hf_to_gguf.LlamaModel'>, <class 'convert_hf_to_gguf.TextModel'>, <class 'convert_hf_to_gguf.ModelBase'>, <class 'object'>) (Noting this here, because I had to check how that works, not because there's a problem). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that's right. I like the suggestions below to be more explicit. |
||
"""GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM | ||
layers and optionally uses MoE w/ a shared expert""" | ||
model_arch = gguf.MODEL_ARCH.GRANITE_HYBRID | ||
undo_permute = True | ||
|
||
def __init__(self, *args, **kwargs): | ||
|
||
# Hybrid mamba models use a prefix for the mamba-specific params. | ||
# TODO: Extend this if the prefix(es) need to be configurable | ||
self.hparam_prefixes = ["mamba"] | ||
|
||
super().__init__(*args, **kwargs) | ||
|
||
# Lists of which layers use ssm vs attention | ||
self._attn_layers = self.get_attn_layers() | ||
self._ssm_layers = [ | ||
i for i in range(self.block_count) | ||
if i not in self._attn_layers | ||
] | ||
|
||
# n_group and d_inner are used during reshape_tensors for mamba2 | ||
self.d_model = self.find_hparam(["hidden_size", "d_model"]) | ||
self.n_group = self.find_hparam(["n_groups"]) | ||
self.d_inner = self.find_hparam(["expand"]) * self.d_model | ||
|
||
def get_attn_layers(self): | ||
# Explicit list of layer type names | ||
if layer_types := self.hparams.get("layer_types"): | ||
return [ | ||
i for i, typ in enumerate(layer_types) | ||
if typ == "attention" | ||
] | ||
|
||
# Layer types indicated by index or period | ||
attn_layers = self.hparams.get("attn_layer_indices", []) | ||
if not attn_layers: | ||
attn_period = self.hparams.get("attn_layer_period") | ||
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period" | ||
attn_offset = self.hparams.get("attn_layer_offset") | ||
assert attn_offset is not None, "No attention layer offset set with attn_layer_period" | ||
attn_layers = [ | ||
i for i in range(self.block_count) | ||
if i % attn_period == attn_offset | ||
] | ||
return attn_layers | ||
|
||
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: | ||
prefixed = [] | ||
for pfx in self.hparam_prefixes: | ||
prefixed.extend( | ||
"_".join([pfx, k]) | ||
for k in keys | ||
) | ||
keys = list(keys) + prefixed | ||
return Mamba2Model.find_hparam(self, keys, *args, **kwargs) | ||
|
||
def modify_tensors( | ||
self, data_torch: Tensor, name: str, bid: int | None | ||
) -> Iterable[tuple[str, Tensor]]: | ||
if ( | ||
name.endswith("block_sparse_moe.input_linear.weight") | ||
or "shared_mlp" in name | ||
): | ||
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) | ||
|
||
# Determine whether this is a mamba layer or an attention layer | ||
if bid in self._ssm_layers: | ||
return Mamba2Model.modify_tensors(self, data_torch, name, bid) | ||
elif bid in self._attn_layers: | ||
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid) | ||
return [(self.map_tensor_name(name), data_torch)] | ||
|
||
def set_gguf_parameters(self): | ||
"""This method merges params from both parents and some that are | ||
specific to this model. The result is some duplication of how the params | ||
get set. The following warnings are expected during conversion: | ||
WARNING:Duplicated key name 'granitehybrid.attention.head_count_kv' | ||
WARNING:Duplicated key name 'granitehybrid.context_length' | ||
""" | ||
GraniteMoeModel.set_gguf_parameters(self) | ||
|
||
## Mamba mixer params ## | ||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) | ||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"])) | ||
self.gguf_writer.add_ssm_group_count(self.n_group) | ||
self.gguf_writer.add_ssm_inner_size(self.d_inner) | ||
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used | ||
# in llama.cpp | ||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) | ||
|
||
## Attention params ## | ||
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) | ||
head_count_kv_vec = [ | ||
head_count_kv if i in self._attn_layers else 0 for i in range(self.block_count) | ||
] | ||
if rope_dim := self.hparams.get("attn_rotary_emb"): | ||
self.gguf_writer.add_rope_dimension_count(rope_dim) | ||
self.gguf_writer.add_head_count_kv(head_count_kv_vec) | ||
|
||
## If Bamba, use rope, otherwise don't | ||
use_rope = "BambaForCausalLM" in self.hparams["architectures"] | ||
self.gguf_writer.add_rope_scaling_finetuned(use_rope) | ||
if not use_rope: | ||
self.gguf_writer.add_context_length(2**20) | ||
|
||
## Validation ## | ||
d_head = self.find_hparam(["d_head"], optional=True) or 64 | ||
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" | ||
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" | ||
|
||
def set_vocab(self): | ||
self.hparams["pad_vocab_size_multiple"] = 8 | ||
Mamba2Model.set_vocab(self) | ||
|
||
|
||
@ModelBase.register("BailingMoeForCausalLM") | ||
class BailingMoeModel(TextModel): | ||
model_arch = gguf.MODEL_ARCH.BAILINGMOE | ||
|
@@ -6687,7 +6814,7 @@ def __init__(self, *args, **kwargs): | |
# Use Llama conversion for attention | ||
self._transformer_model_class = LlamaModel | ||
|
||
# n_group and d_inner are used during reshape_tensors for mamaba2 | ||
# n_group and d_inner are used during reshape_tensors for mamba2 | ||
self.n_group = self.find_hparam(["n_groups"]) | ||
self.d_inner = self.find_hparam(["mamba_d_ssm"]) | ||
self.d_head = self.find_hparam(["d_head"]) | ||
|
Uh oh!
There was an error while loading. Please reload this page.