Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions examples/biencoder/nemotron_nanov2_biencoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# To run this recipe, please use the following command:
# python examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml
# Or with torchrun for multi-GPU:
# torchrun --nproc-per-node=8 examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml

seed: 125

step_scheduler:
global_batch_size: 128
local_batch_size: 1
ckpt_every_steps: 100
val_every_steps: 100
num_epochs: 1

dist_env:
backend: nccl
timeout_minutes: 1

model:
_target_: nemo_automodel.components.models.biencoder.NeMoAutoModelBiencoder.from_pretrained
pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-Nano-9B-v2
train_n_passages: 5
pooling: avg
t: 0.02
torch_dtype: bfloat16
attn_implementation: eager

# Bidirectional processing strategy for Mamba layers
# Options:
# - average: Simple average of forward and backward passes
# - concat: Concatenate forward and backward (doubles hidden size)
# - weighted: Weighted average using forward_weight
# - gated: Learned gating mechanism (requires training)
mamba_bidirectional_strategy: average

# Forward weight (only used if mamba_bidirectional_strategy=weighted)
# Value between 0 and 1, backward_weight = 1 - forward_weight
forward_weight: 0.5

# Bidirectional attention for attention layers
# If true, attention layers use bidirectional attention (all tokens attend to all tokens)
# If false, attention layers use causal attention (tokens only attend to past tokens)
bidirectional_attention: true


tokenizer:
_target_: transformers.AutoTokenizer.from_pretrained
pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-Nano-9B-v2

dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
dataset:
_target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset
data_dir_list:
- ./embed_nemotron_dataset_v1/TriviaQA/TriviaQA.json
- ./embed_nemotron_dataset_v1/SyntheticClassificationData/SyntheticClassificationData.json

data_type: train
train_n_passages: 5
seed: 125
do_shuffle: true
use_dataset_instruction: true
collate_fn:
_target_: nemo_automodel.components.datasets.llm.RetrievalBiencoderCollator
q_max_len: 512
p_max_len: 512
query_prefix: ""
passage_prefix: ""
pad_to_multiple_of: 8
use_dataset_instruction: true
shuffle: true
num_workers: 8

optimizer:
_target_: torch.optim.AdamW
lr: 2.0e-6
weight_decay: 0.01

lr_scheduler:
lr_warmup_steps: 2

checkpoint:
enabled: true
checkpoint_dir: ./output/nemotron_nano_9b_biencoder
model_save_format: safetensors
save_consolidated: true

wandb:
project: nemotron-finetuning
entity: nvidia-merlin # Replace with your wandb entity/username
name: nemotron_nano_9b_biencoder
tags:
- biencoder
- nemotron
- retrieval
notes: "Finetuning Nemotron Nano 9B for biencoder retrieval"

distributed:
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
dp_size: none
dp_replicate_size: 1
tp_size: 1
cp_size: 1
sequence_parallel: false

dist_env:
backend: nccl
timeout_minutes: 30
2 changes: 1 addition & 1 deletion nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def _maybe_build_consolidated_index(

# Add any missing keys from the model_state_dict
# These will go to the same file as the last file (or file 1 for single-file models)
default_index = max(fqn_to_file_index_mapping.values())
default_index = max(fqn_to_file_index_mapping.values()) if fqn_to_file_index_mapping else 1

# add any additional keys that are not in the base checkpoint
for fqn in list(state_dict.keys()):
Expand Down
14 changes: 10 additions & 4 deletions nemo_automodel/components/models/biencoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@
"""

from .biencoder_model import NeMoAutoModelBiencoder # noqa: F401
from .llama_bidirectional_model import ( # noqa: F401
# from .llama_bidirectional_model import ( # noqa: F401
# BiencoderModel,
# BiencoderOutput,
# LlamaBidirectionalConfig,
# LlamaBidirectionalForSequenceClassification,
# LlamaBidirectionalModel,
# )
from .nemotron_bidirectional_model import ( # noqa: F401
BiencoderModel,
BiencoderOutput,
LlamaBidirectionalConfig,
LlamaBidirectionalForSequenceClassification,
LlamaBidirectionalModel,
NemotronBidirectionalConfig,
NemotronBidirectionalModel,
)

__all__ = [
Expand Down
10 changes: 7 additions & 3 deletions nemo_automodel/components/models/biencoder/biencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import gc
import logging
from typing import List, Optional
import torch

from torch.nn.attention import SDPBackend

Expand All @@ -25,7 +26,8 @@
_patch_liger_kernel,
)

from .llama_bidirectional_model import BiencoderModel
# from .llama_bidirectional_model import BiencoderModel
from .nemotron_bidirectional_model import BiencoderModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,8 +115,10 @@ def _retry(**override):
)

# Use BiencoderModel.build to initialize model with base encoders
hf_kwargs = {"attn_implementation": "flash_attention_2"}
kwargs.update(hf_kwargs)
# Only set attn_implementation if not already provided in kwargs
if "attn_implementation" not in kwargs:
kwargs["attn_implementation"] = "eager"

model = BiencoderModel.build(
model_name_or_path=pretrained_model_name_or_path,
share_encoder=share_encoder,
Expand Down
Loading
Loading