Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions nemo_automodel/components/checkpoint/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,24 @@ def _extract_target_modules(model: nn.Module) -> list[str]:
final_target_modules.add(f"{expert_path}.{expert_id}.down_proj")
break

# When the model is a biencoder, LoRA target modules are discovered under
# the "lm_q." prefix (the biencoder query-encoder wrapper). The standalone
# HF base model (LlamaBidirectionalModel) uses bare names without any
# wrapper prefix (e.g. "layers.0.self_attn.q_proj"). Strip "lm_q." so
# the saved adapter_config.json is compatible with merge_lora / HF PEFT.
adapter = getattr(model, "state_dict_adapter", None)
if adapter is not None:
from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter

if isinstance(adapter, BiencoderStateDictAdapter):
remapped = set()
for name in final_target_modules:
if name.startswith("lm_q."):
remapped.add(name[len("lm_q.") :])
else:
remapped.add(name)
final_target_modules = remapped

return sorted(list(final_target_modules))


Expand Down
30 changes: 16 additions & 14 deletions nemo_automodel/components/models/biencoder/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def __init__(self):
def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
"""Convert from biencoder state dict to HuggingFace format.

Filters to only lm_q keys and converts "lm_q." prefix to "model." prefix.
Also handles the ``base_model.model.`` prefix that PEFT checkpointing adds
so that adapter weights are correctly converted (e.g.
``base_model.model.lm_q.X`` → ``base_model.model.model.X``).
Filters to only lm_q keys and converts "lm_q." prefix to "model." prefix
for base-model weights. For PEFT adapter weights (prefixed with
``base_model.model.``), the ``lm_q.`` segment is stripped so that keys
match the standalone HF model's module names (e.g.
``base_model.model.lm_q.X`` → ``base_model.model.X``).

Args:
state_dict: The biencoder model state dict
Expand All @@ -57,7 +58,7 @@ def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
new_key = "model." + key[len("lm_q.") :]
hf_state_dict[new_key] = value
elif key.startswith(peft_lm_q):
new_key = self._PEFT_PREFIX + "model." + key[len(peft_lm_q) :]
new_key = self._PEFT_PREFIX + key[len(peft_lm_q) :]
hf_state_dict[new_key] = value
elif key.startswith("linear_pooler.") or key.startswith(peft_pooler):
hf_state_dict[key] = value
Expand All @@ -74,7 +75,7 @@ def from_hf(

Converts "model." prefix to "lm_q." prefix for loading into biencoder.
Also handles the ``base_model.model.`` prefix used by PEFT checkpoints
(e.g. ``base_model.model.model.X`` → ``base_model.model.lm_q.X``).
(e.g. ``base_model.model.X`` → ``base_model.model.lm_q.X``).

Args:
hf_state_dict: The HuggingFace format state dict
Expand All @@ -84,20 +85,21 @@ def from_hf(
The converted biencoder format state dict
"""
biencoder_state_dict = {}
peft_model = self._PEFT_PREFIX + "model."
peft_pooler = self._PEFT_PREFIX + "linear_pooler."

for key, value in hf_state_dict.items():
if key.startswith("model."):
if key.startswith("linear_pooler.") or key.startswith(peft_pooler):
biencoder_state_dict[key] = value
elif key.startswith(self._PEFT_PREFIX):
# PEFT format: base_model.model.X → base_model.model.lm_q.X
suffix = key[len(self._PEFT_PREFIX) :]
biencoder_state_dict[self._PEFT_PREFIX + "lm_q." + suffix] = value
biencoder_state_dict[self._PEFT_PREFIX + "lm_p." + suffix] = value
elif key.startswith("model."):
# Full checkpoint: model.X → lm_q.X
suffix = key[len("model.") :]
biencoder_state_dict["lm_q." + suffix] = value
biencoder_state_dict["lm_p." + suffix] = value
elif key.startswith(peft_model):
suffix = key[len(peft_model) :]
biencoder_state_dict[self._PEFT_PREFIX + "lm_q." + suffix] = value
biencoder_state_dict[self._PEFT_PREFIX + "lm_p." + suffix] = value
elif key.startswith("linear_pooler.") or key.startswith(peft_pooler):
biencoder_state_dict[key] = value

return biencoder_state_dict

Expand Down
18 changes: 18 additions & 0 deletions tests/unit_tests/checkpoint/test_addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,22 @@ def test_result_is_sorted(self):
result = _extract_target_modules(model)
assert result == sorted(result)

def test_biencoder_target_modules_remapped(self):
"""Biencoder lm_q.* target modules have lm_q. prefix stripped."""
from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter

model = _make_model_with_named_modules(
[
"lm_q.layers.0.self_attn.q_proj.lora_A",
"lm_q.layers.0.self_attn.k_proj.lora_A",
"lm_q.layers.0.mlp.down_proj.lora_A",
]
)
model.state_dict_adapter = BiencoderStateDictAdapter()
result = _extract_target_modules(model)
assert "layers.0.self_attn.q_proj" in result
assert "layers.0.self_attn.k_proj" in result
assert "layers.0.mlp.down_proj" in result
assert all("lm_q" not in m for m in result)


39 changes: 39 additions & 0 deletions tests/unit_tests/models/biencoder/test_state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,42 @@ def test_prefix_replacement_accuracy(self, adapter):
# Should only replace the first occurrence of lm_q.
assert "model.model.layer.sublayer.weight" in hf_state_dict
assert "lm_q.model.layer.sublayer.weight" not in hf_state_dict

def test_to_hf_peft_keys_strip_lm_q(self, adapter):
"""PEFT keys should have lm_q. stripped without adding model. prefix."""
state_dict = {
"base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64),
"base_model.model.lm_q.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8),
}

hf_state_dict = adapter.to_hf(state_dict)

assert "base_model.model.layers.0.self_attn.q_proj.lora_A.weight" in hf_state_dict
assert "base_model.model.layers.0.self_attn.q_proj.lora_B.weight" in hf_state_dict
# Should NOT have extra model. prefix
assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" not in hf_state_dict

def test_from_hf_peft_keys_add_lm_q(self, adapter):
"""PEFT keys should get lm_q./lm_p. inserted after base_model.model. prefix."""
hf_state_dict = {
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64),
}

biencoder_state_dict = adapter.from_hf(hf_state_dict)

assert "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight" in biencoder_state_dict
assert "base_model.model.lm_p.layers.0.self_attn.q_proj.lora_A.weight" in biencoder_state_dict

def test_peft_roundtrip(self, adapter):
"""PEFT keys should survive a to_hf → from_hf roundtrip."""
original = {
"base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64),
"base_model.model.lm_q.layers.0.mlp.down_proj.lora_B.weight": torch.randn(64, 8),
}

hf = adapter.to_hf(original)
restored = adapter.from_hf(hf)

for key in original:
assert key in restored, f"Missing key {key} after roundtrip"
torch.testing.assert_close(restored[key], original[key])
Loading