From 7c7bee2f95fe54944b13a7c57197c2e3f7427943 Mon Sep 17 00:00:00 2001 From: adil-a Date: Sat, 7 Mar 2026 00:29:27 +0000 Subject: [PATCH 1/2] fix: biencoder PEFT adapter key remapping for merge_lora compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The biencoder wraps the base model as `lm_q`, so PEFT adapter weights and target modules were saved with `lm_q.` prefix. However, the standalone base model (LlamaBidirectionalModel extends LlamaModel) uses bare module names like `layers.0.self_attn.q_proj` — no `model.` prefix. This caused merge_lora.py to fail because PEFT could not match the target modules or weight keys against the base model. Changes: - state_dict_adapter: fix PEFT key path in to_hf to strip `lm_q.` without adding `model.` (base_model.model.lm_q.X → base_model.model.X) - state_dict_adapter: fix from_hf to handle new PEFT key format (base_model.model.X → base_model.model.lm_q.X) - addons: strip `lm_q.` prefix from target modules for biencoder so adapter_config.json is compatible with the standalone HF base model Co-Authored-By: Claude Opus 4.6 --- .../components/checkpoint/addons.py | 18 +++++++++ .../models/biencoder/state_dict_adapter.py | 34 ++++++++-------- tests/unit_tests/checkpoint/test_addons.py | 18 +++++++++ .../biencoder/test_state_dict_adapter.py | 39 +++++++++++++++++++ 4 files changed, 93 insertions(+), 16 deletions(-) diff --git a/nemo_automodel/components/checkpoint/addons.py b/nemo_automodel/components/checkpoint/addons.py index 207bbc216..053339029 100644 --- a/nemo_automodel/components/checkpoint/addons.py +++ b/nemo_automodel/components/checkpoint/addons.py @@ -308,6 +308,24 @@ def _extract_target_modules(model: nn.Module) -> list[str]: final_target_modules.add(f"{expert_path}.{expert_id}.down_proj") break + # When the model is a biencoder, LoRA target modules are discovered under + # the "lm_q." prefix (the biencoder query-encoder wrapper). The standalone + # HF base model (LlamaBidirectionalModel) uses bare names without any + # wrapper prefix (e.g. "layers.0.self_attn.q_proj"). Strip "lm_q." so + # the saved adapter_config.json is compatible with merge_lora / HF PEFT. + adapter = getattr(model, "state_dict_adapter", None) + if adapter is not None: + from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter + + if isinstance(adapter, BiencoderStateDictAdapter): + remapped = set() + for name in final_target_modules: + if name.startswith("lm_q."): + remapped.add(name[len("lm_q."):]) + else: + remapped.add(name) + final_target_modules = remapped + return sorted(list(final_target_modules)) diff --git a/nemo_automodel/components/models/biencoder/state_dict_adapter.py b/nemo_automodel/components/models/biencoder/state_dict_adapter.py index f19bfea72..2d251c7e6 100644 --- a/nemo_automodel/components/models/biencoder/state_dict_adapter.py +++ b/nemo_automodel/components/models/biencoder/state_dict_adapter.py @@ -37,10 +37,11 @@ def __init__(self): def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: """Convert from biencoder state dict to HuggingFace format. - Filters to only lm_q keys and converts "lm_q." prefix to "model." prefix. - Also handles the ``base_model.model.`` prefix that PEFT checkpointing adds - so that adapter weights are correctly converted (e.g. - ``base_model.model.lm_q.X`` → ``base_model.model.model.X``). + Filters to only lm_q keys and converts "lm_q." prefix to "model." prefix + for base-model weights. For PEFT adapter weights (prefixed with + ``base_model.model.``), the ``lm_q.`` segment is stripped so that keys + match the standalone HF model's module names (e.g. + ``base_model.model.lm_q.X`` → ``base_model.model.X``). Args: state_dict: The biencoder model state dict @@ -54,10 +55,10 @@ def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: for key, value in state_dict.items(): if key.startswith("lm_q."): - new_key = "model." + key[len("lm_q.") :] + new_key = "model." + key[len("lm_q."):] hf_state_dict[new_key] = value elif key.startswith(peft_lm_q): - new_key = self._PEFT_PREFIX + "model." + key[len(peft_lm_q) :] + new_key = self._PEFT_PREFIX + key[len(peft_lm_q):] hf_state_dict[new_key] = value elif key.startswith("linear_pooler.") or key.startswith(peft_pooler): hf_state_dict[key] = value @@ -74,7 +75,7 @@ def from_hf( Converts "model." prefix to "lm_q." prefix for loading into biencoder. Also handles the ``base_model.model.`` prefix used by PEFT checkpoints - (e.g. ``base_model.model.model.X`` → ``base_model.model.lm_q.X``). + (e.g. ``base_model.model.X`` → ``base_model.model.lm_q.X``). Args: hf_state_dict: The HuggingFace format state dict @@ -84,20 +85,21 @@ def from_hf( The converted biencoder format state dict """ biencoder_state_dict = {} - peft_model = self._PEFT_PREFIX + "model." peft_pooler = self._PEFT_PREFIX + "linear_pooler." for key, value in hf_state_dict.items(): - if key.startswith("model."): - suffix = key[len("model.") :] - biencoder_state_dict["lm_q." + suffix] = value - biencoder_state_dict["lm_p." + suffix] = value - elif key.startswith(peft_model): - suffix = key[len(peft_model) :] + if key.startswith("linear_pooler.") or key.startswith(peft_pooler): + biencoder_state_dict[key] = value + elif key.startswith(self._PEFT_PREFIX): + # PEFT format: base_model.model.X → base_model.model.lm_q.X + suffix = key[len(self._PEFT_PREFIX):] biencoder_state_dict[self._PEFT_PREFIX + "lm_q." + suffix] = value biencoder_state_dict[self._PEFT_PREFIX + "lm_p." + suffix] = value - elif key.startswith("linear_pooler.") or key.startswith(peft_pooler): - biencoder_state_dict[key] = value + elif key.startswith("model."): + # Full checkpoint: model.X → lm_q.X + suffix = key[len("model."):] + biencoder_state_dict["lm_q." + suffix] = value + biencoder_state_dict["lm_p." + suffix] = value return biencoder_state_dict diff --git a/tests/unit_tests/checkpoint/test_addons.py b/tests/unit_tests/checkpoint/test_addons.py index edbd33502..4323093d8 100644 --- a/tests/unit_tests/checkpoint/test_addons.py +++ b/tests/unit_tests/checkpoint/test_addons.py @@ -213,4 +213,22 @@ def test_result_is_sorted(self): result = _extract_target_modules(model) assert result == sorted(result) + def test_biencoder_target_modules_remapped(self): + """Biencoder lm_q.* target modules have lm_q. prefix stripped.""" + from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter + + model = _make_model_with_named_modules( + [ + "lm_q.layers.0.self_attn.q_proj.lora_A", + "lm_q.layers.0.self_attn.k_proj.lora_A", + "lm_q.layers.0.mlp.down_proj.lora_A", + ] + ) + model.state_dict_adapter = BiencoderStateDictAdapter() + result = _extract_target_modules(model) + assert "layers.0.self_attn.q_proj" in result + assert "layers.0.self_attn.k_proj" in result + assert "layers.0.mlp.down_proj" in result + assert all("lm_q" not in m for m in result) + diff --git a/tests/unit_tests/models/biencoder/test_state_dict_adapter.py b/tests/unit_tests/models/biencoder/test_state_dict_adapter.py index 29ea3ec81..32e8427f1 100644 --- a/tests/unit_tests/models/biencoder/test_state_dict_adapter.py +++ b/tests/unit_tests/models/biencoder/test_state_dict_adapter.py @@ -232,3 +232,42 @@ def test_prefix_replacement_accuracy(self, adapter): # Should only replace the first occurrence of lm_q. assert "model.model.layer.sublayer.weight" in hf_state_dict assert "lm_q.model.layer.sublayer.weight" not in hf_state_dict + + def test_to_hf_peft_keys_strip_lm_q(self, adapter): + """PEFT keys should have lm_q. stripped without adding model. prefix.""" + state_dict = { + "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8), + } + + hf_state_dict = adapter.to_hf(state_dict) + + assert "base_model.model.layers.0.self_attn.q_proj.lora_A.weight" in hf_state_dict + assert "base_model.model.layers.0.self_attn.q_proj.lora_B.weight" in hf_state_dict + # Should NOT have extra model. prefix + assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" not in hf_state_dict + + def test_from_hf_peft_keys_add_lm_q(self, adapter): + """PEFT keys should get lm_q./lm_p. inserted after base_model.model. prefix.""" + hf_state_dict = { + "base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + } + + biencoder_state_dict = adapter.from_hf(hf_state_dict) + + assert "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight" in biencoder_state_dict + assert "base_model.model.lm_p.layers.0.self_attn.q_proj.lora_A.weight" in biencoder_state_dict + + def test_peft_roundtrip(self, adapter): + """PEFT keys should survive a to_hf → from_hf roundtrip.""" + original = { + "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "base_model.model.lm_q.layers.0.mlp.down_proj.lora_B.weight": torch.randn(64, 8), + } + + hf = adapter.to_hf(original) + restored = adapter.from_hf(hf) + + for key in original: + assert key in restored, f"Missing key {key} after roundtrip" + torch.testing.assert_close(restored[key], original[key]) From 019acfce0c112679f1344642c2febb8c19ae1e7c Mon Sep 17 00:00:00 2001 From: adil-a Date: Sat, 7 Mar 2026 04:55:23 +0000 Subject: [PATCH 2/2] lint Signed-off-by: adil-a --- nemo_automodel/components/checkpoint/addons.py | 2 +- .../components/models/biencoder/state_dict_adapter.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo_automodel/components/checkpoint/addons.py b/nemo_automodel/components/checkpoint/addons.py index 053339029..0c2bc1042 100644 --- a/nemo_automodel/components/checkpoint/addons.py +++ b/nemo_automodel/components/checkpoint/addons.py @@ -321,7 +321,7 @@ def _extract_target_modules(model: nn.Module) -> list[str]: remapped = set() for name in final_target_modules: if name.startswith("lm_q."): - remapped.add(name[len("lm_q."):]) + remapped.add(name[len("lm_q.") :]) else: remapped.add(name) final_target_modules = remapped diff --git a/nemo_automodel/components/models/biencoder/state_dict_adapter.py b/nemo_automodel/components/models/biencoder/state_dict_adapter.py index 2d251c7e6..759cc9215 100644 --- a/nemo_automodel/components/models/biencoder/state_dict_adapter.py +++ b/nemo_automodel/components/models/biencoder/state_dict_adapter.py @@ -55,10 +55,10 @@ def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: for key, value in state_dict.items(): if key.startswith("lm_q."): - new_key = "model." + key[len("lm_q."):] + new_key = "model." + key[len("lm_q.") :] hf_state_dict[new_key] = value elif key.startswith(peft_lm_q): - new_key = self._PEFT_PREFIX + key[len(peft_lm_q):] + new_key = self._PEFT_PREFIX + key[len(peft_lm_q) :] hf_state_dict[new_key] = value elif key.startswith("linear_pooler.") or key.startswith(peft_pooler): hf_state_dict[key] = value @@ -92,12 +92,12 @@ def from_hf( biencoder_state_dict[key] = value elif key.startswith(self._PEFT_PREFIX): # PEFT format: base_model.model.X → base_model.model.lm_q.X - suffix = key[len(self._PEFT_PREFIX):] + suffix = key[len(self._PEFT_PREFIX) :] biencoder_state_dict[self._PEFT_PREFIX + "lm_q." + suffix] = value biencoder_state_dict[self._PEFT_PREFIX + "lm_p." + suffix] = value elif key.startswith("model."): # Full checkpoint: model.X → lm_q.X - suffix = key[len("model."):] + suffix = key[len("model.") :] biencoder_state_dict["lm_q." + suffix] = value biencoder_state_dict["lm_p." + suffix] = value