Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions nemo_automodel/components/checkpoint/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,24 @@ def _extract_target_modules(model: nn.Module) -> list[str]:
final_target_modules.add(f"{expert_path}.{expert_id}.down_proj")
break

# When the model is a biencoder, LoRA target modules are discovered under
# the "lm_q." prefix (the biencoder query-encoder wrapper). The standalone
# HF base model (LlamaBidirectionalModel) uses bare names without any
# wrapper prefix (e.g. "layers.0.self_attn.q_proj"). Strip "lm_q." so
# the saved adapter_config.json is compatible with merge_lora / HF PEFT.
adapter = getattr(model, "state_dict_adapter", None)
if adapter is not None:
from nemo_automodel.components.models.common.bidirectional import BiencoderStateDictAdapter

if isinstance(adapter, BiencoderStateDictAdapter):
remapped = set()
for name in final_target_modules:
if name.startswith("lm_q."):
remapped.add(name[len("lm_q.") :])
else:
remapped.add(name)
final_target_modules = remapped

return sorted(list(final_target_modules))


Expand Down
40 changes: 30 additions & 10 deletions nemo_automodel/components/models/common/bidirectional.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,33 @@ def __init__(self):
self._uses_model_prefix = True

@staticmethod
def _swap_key(key: str, src: str, dst: str, peft_prefix: str) -> Optional[str]:
def _swap_key(key: str, src: str, dst: str, peft_prefix: str, peft_dst: Optional[str] = None) -> Optional[str]:
"""Return *key* with *src* prefix replaced by *dst*, handling an optional PEFT wrapper.

Args:
peft_dst: Destination prefix for PEFT-wrapped keys. Defaults to *dst*.

Returns ``None`` when *key* doesn't match *src* (bare or PEFT-wrapped).
"""
if key.startswith(src):
return dst + key[len(src) :]
peft_src = peft_prefix + src
if key.startswith(peft_src):
return peft_prefix + dst + key[len(peft_src) :]
effective_dst = peft_dst if peft_dst is not None else dst
return peft_prefix + effective_dst + key[len(peft_src) :]
return None

def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]:
"""Convert biencoder state dict to HF format (lm_q -> model)."""
"""Convert biencoder state dict to HF format.

Base-model weights: ``lm_q.X`` → ``model.X``.
PEFT adapter weights: ``base_model.model.lm_q.X`` → ``base_model.model.X``
(strips ``lm_q.`` without adding ``model.`` so keys match the standalone
``LlamaBidirectionalModel`` which extends ``LlamaModel`` directly).
"""
hf_state_dict = {}
for key, value in state_dict.items():
new_key = self._swap_key(key, "lm_q.", "model.", self._PEFT_PREFIX)
new_key = self._swap_key(key, "lm_q.", "model.", self._PEFT_PREFIX, peft_dst="")
if new_key is not None:
hf_state_dict[new_key] = value
return hf_state_dict
Expand All @@ -67,14 +77,24 @@ def from_hf(
device_mesh: Optional["DeviceMesh"] = None,
**kwargs,
) -> dict[str, Any]:
"""Convert HF state dict to biencoder format (model -> lm_q + lm_p)."""
"""Convert HF state dict to biencoder format (model -> lm_q + lm_p).

Handles both bare keys (``model.X``) and PEFT keys (``base_model.model.X``).
"""
biencoder_state_dict = {}
for key, value in hf_state_dict.items():
q_key = self._swap_key(key, "model.", "lm_q.", self._PEFT_PREFIX)
if q_key is not None:
p_key = self._swap_key(key, "model.", "lm_p.", self._PEFT_PREFIX)
biencoder_state_dict[q_key] = value
biencoder_state_dict[p_key] = value
if key.startswith(self._PEFT_PREFIX):
# PEFT format: base_model.model.X → base_model.model.lm_q.X
# Only restore to lm_q; lm_p shares parameters in shared-encoder
# mode and loading into lm_p would fail with FSDP DTensors.
suffix = key[len(self._PEFT_PREFIX) :]
biencoder_state_dict[self._PEFT_PREFIX + "lm_q." + suffix] = value
else:
q_key = self._swap_key(key, "model.", "lm_q.", self._PEFT_PREFIX)
if q_key is not None:
p_key = self._swap_key(key, "model.", "lm_p.", self._PEFT_PREFIX)
biencoder_state_dict[q_key] = value
biencoder_state_dict[p_key] = value
return biencoder_state_dict

def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
Expand Down
2 changes: 2 additions & 0 deletions nemo_automodel/recipes/biencoder/train_biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
build_step_scheduler,
build_wandb,
)
from nemo_automodel.shared.te_patches import apply_te_patches

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -151,6 +152,7 @@ def setup(self):
setup_logging()

apply_cache_compatibility_patches()
apply_te_patches()
self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True)

self.dist_setup = setup_distributed(self.cfg, world_size=self.dist_env.world_size)
Expand Down
18 changes: 18 additions & 0 deletions tests/unit_tests/checkpoint/test_addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,22 @@ def test_result_is_sorted(self):
result = _extract_target_modules(model)
assert result == sorted(result)

def test_biencoder_target_modules_remapped(self):
"""Biencoder lm_q.* target modules have lm_q. prefix stripped."""
from nemo_automodel.components.models.common.bidirectional import BiencoderStateDictAdapter

model = _make_model_with_named_modules(
[
"lm_q.layers.0.self_attn.q_proj.lora_A",
"lm_q.layers.0.self_attn.k_proj.lora_A",
"lm_q.layers.0.mlp.down_proj.lora_A",
]
)
model.state_dict_adapter = BiencoderStateDictAdapter()
result = _extract_target_modules(model)
assert "layers.0.self_attn.q_proj" in result
assert "layers.0.self_attn.k_proj" in result
assert "layers.0.mlp.down_proj" in result
assert all("lm_q" not in m for m in result)


42 changes: 42 additions & 0 deletions tests/unit_tests/models/biencoder/test_state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,45 @@ def test_prefix_replacement_accuracy(self, adapter):
# Should only replace the first occurrence of lm_q.
assert "model.model.layer.sublayer.weight" in hf_state_dict
assert "lm_q.model.layer.sublayer.weight" not in hf_state_dict

def test_to_hf_peft_keys_strip_lm_q(self, adapter):
"""PEFT keys should have lm_q. stripped without adding model. prefix."""
state_dict = {
"base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64),
"base_model.model.lm_q.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8),
}

hf_state_dict = adapter.to_hf(state_dict)

assert "base_model.model.layers.0.self_attn.q_proj.lora_A.weight" in hf_state_dict
assert "base_model.model.layers.0.self_attn.q_proj.lora_B.weight" in hf_state_dict
# Should NOT have extra model. prefix
assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" not in hf_state_dict

def test_from_hf_peft_keys_add_lm_q(self, adapter):
"""PEFT keys should get lm_q. inserted after base_model.model. prefix (not lm_p)."""
hf_state_dict = {
"base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64),
}

biencoder_state_dict = adapter.from_hf(hf_state_dict)

assert "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight" in biencoder_state_dict
# PEFT keys should NOT be fanned out to lm_p (shared encoder)
assert "base_model.model.lm_p.layers.0.self_attn.q_proj.lora_A.weight" not in biencoder_state_dict

def test_peft_roundtrip(self, adapter):
"""PEFT lm_q keys should survive a to_hf → from_hf roundtrip."""
original = {
"base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64),
"base_model.model.lm_q.layers.0.mlp.down_proj.lora_B.weight": torch.randn(64, 8),
}

hf = adapter.to_hf(original)
restored = adapter.from_hf(hf)

for key in original:
assert key in restored, f"Missing key {key} after roundtrip"
torch.testing.assert_close(restored[key], original[key])
# lm_p should NOT appear in PEFT roundtrip
assert not any("lm_p" in k for k in restored)
Loading