Skip to content

Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library.#625

Open
quic-dhirajku wants to merge 6 commits intoquic:mainfrom
quic-dhirajku:replicate_kv_heads_transform
Open

Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library.#625
quic-dhirajku wants to merge 6 commits intoquic:mainfrom
quic-dhirajku:replicate_kv_heads_transform

Conversation

@quic-dhirajku
Copy link
Copy Markdown
Contributor

The Transform enables KV-head replication for CausalLMs and VLMs as well.
The feature is enabled by passing n_kv_head_repeat parameter during initialization of the QEff wrapper class for the corresponding model.
n_kv_head_repeat param acts as the multiplier for the number of repeats to be done to original count of KV heads. This operation also causes the config and the hash params of the respective model to update the num_key_value_heads parameter and add a paramter orig_kv_heads to it; It allows us to export the same model with different number of kv_heads without causing a hash conflict.
Added tests for both CausalLMs and VLMs with this functionality to compare outputs of Pytorch HF model and the AIC model. Two new optional paramters n_kv_head_repeat and test_kv_replicate are added for testing purpose. Setting test_kv_replicate to True performs a KV-head replication of every model such that the number of KV-heads and attention heads becomes equal. This was done to ensure tests don't fail due to misalignment issues when we simply repeat num_key_value_heads twice and thus cause a divisibility error on hum_heads.

@quic-rishinr
Copy link
Copy Markdown
Contributor

quic-rishinr commented Nov 20, 2025

@ochougul @quic-amitraj please review

…dule within Qefficient library.

The Transform enables KV-head replication for CausalLMs and VLMs as well.
The feature is enabled by passing n_kv_head_repeat parameter during initialization of the QEff wrapper class for the corresponding model.
n_kv_head_repeat param acts as the multiplier for the number of repeats to be done to original count of KV heads.
This operation also causes the config and the hash params of the respective model to update the num_key_value_heads parameter and add a paramter orig_kv_heads to it; It allows us to export the same model with different number of kv_heads without causing a hash conflict.
Also added tests for both CausalLMs and VLMs with this functionality to compare outputs of Pytorch HF model and the AIC model.
Two new optional paramters n_kv_head_repeat and test_kv_replicate are added for testing purpose.
Setting test_kv_replicate to True performs a KV-head replication of every model such that the number of KV-heads and attention heads becomes equal. This was done to ensure tests don't fail due to misalignment issues when we simply repeat num_key_value_heads twice and thus cause a divisibility error on hum_heads.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
… Doing so would prevent any issues during Transforms when we don't wish to apply it.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
…orm.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
@quic-dhirajku quic-dhirajku force-pushed the replicate_kv_heads_transform branch from 870cc8d to 08032e1 Compare January 21, 2026 05:35
Comment thread tests/transformers/models/image_text_to_text/test_image_text_to_text_models.py Outdated
…changes to repeat Bias factor appropriately on quantized layers.

Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Copy link
Copy Markdown
Contributor

@ochougul ochougul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Write a test that makes sure onnx hash is different when different number of kv heads are passed.

Comment on lines +2498 to +2499
# InternVL causes an error if we pass the num_kv_heads_repeat parameter
num_kv_heads_repeat = kwargs.pop("num_kv_heads_repeat", 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

Comment on lines +2398 to +2400
self.model, replicate_kv_transformed = ReplicateKVHeadTransform.apply(self.model, **kwargs)
if replicate_kv_transformed:
self.hash_params["config"] = model.config.to_diff_dict()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better add it to _pytorch_transforms if we are always going to call it.

Comment on lines +893 to +895
if replicate_kv_transformed:
self.lang_model.hash_params["config"] = model.config.to_diff_dict()
self.vision_model.hash_params["config"] = model.config.to_diff_dict()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we already dump config somewhere? in _generate_export_hash?
You can just always add repeat_kv_heads value to self.hash_params which will be 1 if nothing is passed.

}


class ReplicateKVHeadTransform:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this inherit ModuleMutatorTransform
You may need to implement mutate method which is similar to apply here

@quic-hemagnih
Copy link
Copy Markdown
Contributor

@quic-dhirajku Please take this PR post 595

layer.bias.data = torch.repeat_interleave(
layer.bias.data.view(orig_kv_heads, head_dim), repeat, 0
).view(new_kv_heads * head_dim)
if layer.bias is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lines 782-785 are repeated here, please remove

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants