diff --git a/nemo_automodel/components/checkpoint/addons.py b/nemo_automodel/components/checkpoint/addons.py index 077bfb28b..ed8b94c38 100644 --- a/nemo_automodel/components/checkpoint/addons.py +++ b/nemo_automodel/components/checkpoint/addons.py @@ -308,6 +308,24 @@ def _extract_target_modules(model: nn.Module) -> list[str]: final_target_modules.add(f"{expert_path}.{expert_id}.down_proj") break + # When the model is a biencoder, LoRA target modules are discovered under + # the "lm_q." prefix (the biencoder query-encoder wrapper). The standalone + # HF base model (LlamaBidirectionalModel) uses bare names without any + # wrapper prefix (e.g. "layers.0.self_attn.q_proj"). Strip "lm_q." so + # the saved adapter_config.json is compatible with merge_lora / HF PEFT. + adapter = getattr(model, "state_dict_adapter", None) + if adapter is not None: + from nemo_automodel.components.models.common.bidirectional import BiencoderStateDictAdapter + + if isinstance(adapter, BiencoderStateDictAdapter): + remapped = set() + for name in final_target_modules: + if name.startswith("lm_q."): + remapped.add(name[len("lm_q.") :]) + else: + remapped.add(name) + final_target_modules = remapped + return sorted(list(final_target_modules)) diff --git a/nemo_automodel/components/models/common/bidirectional.py b/nemo_automodel/components/models/common/bidirectional.py index 4ac525c51..b093dcf51 100644 --- a/nemo_automodel/components/models/common/bidirectional.py +++ b/nemo_automodel/components/models/common/bidirectional.py @@ -40,23 +40,33 @@ def __init__(self): self._uses_model_prefix = True @staticmethod - def _swap_key(key: str, src: str, dst: str, peft_prefix: str) -> Optional[str]: + def _swap_key(key: str, src: str, dst: str, peft_prefix: str, peft_dst: Optional[str] = None) -> Optional[str]: """Return *key* with *src* prefix replaced by *dst*, handling an optional PEFT wrapper. + Args: + peft_dst: Destination prefix for PEFT-wrapped keys. Defaults to *dst*. + Returns ``None`` when *key* doesn't match *src* (bare or PEFT-wrapped). """ if key.startswith(src): return dst + key[len(src) :] peft_src = peft_prefix + src if key.startswith(peft_src): - return peft_prefix + dst + key[len(peft_src) :] + effective_dst = peft_dst if peft_dst is not None else dst + return peft_prefix + effective_dst + key[len(peft_src) :] return None def to_hf(self, state_dict: dict[str, Any], **kwargs) -> dict[str, Any]: - """Convert biencoder state dict to HF format (lm_q -> model).""" + """Convert biencoder state dict to HF format. + + Base-model weights: ``lm_q.X`` → ``model.X``. + PEFT adapter weights: ``base_model.model.lm_q.X`` → ``base_model.model.X`` + (strips ``lm_q.`` without adding ``model.`` so keys match the standalone + ``LlamaBidirectionalModel`` which extends ``LlamaModel`` directly). + """ hf_state_dict = {} for key, value in state_dict.items(): - new_key = self._swap_key(key, "lm_q.", "model.", self._PEFT_PREFIX) + new_key = self._swap_key(key, "lm_q.", "model.", self._PEFT_PREFIX, peft_dst="") if new_key is not None: hf_state_dict[new_key] = value return hf_state_dict @@ -67,14 +77,24 @@ def from_hf( device_mesh: Optional["DeviceMesh"] = None, **kwargs, ) -> dict[str, Any]: - """Convert HF state dict to biencoder format (model -> lm_q + lm_p).""" + """Convert HF state dict to biencoder format (model -> lm_q + lm_p). + + Handles both bare keys (``model.X``) and PEFT keys (``base_model.model.X``). + """ biencoder_state_dict = {} for key, value in hf_state_dict.items(): - q_key = self._swap_key(key, "model.", "lm_q.", self._PEFT_PREFIX) - if q_key is not None: - p_key = self._swap_key(key, "model.", "lm_p.", self._PEFT_PREFIX) - biencoder_state_dict[q_key] = value - biencoder_state_dict[p_key] = value + if key.startswith(self._PEFT_PREFIX): + # PEFT format: base_model.model.X → base_model.model.lm_q.X + # Only restore to lm_q; lm_p shares parameters in shared-encoder + # mode and loading into lm_p would fail with FSDP DTensors. + suffix = key[len(self._PEFT_PREFIX) :] + biencoder_state_dict[self._PEFT_PREFIX + "lm_q." + suffix] = value + else: + q_key = self._swap_key(key, "model.", "lm_q.", self._PEFT_PREFIX) + if q_key is not None: + p_key = self._swap_key(key, "model.", "lm_p.", self._PEFT_PREFIX) + biencoder_state_dict[q_key] = value + biencoder_state_dict[p_key] = value return biencoder_state_dict def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]: diff --git a/nemo_automodel/recipes/biencoder/train_biencoder.py b/nemo_automodel/recipes/biencoder/train_biencoder.py index 27bdfde2f..020cb7eff 100644 --- a/nemo_automodel/recipes/biencoder/train_biencoder.py +++ b/nemo_automodel/recipes/biencoder/train_biencoder.py @@ -43,6 +43,7 @@ build_step_scheduler, build_wandb, ) +from nemo_automodel.shared.te_patches import apply_te_patches logger = logging.getLogger(__name__) @@ -151,6 +152,7 @@ def setup(self): setup_logging() apply_cache_compatibility_patches() + apply_te_patches() self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True) self.dist_setup = setup_distributed(self.cfg, world_size=self.dist_env.world_size) diff --git a/tests/unit_tests/checkpoint/test_addons.py b/tests/unit_tests/checkpoint/test_addons.py index 8836ba27e..ccd2e341e 100644 --- a/tests/unit_tests/checkpoint/test_addons.py +++ b/tests/unit_tests/checkpoint/test_addons.py @@ -213,4 +213,22 @@ def test_result_is_sorted(self): result = _extract_target_modules(model) assert result == sorted(result) + def test_biencoder_target_modules_remapped(self): + """Biencoder lm_q.* target modules have lm_q. prefix stripped.""" + from nemo_automodel.components.models.common.bidirectional import BiencoderStateDictAdapter + + model = _make_model_with_named_modules( + [ + "lm_q.layers.0.self_attn.q_proj.lora_A", + "lm_q.layers.0.self_attn.k_proj.lora_A", + "lm_q.layers.0.mlp.down_proj.lora_A", + ] + ) + model.state_dict_adapter = BiencoderStateDictAdapter() + result = _extract_target_modules(model) + assert "layers.0.self_attn.q_proj" in result + assert "layers.0.self_attn.k_proj" in result + assert "layers.0.mlp.down_proj" in result + assert all("lm_q" not in m for m in result) + diff --git a/tests/unit_tests/models/biencoder/test_state_dict_adapter.py b/tests/unit_tests/models/biencoder/test_state_dict_adapter.py index c7f05c16f..49bf3b76d 100644 --- a/tests/unit_tests/models/biencoder/test_state_dict_adapter.py +++ b/tests/unit_tests/models/biencoder/test_state_dict_adapter.py @@ -190,3 +190,45 @@ def test_prefix_replacement_accuracy(self, adapter): # Should only replace the first occurrence of lm_q. assert "model.model.layer.sublayer.weight" in hf_state_dict assert "lm_q.model.layer.sublayer.weight" not in hf_state_dict + + def test_to_hf_peft_keys_strip_lm_q(self, adapter): + """PEFT keys should have lm_q. stripped without adding model. prefix.""" + state_dict = { + "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_B.weight": torch.randn(64, 8), + } + + hf_state_dict = adapter.to_hf(state_dict) + + assert "base_model.model.layers.0.self_attn.q_proj.lora_A.weight" in hf_state_dict + assert "base_model.model.layers.0.self_attn.q_proj.lora_B.weight" in hf_state_dict + # Should NOT have extra model. prefix + assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" not in hf_state_dict + + def test_from_hf_peft_keys_add_lm_q(self, adapter): + """PEFT keys should get lm_q. inserted after base_model.model. prefix (not lm_p).""" + hf_state_dict = { + "base_model.model.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + } + + biencoder_state_dict = adapter.from_hf(hf_state_dict) + + assert "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight" in biencoder_state_dict + # PEFT keys should NOT be fanned out to lm_p (shared encoder) + assert "base_model.model.lm_p.layers.0.self_attn.q_proj.lora_A.weight" not in biencoder_state_dict + + def test_peft_roundtrip(self, adapter): + """PEFT lm_q keys should survive a to_hf → from_hf roundtrip.""" + original = { + "base_model.model.lm_q.layers.0.self_attn.q_proj.lora_A.weight": torch.randn(8, 64), + "base_model.model.lm_q.layers.0.mlp.down_proj.lora_B.weight": torch.randn(64, 8), + } + + hf = adapter.to_hf(original) + restored = adapter.from_hf(hf) + + for key in original: + assert key in restored, f"Missing key {key} after roundtrip" + torch.testing.assert_close(restored[key], original[key]) + # lm_p should NOT appear in PEFT roundtrip + assert not any("lm_p" in k for k in restored)