From af9a3609fd51d6b5c2a294e14a4386042a18a1c2 Mon Sep 17 00:00:00 2001 From: Vinay Raman Date: Tue, 27 Jan 2026 23:48:05 -0800 Subject: [PATCH 1/8] nemotron embedding model last pooling --- .../biencoder/nemotron_bidirectional_model.py | 540 ++++++ .../nemotron/configuration_nemotron_h.py | 243 +++ .../models/nemotron/modeling_nemotron_h.py | 1638 +++++++++++++++++ 3 files changed, 2421 insertions(+) create mode 100644 nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py create mode 100644 nemo_automodel/components/models/nemotron/configuration_nemotron_h.py create mode 100644 nemo_automodel/components/models/nemotron/modeling_nemotron_h.py diff --git a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py new file mode 100644 index 000000000..7f6534d1f --- /dev/null +++ b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py @@ -0,0 +1,540 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Nemotron Bidirectional Model for NeMo AutoModel. + +This module provides a bidirectional attention variant of Nemotron that is useful +for embedding and retrieval tasks. Unlike the standard causal Nemotron model, +this version can attend to all tokens bidirectionally. +""" + +import copy +import os +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +from transformers import PreTrainedModel +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + ModelOutput, + SequenceClassifierOutputWithPast, +) +from nemo_automodel.components.models.nemotron.configuration_nemotron_h import NemotronHConfig +from nemo_automodel.components.models.nemotron.modeling_nemotron_h import ( + NemotronHModel, + NemotronHOutput, +) + +try: + from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter +except ImportError: + BiencoderStateDictAdapter = object + +from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator + +logger = logging.get_logger(__name__) +check_model_inputs = get_check_model_inputs_decorator() + +def contrastive_scores_and_labels( + query: torch.Tensor, key: torch.Tensor, current_train_n_passages: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute contrastive scores and labels without in-batch negatives. + + Args: + query: Query embeddings [batch_size, hidden_dim] + key: Key/passage embeddings [batch_size * n_passages, hidden_dim] + current_train_n_passages: Number of passages per query + + Returns: + Tuple of (scores, labels) where scores is [batch_size, n_passages] + and labels is [batch_size] of zeros (positive is first passage) + """ + assert key.shape[0] % query.shape[0] == 0, "{} % {} > 0".format(key.shape[0], query.shape[0]) + query_shape = query.shape + repeated_query = query.repeat(1, 1, current_train_n_passages).reshape( + query_shape[0] * current_train_n_passages, query_shape[1] + ) + qk = torch.sum(repeated_query * key, dim=-1).reshape(query_shape[0], current_train_n_passages) + labels = torch.zeros(query_shape[0], dtype=torch.long, device=query.device) + return qk, labels + + + +def pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str) -> torch.Tensor: + """ + Pool hidden states using the specified pooling method. + + Args: + last_hidden_states: Hidden states from the model [batch_size, seq_len, hidden_size] + attention_mask: Attention mask [batch_size, seq_len] + pool_type: Type of pooling to apply + + Returns: + Pooled embeddings [batch_size, hidden_size] + """ + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + + if pool_type == "avg": + emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + elif pool_type == "weighted_avg": + emb = last_hidden.sum(dim=1) + elif pool_type == "cls": + emb = last_hidden[:, 0] + elif pool_type == "last": + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + emb = last_hidden[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden.shape[0] + emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] + elif pool_type == "cls_last": + emb = last_hidden[:, 0] + elif pool_type == "colbert": + emb = last_hidden + else: + raise ValueError(f"pool_type {pool_type} not supported") + + return emb + + + + + +class NemotronBidirectionalConfig(NemotronHConfig): + """ + Configuration class for NemotronBidirectionalModel. + + Extends NemotronHConfig with additional parameters for bidirectional attention + and pooling configurations. + """ + + model_type = "nemotron_bidirectional" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.pooling = kwargs.get("pooling", "last") + self.temperature = kwargs.get("temperature", 1.0) + self.use_cache = kwargs.get("use_cache", False) + logger.info(f"NemotronBidirectionalConfig initialized with pooling: {pooling} and temperature: {temperature}") + logger.info(f"NemotronBidirectionalConfig initialized with kwargs: {kwargs}") + + + +class NemotronBidirectionalModel(NemotronHModel): + """ + Nemotron Bidirectional Model. + + This model is a bidirectional Nemotron model for embedding tasks. + """ + config_class = NemotronBidirectionalConfig + + def __init__(self, config: NemotronBidirectionalConfig): + super().__init__(config) + self.config = config + self.model = None + self.tokenizer = None + + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, NemotronHOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # From zamba_modeling.py + if use_cache and cache_params is None: + logger.warning_once( + "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + hidden_states = inputs_embeds + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + # Until HERE + + for layer_idx, mixer_block in enumerate(self.layers): + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + if mixer_block.block_type == "mamba": + layer_mask = mamba_mask + elif mixer_block.block_type == "attention": + layer_mask = causal_mask + elif mixer_block.block_type == "mlp": + layer_mask = None + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return NemotronHOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + +@dataclass +class BiencoderOutput(ModelOutput): + """Output dataclass for biencoder model.""" + + q_reps: Optional[Tensor] = None + p_reps: Optional[Tensor] = None + loss: Optional[Tensor] = None + labels: Optional[Tensor] = None + scores: Optional[Tensor] = None + + +class BiencoderModel(nn.Module): + """ + Biencoder Model with essential functions for training. + + This model encodes queries and passages separately and computes contrastive loss. + """ + + def __init__( + self, + lm_q: PreTrainedModel, + lm_p: PreTrainedModel, + linear_pooler: nn.Module = None, + train_n_passages: int = 1, + eval_negative_size: int = 0, + pooling: str = "last", + l2_normalize: bool = True, + t: float = 1.0, + share_encoder: bool = True, + add_linear_pooler: bool = False, + ): + super().__init__() + self.lm_q = lm_q + self.lm_p = lm_p + self.train_n_passages = train_n_passages + self.eval_negative_size = eval_negative_size + self.pooling = pooling + self.l2_normalize = l2_normalize + self.t = t + self.share_encoder = share_encoder + self.add_linear_pooler = add_linear_pooler + self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") + self.linear_pooler = linear_pooler if linear_pooler is not None else nn.Identity() + self.config = self.lm_q.config + self.trainer = None + + # For HuggingFace consolidated checkpoint compatibility + self.name_or_path = os.path.abspath(__file__) + self.state_dict_adapter = BiencoderStateDictAdapter() + self.config.architectures = ["NemotronBidirectionalModel"] + self.config.auto_map = { + "AutoModel": "nemotron_bidirectional_model.NemotronBidirectionalModel", + "AutoConfig": "nemotron_bidirectional_model.NemotronBidirectionalConfig", + } + + def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None): + """Forward pass for training.""" + + # Get current number of passages per query + if self.training: + current_train_n_passages = self.train_n_passages + else: + current_train_n_passages = self.eval_negative_size + 1 + + # Compute scores (encoding happens inside _compute_scores) + scores, labels, q_reps, p_reps = self._compute_scores( + query=query, + passage=passage, + current_train_n_passages=current_train_n_passages, + ) + loss = self.cross_entropy(scores, labels) + + # Adding Dummy Gradients for vlm-based models + if hasattr(self.lm_q, "module") and hasattr(self.lm_q.module, "post_loss"): + loss = self.lm_q.module.post_loss(loss, passage) + elif hasattr(self.lm_q, "post_loss"): + # Not tested this branch + loss = self.lm_q.post_loss(loss, passage) + + return BiencoderOutput( + loss=loss, + q_reps=q_reps, + p_reps=p_reps, + labels=labels.contiguous(), + scores=scores, + ) + + def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch.Tensor]: + """Encode input using the encoder.""" + if not input_dict: + return None + + import inspect + + # Remove token_type_ids if encoder doesn't support it + if ( + "token_type_ids" not in inspect.getfullargspec(encoder.forward).args + and "token_type_ids" in input_dict.keys() + ): + input_dict = {k: v for k, v in input_dict.items() if k != "token_type_ids"} + + # Get encoder outputs + outputs = encoder( + **{k: v for k, v in input_dict.items() if k not in ["kd_labels"]}, + return_dict=True, + output_hidden_states=True, + ) + + # Extract hidden states + if hasattr(outputs, "last_hidden_state"): + hidden_state = outputs.last_hidden_state + else: + hidden_state = outputs.hidden_states[-1] + + # Pool the representations + embeds = pool( + last_hidden_states=hidden_state, + attention_mask=input_dict["attention_mask"], + pool_type=self.pooling, + ) + + # Apply linear pooler + embeds = self.linear_pooler(embeds) + + # L2 normalize if required + if self.l2_normalize: + embeds = F.normalize(embeds, dim=-1) + + return embeds.contiguous() + + def _compute_scores( + self, + current_train_n_passages: int, + query: Dict[str, Tensor] = None, + passage: Dict[str, Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute similarity scores and labels.""" + + # Encode query and passage + q_reps = self._encode(self.lm_q, query) + p_reps = self._encode(self.lm_p, passage) + + # Compute similarity scores using contrastive_scores_and_labels + scores, labels = contrastive_scores_and_labels( + query=q_reps, + key=p_reps, + current_train_n_passages=current_train_n_passages, + ) + + if self.l2_normalize: + scores = scores / self.t + + return scores, labels, q_reps, p_reps + + @classmethod + def build( + cls, + model_name_or_path: str, + share_encoder: bool = True, + add_linear_pooler: bool = False, + out_dimension: int = 768, + do_gradient_checkpointing: bool = False, + train_n_passages: int = 1, + eval_negative_size: int = 0, + pooling: str = "avg", + l2_normalize: bool = True, + t: float = 1.0, + **hf_kwargs, + ): + """ + Build biencoder model from pretrained. + + Args: + model_name_or_path: Path to pretrained model or model identifier + share_encoder: Whether to share encoder weights between query and passage + add_linear_pooler: Whether to add a linear pooler layer + out_dimension: Output dimension for linear pooler + do_gradient_checkpointing: Whether to enable gradient checkpointing + train_n_passages: Number of passages per query during training + eval_negative_size: Number of negative samples during evaluation + pooling: Pooling strategy ('avg', 'cls', 'last', etc.) + l2_normalize: Whether to L2 normalize embeddings + t: Temperature for scaling similarity scores + **hf_kwargs: Additional arguments passed to model loading + """ + + logger.info(f"Building BiencoderModel from {model_name_or_path}") + + # Infer model class from model_name_or_path + # Check config.json if it exists + config_path = os.path.join(model_name_or_path, "config.json") if os.path.isdir(model_name_or_path) else None + + if config_path and os.path.exists(config_path): + import json + + with open(config_path, "r") as f: + config = json.load(f) + model_type = config.get("model_type", "") + else: + # If no config, infer from model name + model_type = "" + + # Select model class based on model type + if model_type == "nemotron" or "nemotron" in model_name_or_path.lower(): + ModelClass = NemotronBidirectionalModel + logger.info("Using NemotronBidirectionalModel") + else: + raise ValueError( + f"Unsupported model type: {model_type}. Cannot infer model class from {model_name_or_path}" + ) + + # Load model locally or from hub using selected model class + if os.path.isdir(model_name_or_path): + if share_encoder: + lm_q = ModelClass.from_pretrained(model_name_or_path, trust_remote_code=True, **hf_kwargs) + lm_p = lm_q + else: + _qry_model_path = os.path.join(model_name_or_path, "query_model") + _psg_model_path = os.path.join(model_name_or_path, "passage_model") + + if not os.path.exists(_qry_model_path): + _qry_model_path = model_name_or_path + _psg_model_path = model_name_or_path + + lm_q = ModelClass.from_pretrained(_qry_model_path, trust_remote_code=True, **hf_kwargs) + lm_p = ModelClass.from_pretrained(_psg_model_path, trust_remote_code=True, **hf_kwargs) + else: + # Load from hub + lm_q = ModelClass.from_pretrained(model_name_or_path, **hf_kwargs) + + if share_encoder: + lm_p = lm_q + else: + lm_p = copy.deepcopy(lm_q) + + # Enable gradient checkpointing if requested + if do_gradient_checkpointing: + lm_q.gradient_checkpointing_enable() + if lm_p is not lm_q: + lm_p.gradient_checkpointing_enable() + + # Create linear pooler if needed + if add_linear_pooler: + linear_pooler = nn.Linear(lm_q.config.hidden_size, out_dimension) + + pooler_path = os.path.join(model_name_or_path, "pooler.pt") + if os.path.exists(pooler_path): + logger.info("Loading pooler weights from local files") + state_dict = torch.load(pooler_path, map_location="cpu") + linear_pooler.load_state_dict(state_dict) + else: + linear_pooler = nn.Identity() + + model = cls( + lm_q=lm_q, + lm_p=lm_p, + linear_pooler=linear_pooler, + train_n_passages=train_n_passages, + eval_negative_size=eval_negative_size, + pooling=pooling, + l2_normalize=l2_normalize, + t=t, + share_encoder=share_encoder, + add_linear_pooler=add_linear_pooler, + ) + return model + + def save(self, output_dir: str): + """Save model to output directory.""" + + logger.info(f"Saving BiencoderModel to {output_dir}") + + # Save the model + if self.share_encoder: + self.lm_q.save_pretrained(output_dir) + else: + os.makedirs(os.path.join(output_dir, "query_model"), exist_ok=True) + os.makedirs(os.path.join(output_dir, "passage_model"), exist_ok=True) + self.lm_q.save_pretrained(os.path.join(output_dir, "query_model")) + self.lm_p.save_pretrained(os.path.join(output_dir, "passage_model")) + + # Save linear pooler if exists + if self.add_linear_pooler: + pooler_path = os.path.join(output_dir, "pooler.pt") + logger.info(f"Saving linear pooler to {pooler_path}") + torch.save(self.linear_pooler.state_dict(), pooler_path) diff --git a/nemo_automodel/components/models/nemotron/configuration_nemotron_h.py b/nemo_automodel/components/models/nemotron/configuration_nemotron_h.py new file mode 100644 index 000000000..21098cbc3 --- /dev/null +++ b/nemo_automodel/components/models/nemotron/configuration_nemotron_h.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NemotronH model configuration""" + +import re + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class NemotronHConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NemotronHModel`]. It is used to instantiate a + NemotronH model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the NemotronH-v0.1 model. + + [todo](todo) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the NemotronH model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronHModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 21504): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 52): + Number of hidden layers in the Transformer encoder. + hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`): + The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + attention_head_dim (`int`, *optional*, defaults to 128): + Dimension of each attention head. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. + mlp_hidden_act (`str`, *optional*, defaults to "relu2"): + The non-linear activation function in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in MLP layers. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + residual_in_fp32 (`bool`, *optional*, defaults to `False`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*, defaults to None): + Sliding window attention window size. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden states. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. + ssm_state_size (`int`, *optional*, defaults to 128): + The dimension of the mamba state space latents. + mamba_num_heads (`int`, *optional*, defaults to 128): + Number of heads in Mamba layers. + mamba_n_groups (`int`, *optional*, defaults to 8): + Number of groups in Mamba layers. + mamba_head_dim (`int`, *optional*, defaults to 64): + Dimension of each Mamba head. + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor used to determine the mamba intermediate size. + mamba_hidden_act (`str`, *optional*, defaults to "silu"): + The non-linear activation function in the Mamba layers. + mamba_dt_min (`float`, *optional*, defaults to 0.001): + Minimum value for the time step in Mamba. + mamba_dt_max (`float`, *optional*, defaults to 0.1): + Maximum value for the time step in Mamba. + mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))): + Limits for the time step in Mamba. + mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4): + Floor value for time step initialization in Mamba. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the input and output projections of the mamba mixer block. + mamba_chunk_size (`int`, *optional*, defaults to 256): + Size of chunks for Mamba processing. + rescale_prenorm_residual (`bool`, *optional*, defaults to `True`): + Whether to rescale the pre-normalization residual connections. + """ + + model_type = "nemotron_h" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=21504, + num_hidden_layers=52, + hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + num_attention_heads=32, + attention_head_dim=128, + num_key_value_heads=8, # nemo: num_query_groups + mlp_hidden_act="relu2", + attention_bias=False, + mlp_bias=False, + use_bias=False, + initializer_range=0.02, # nemo: init_method_std + layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon + residual_in_fp32=False, # Megatron Core default value + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + max_position_embeddings=4096, + attention_dropout=0.0, + hidden_dropout=0.0, # * ADDED + use_mamba_kernels=True, + ssm_state_size=128, # mamba_state_size + mamba_num_heads=128, + mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads + mamba_head_dim=64, + mamba_d_conv=4, + mamba_expand=2, + mamba_hidden_act="silu", + mamba_dt_min=0.001, + mamba_dt_max=0.1, + mamba_dt_limit=(0.0, float("inf")), + mamba_dt_init_floor=1e-4, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_chunk_size=256, + rescale_prenorm_residual=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.hybrid_override_pattern = hybrid_override_pattern + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.sliding_window = sliding_window + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + + # Validate hybrid_override_pattern + # M: Mamba2, *: Attention, -: MLP + assert len(self.hybrid_override_pattern) == self.num_hidden_layers, "hybrid_override_pattern must have the same length as num_hidden_layers" + assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), "hybrid_override_pattern must only contain characters 'M', '*', or '-'" + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_hidden_act = mlp_hidden_act + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.use_bias = use_bias + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.residual_in_fp32 = residual_in_fp32 + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.use_mamba_kernels = use_mamba_kernels + self.n_groups = mamba_n_groups + self.mamba_head_dim = mamba_head_dim + self.ssm_state_size = ssm_state_size + self.mamba_num_heads = mamba_num_heads + self.conv_kernel = mamba_d_conv + self.expand = mamba_expand + self.mamba_hidden_act = mamba_hidden_act + self.time_step_min = mamba_dt_min + self.time_step_max = mamba_dt_max + self.time_step_limit = mamba_dt_limit + self.time_step_floor = mamba_dt_init_floor + self.use_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.chunk_size = mamba_chunk_size + self.rescale_prenorm_residual = rescale_prenorm_residual + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + "mamba" if self.hybrid_override_pattern[i] == "M" else + "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + for i in range(self.num_hidden_layers)] \ No newline at end of file diff --git a/nemo_automodel/components/models/nemotron/modeling_nemotron_h.py b/nemo_automodel/components/models/nemotron/modeling_nemotron_h.py new file mode 100644 index 000000000..f04b17d36 --- /dev/null +++ b/nemo_automodel/components/models/nemotron/modeling_nemotron_h.py @@ -0,0 +1,1638 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch NemotronH model.""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from transformers.utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_mamba_2_ssm_available, +) +from .configuration_nemotron_h import NemotronHConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH +# For Mamba2 components Mamba2->NemotronHMamba2 +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None + +try: + #from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn +except ImportError: + raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported") + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) +) + + +_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K" +_CONFIG_FOR_DOC = "NemotronHConfig" + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_override_pattern + self.has_previous_state = False # only used by mamba + intermediate_size = config.expand * config.hidden_size + ssm_state_size = config.ssm_state_size + conv_kernel_size = config.conv_kernel + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "M": + # Mamba layer + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + self.ssm_states += [ + torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + # jan28b version + def forward(self, hidden_states, gate=None): + return rmsnorm_fn(x=hidden_states, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self.variance_epsilon, + group_size=self.group_size, + norm_before_gate=False + ) + +class NemotronHMamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: NemotronHConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.ssm_state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.mamba_hidden_act + self.act = ACT2FN[config.mamba_hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + + self.n_groups = config.n_groups + self.head_dim = config.mamba_head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class NemotronHRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # Weights are in float32 + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + +class NemotronHBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + # M: Mamba2, *: Attention, -: MLP + self.block_type = config.layers_block_type[layer_idx] + if self.block_type == "mamba": + self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) + elif self.block_type == "attention": + self.mixer = NEMOTRONH_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + elif self.block_type == "mlp": + self.mixer = NemotronHMLP(config, layer_idx=layer_idx) + else: + raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}") + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + if self.block_type == "mamba": + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position + ) + elif self.block_type == "attention": + hidden_states = self.mixer( + hidden_states, cache_position=cache_position + ) + hidden_states = hidden_states[0] + elif self.block_type == "mlp": + hidden_states = self.mixer( + hidden_states + ) + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + hidden_states = residual + hidden_states + return hidden_states + + +# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH +class NemotronHMLP(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size + self.mlp_idx = config.hybrid_override_pattern[:layer_idx+1].count("-")-1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + self.intermediate_size = config.intermediate_size[0] + else: + self.intermediate_size = config.intermediate_size[self.mlp_idx] + else: + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.mlp_hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class NemotronHAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if config.attention_head_dim is not None: + self.head_dim = config.attention_head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + #attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba +#class JambaFlashAttention2(JambaAttention): +class NemotronHFlashAttention2(NemotronHAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + #attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba +#class JambaSdpaAttention(JambaAttention): +class NemotronHSdpaAttention(NemotronHAttention): + """ + Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from NemotronHAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +NEMOTRONH_ATTENTION_CLASSES = { + "eager": NemotronHAttention, + "flash_attention_2": NemotronHFlashAttention2, + "sdpa": NemotronHSdpaAttention, +} + +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel +class NemotronHPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NemotronHConfig + base_model_prefix = "backbone" + _no_split_modules = ["NemotronHBlock"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, NemotronHMamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.mamba_num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + # TODO: Check + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH +class NemotronHOutput(ModelOutput): + """ + Class for the NemotronH model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH +class NemotronHCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +NEMOTRONH_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NEMOTRONH_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + cache_params (`HybridMambaAttentionDynamicCache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the current input in the cache. This is used to ensure that the cache is correctly updated. + If `cache_params` is passed, `cache_position` should also be passed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) +""" + + +@add_start_docstrings( + "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.", + NEMOTRONH_START_DOCSTRING, +) +class NemotronHModel(NemotronHPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, NemotronHOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # From zamba_modeling.py + if use_cache and cache_params is None: + logger.warning_once( + "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + hidden_states = inputs_embeds + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + # Until HERE + + for layer_idx, mixer_block in enumerate(self.layers): + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + if mixer_block.block_type == "mamba": + layer_mask = mamba_mask + elif mixer_block.block_type == "attention": + layer_mask = causal_mask + elif mixer_block.block_type == "mlp": + layer_mask = None + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + # TODO: Store attentions + # if output_attentions: + # if layer_outputs[1] is not None: + # # append attentions only of attention layers. Mamba layers return `None` as the attention weights + # all_self_attns += (layer_outputs[1],) + + # TODO (Check): should it happen before the forward pass? + # if output_hidden_states: + # all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return NemotronHOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +@add_start_docstrings( + """ + The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input + embeddings). + """, + NEMOTRONH_START_DOCSTRING, +) +class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = NemotronHModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def set_decoder(self, decoder): + self.model = decoder + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if ( + inputs_embeds is not None # Exception 1 + or cache_position[-1] >= input_ids.shape[1] # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, NemotronHCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + nemotron_h_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = nemotron_h_outputs[0] + + # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2 + #logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + nemotron_h_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return NemotronHCausalLMOutput( + loss=loss, + logits=logits, + cache_params=nemotron_h_outputs.cache_params, + hidden_states=nemotron_h_outputs.hidden_states, + attentions=nemotron_h_outputs.attentions, + ) From d97718bcdecc2e11315f0ec91720416fb0cfb13c Mon Sep 17 00:00:00 2001 From: Vinay Raman Date: Wed, 28 Jan 2026 00:07:13 -0800 Subject: [PATCH 2/8] use nemotron_biencoder --- nemo_automodel/components/models/biencoder/biencoder_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_automodel/components/models/biencoder/biencoder_model.py b/nemo_automodel/components/models/biencoder/biencoder_model.py index acbbb0dec..3cc25c879 100644 --- a/nemo_automodel/components/models/biencoder/biencoder_model.py +++ b/nemo_automodel/components/models/biencoder/biencoder_model.py @@ -25,7 +25,8 @@ _patch_liger_kernel, ) -from .llama_bidirectional_model import BiencoderModel +# from .llama_bidirectional_model import BiencoderModel +from .nemotron_bidirectional_model import BiencoderModel logger = logging.getLogger(__name__) From 280f6c071323db481c06d11a4e09f3d41d009ae3 Mon Sep 17 00:00:00 2001 From: VINAY RAMAN Date: Wed, 28 Jan 2026 19:48:36 -0800 Subject: [PATCH 3/8] fixed issues, training job works --- .../biencoder/nemotron_nanov3_biencoder.yaml | 95 +++++++++++++++++++ .../components/models/biencoder/__init__.py | 14 ++- .../models/biencoder/biencoder_model.py | 6 +- .../biencoder/nemotron_bidirectional_model.py | 56 +++++++---- 4 files changed, 146 insertions(+), 25 deletions(-) create mode 100644 examples/biencoder/nemotron_nanov3_biencoder.yaml diff --git a/examples/biencoder/nemotron_nanov3_biencoder.yaml b/examples/biencoder/nemotron_nanov3_biencoder.yaml new file mode 100644 index 000000000..7a860a785 --- /dev/null +++ b/examples/biencoder/nemotron_nanov3_biencoder.yaml @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# To run this recipe, please use the following command: +# python examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml +# Or with torchrun for multi-GPU: +# torchrun --nproc-per-node=8 examples/biencoder/finetune.py --config examples/biencoder/llama3_2_1b_biencoder.yaml + +seed: 125 + +step_scheduler: + global_batch_size: 128 + local_batch_size: 1 + ckpt_every_steps: 3000 + val_every_steps: 3000 + num_epochs: 1 + +dist_env: + backend: nccl + timeout_minutes: 1 + +model: + _target_: nemo_automodel.components.models.biencoder.NeMoAutoModelBiencoder.from_pretrained + pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-Nano-9B-v2 + train_n_passages: 5 + pooling: last + t: 0.02 + torch_dtype: bfloat16 + attn_implementation: eager + + +tokenizer: + _target_: transformers.AutoTokenizer.from_pretrained + pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-Nano-9B-v2 + +dataloader: + _target_: torchdata.stateful_dataloader.StatefulDataLoader + dataset: + _target_: nemo_automodel.components.datasets.llm.make_retrieval_dataset + data_dir_list: + - ./embed_nemotron_dataset_v1/TriviaQA/TriviaQA.json + - ./embed_nemotron_dataset_v1/SyntheticClassificationData/SyntheticClassificationData.json + + data_type: train + train_n_passages: 5 + seed: 125 + do_shuffle: true + use_dataset_instruction: true + collate_fn: + _target_: nemo_automodel.components.datasets.llm.RetrievalBiencoderCollator + q_max_len: 512 + p_max_len: 512 + query_prefix: "" + passage_prefix: "" + pad_to_multiple_of: 8 + use_dataset_instruction: true + shuffle: true + num_workers: 8 + +optimizer: + _target_: torch.optim.AdamW + lr: 2.0e-6 + weight_decay: 0.01 + +lr_scheduler: + lr_warmup_steps: 2 + +checkpoint: + enabled: true + checkpoint_dir: ./output/llama_embed_nemotron_8b + model_save_format: safetensors + save_consolidated: true + +distributed: + _target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager + dp_size: none + dp_replicate_size: 1 + tp_size: 1 + cp_size: 1 + sequence_parallel: false + +dist_env: + backend: nccl + timeout_minutes: 30 diff --git a/nemo_automodel/components/models/biencoder/__init__.py b/nemo_automodel/components/models/biencoder/__init__.py index 20cd00f45..3d38093cf 100644 --- a/nemo_automodel/components/models/biencoder/__init__.py +++ b/nemo_automodel/components/models/biencoder/__init__.py @@ -20,12 +20,18 @@ """ from .biencoder_model import NeMoAutoModelBiencoder # noqa: F401 -from .llama_bidirectional_model import ( # noqa: F401 +# from .llama_bidirectional_model import ( # noqa: F401 +# BiencoderModel, +# BiencoderOutput, +# LlamaBidirectionalConfig, +# LlamaBidirectionalForSequenceClassification, +# LlamaBidirectionalModel, +# ) +from .nemotron_bidirectional_model import ( # noqa: F401 BiencoderModel, BiencoderOutput, - LlamaBidirectionalConfig, - LlamaBidirectionalForSequenceClassification, - LlamaBidirectionalModel, + NemotronBidirectionalConfig, + NemotronBidirectionalModel, ) __all__ = [ diff --git a/nemo_automodel/components/models/biencoder/biencoder_model.py b/nemo_automodel/components/models/biencoder/biencoder_model.py index 3cc25c879..c7069ac38 100644 --- a/nemo_automodel/components/models/biencoder/biencoder_model.py +++ b/nemo_automodel/components/models/biencoder/biencoder_model.py @@ -15,6 +15,7 @@ import gc import logging from typing import List, Optional +import torch from torch.nn.attention import SDPBackend @@ -114,8 +115,9 @@ def _retry(**override): ) # Use BiencoderModel.build to initialize model with base encoders - hf_kwargs = {"attn_implementation": "flash_attention_2"} - kwargs.update(hf_kwargs) + # Only set attn_implementation if not already provided in kwargs + if "attn_implementation" not in kwargs: + kwargs["attn_implementation"] = "eager" model = BiencoderModel.build( model_name_or_path=pretrained_model_name_or_path, share_encoder=share_encoder, diff --git a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py index 7f6534d1f..5d5377072 100644 --- a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py +++ b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py @@ -27,6 +27,12 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +import pdb + from transformers import PreTrainedModel from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( @@ -34,8 +40,12 @@ ModelOutput, SequenceClassifierOutputWithPast, ) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, logging + from nemo_automodel.components.models.nemotron.configuration_nemotron_h import NemotronHConfig from nemo_automodel.components.models.nemotron.modeling_nemotron_h import ( + HybridMambaAttentionDynamicCache, NemotronHModel, NemotronHOutput, ) @@ -50,6 +60,9 @@ logger = logging.get_logger(__name__) check_model_inputs = get_check_model_inputs_decorator() + + + def contrastive_scores_and_labels( query: torch.Tensor, key: torch.Tensor, current_train_n_passages: int ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -132,7 +145,7 @@ def __init__(self, **kwargs): self.pooling = kwargs.get("pooling", "last") self.temperature = kwargs.get("temperature", 1.0) self.use_cache = kwargs.get("use_cache", False) - logger.info(f"NemotronBidirectionalConfig initialized with pooling: {pooling} and temperature: {temperature}") + logger.info(f"NemotronBidirectionalConfig initialized with pooling: {self.pooling} and temperature: {self.temperature}") logger.info(f"NemotronBidirectionalConfig initialized with kwargs: {kwargs}") @@ -144,6 +157,7 @@ class NemotronBidirectionalModel(NemotronHModel): This model is a bidirectional Nemotron model for embedding tasks. """ config_class = NemotronBidirectionalConfig + main_input_name = "input_ids" def __init__(self, config: NemotronBidirectionalConfig): super().__init__(config) @@ -152,22 +166,22 @@ def __init__(self, config: NemotronBidirectionalConfig): self.tokenizer = None - @check_model_inputs - @auto_docstring + # @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, cache_params: Optional[HybridMambaAttentionDynamicCache] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, NemotronHOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -179,7 +193,6 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) @@ -197,7 +210,7 @@ def forward( ) hidden_states = inputs_embeds - + if cache_position is None: cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) if position_ids is None: @@ -209,8 +222,8 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None # Until HERE - - for layer_idx, mixer_block in enumerate(self.layers): + + for mixer_block in self.layers: # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) if mixer_block.block_type == "mamba": layer_mask = mamba_mask @@ -223,12 +236,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - + if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask ) else: + hidden_states = mixer_block( hidden_states, cache_params=cache_params, @@ -345,20 +359,24 @@ def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch. return None import inspect - + # Remove token_type_ids if encoder doesn't support it if ( "token_type_ids" not in inspect.getfullargspec(encoder.forward).args and "token_type_ids" in input_dict.keys() ): input_dict = {k: v for k, v in input_dict.items() if k != "token_type_ids"} - + # Get encoder outputs - outputs = encoder( - **{k: v for k, v in input_dict.items() if k not in ["kd_labels"]}, - return_dict=True, - output_hidden_states=True, - ) + # outputs = encoder( + # **{k: v for k, v in input_dict.items() if k not in ["kd_labels"]}, + # return_dict=True, + # output_hidden_states=True, + # ) + outputs = encoder(input_ids=input_dict["input_ids"], + attention_mask=input_dict["attention_mask"], + output_hidden_states=True, + return_dict=True,) # Extract hidden states if hasattr(outputs, "last_hidden_state"): From 443c0f1c4fe46c1be5a4d74cbe0551c2897d8db2 Mon Sep 17 00:00:00 2001 From: VINAY RAMAN Date: Wed, 28 Jan 2026 21:17:12 -0800 Subject: [PATCH 4/8] added wandb to config.yaml --- .../biencoder/nemotron_nanov3_biencoder.yaml | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/biencoder/nemotron_nanov3_biencoder.yaml b/examples/biencoder/nemotron_nanov3_biencoder.yaml index 7a860a785..4883a26b6 100644 --- a/examples/biencoder/nemotron_nanov3_biencoder.yaml +++ b/examples/biencoder/nemotron_nanov3_biencoder.yaml @@ -22,8 +22,8 @@ seed: 125 step_scheduler: global_batch_size: 128 local_batch_size: 1 - ckpt_every_steps: 3000 - val_every_steps: 3000 + ckpt_every_steps: 100 + val_every_steps: 100 num_epochs: 1 dist_env: @@ -78,10 +78,20 @@ lr_scheduler: checkpoint: enabled: true - checkpoint_dir: ./output/llama_embed_nemotron_8b + checkpoint_dir: ./output/nemotron_nano_9b_biencoder model_save_format: safetensors save_consolidated: true +wandb: + project: nemotron-finetuning + entity: nvidia-merlin # Replace with your wandb entity/username + name: nemotron_nano_9b_biencoder + tags: + - biencoder + - nemotron + - retrieval + notes: "Finetuning Nemotron Nano 9B for biencoder retrieval" + distributed: _target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager dp_size: none From e3b65e66007ae66440bf2c9f96578a7904b20699 Mon Sep 17 00:00:00 2001 From: VINAY RAMAN Date: Wed, 28 Jan 2026 21:54:57 -0800 Subject: [PATCH 5/8] fixed bug relating to partial saving of ckpt --- examples/biencoder/nemotron_nanov3_biencoder.yaml | 4 ++-- nemo_automodel/components/checkpoint/checkpointing.py | 2 +- .../components/models/biencoder/biencoder_model.py | 1 + .../models/biencoder/nemotron_bidirectional_model.py | 6 +----- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/biencoder/nemotron_nanov3_biencoder.yaml b/examples/biencoder/nemotron_nanov3_biencoder.yaml index 4883a26b6..fb7651a15 100644 --- a/examples/biencoder/nemotron_nanov3_biencoder.yaml +++ b/examples/biencoder/nemotron_nanov3_biencoder.yaml @@ -22,8 +22,8 @@ seed: 125 step_scheduler: global_batch_size: 128 local_batch_size: 1 - ckpt_every_steps: 100 - val_every_steps: 100 + ckpt_every_steps: 2 + val_every_steps: 2 num_epochs: 1 dist_env: diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index b166e0339..8225e249b 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -580,7 +580,7 @@ def _maybe_build_consolidated_index( # Add any missing keys from the model_state_dict # These will go to the same file as the last file (or file 1 for single-file models) - default_index = max(fqn_to_file_index_mapping.values()) + default_index = max(fqn_to_file_index_mapping.values()) if fqn_to_file_index_mapping else 1 # add any additional keys that are not in the base checkpoint for fqn in list(state_dict.keys()): diff --git a/nemo_automodel/components/models/biencoder/biencoder_model.py b/nemo_automodel/components/models/biencoder/biencoder_model.py index c7069ac38..5d501617e 100644 --- a/nemo_automodel/components/models/biencoder/biencoder_model.py +++ b/nemo_automodel/components/models/biencoder/biencoder_model.py @@ -118,6 +118,7 @@ def _retry(**override): # Only set attn_implementation if not already provided in kwargs if "attn_implementation" not in kwargs: kwargs["attn_implementation"] = "eager" + model = BiencoderModel.build( model_name_or_path=pretrained_model_name_or_path, share_encoder=share_encoder, diff --git a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py index 5d5377072..29ed014bc 100644 --- a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py +++ b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py @@ -368,11 +368,7 @@ def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch. input_dict = {k: v for k, v in input_dict.items() if k != "token_type_ids"} # Get encoder outputs - # outputs = encoder( - # **{k: v for k, v in input_dict.items() if k not in ["kd_labels"]}, - # return_dict=True, - # output_hidden_states=True, - # ) + outputs = encoder(input_ids=input_dict["input_ids"], attention_mask=input_dict["attention_mask"], output_hidden_states=True, From 54c68e226bc7c53e9aa8b322ff82337fcf727082 Mon Sep 17 00:00:00 2001 From: VINAY RAMAN Date: Thu, 29 Jan 2026 20:09:15 -0800 Subject: [PATCH 6/8] added support for avg pooling --- .../biencoder/nemotron_nanov3_biencoder.yaml | 23 +++- .../biencoder/nemotron_bidirectional_model.py | 124 ++++++++++++++++-- 2 files changed, 131 insertions(+), 16 deletions(-) diff --git a/examples/biencoder/nemotron_nanov3_biencoder.yaml b/examples/biencoder/nemotron_nanov3_biencoder.yaml index fb7651a15..496c5c452 100644 --- a/examples/biencoder/nemotron_nanov3_biencoder.yaml +++ b/examples/biencoder/nemotron_nanov3_biencoder.yaml @@ -22,8 +22,8 @@ seed: 125 step_scheduler: global_batch_size: 128 local_batch_size: 1 - ckpt_every_steps: 2 - val_every_steps: 2 + ckpt_every_steps: 100 + val_every_steps: 100 num_epochs: 1 dist_env: @@ -34,11 +34,28 @@ model: _target_: nemo_automodel.components.models.biencoder.NeMoAutoModelBiencoder.from_pretrained pretrained_model_name_or_path: nvidia/NVIDIA-Nemotron-Nano-9B-v2 train_n_passages: 5 - pooling: last + pooling: avg t: 0.02 torch_dtype: bfloat16 attn_implementation: eager + # Bidirectional processing strategy for Mamba layers + # Options: + # - average: Simple average of forward and backward passes + # - concat: Concatenate forward and backward (doubles hidden size) + # - weighted: Weighted average using forward_weight + # - gated: Learned gating mechanism (requires training) + mamba_bidirectional_strategy: average + + # Forward weight (only used if mamba_bidirectional_strategy=weighted) + # Value between 0 and 1, backward_weight = 1 - forward_weight + forward_weight: 0.5 + + # Bidirectional attention for attention layers + # If true, attention layers use bidirectional attention (all tokens attend to all tokens) + # If false, attention layers use causal attention (tokens only attend to past tokens) + bidirectional_attention: true + tokenizer: _target_: transformers.AutoTokenizer.from_pretrained diff --git a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py index 29ed014bc..2d4d43f76 100644 --- a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py +++ b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py @@ -142,10 +142,15 @@ class NemotronBidirectionalConfig(NemotronHConfig): def __init__(self, **kwargs): super().__init__(**kwargs) - self.pooling = kwargs.get("pooling", "last") + self.pooling = kwargs.get("pooling", "avg") self.temperature = kwargs.get("temperature", 1.0) self.use_cache = kwargs.get("use_cache", False) + self.mamba_bidirectional_strategy = kwargs.get("mamba_bidirectional_strategy", "average") # Options: average, concat, weighted, gated + self.forward_weight = kwargs.get("forward_weight", 0.5) # For weighted strategy + self.bidirectional_attention = kwargs.get("bidirectional_attention", True) # Use bidirectional attention for attention layers logger.info(f"NemotronBidirectionalConfig initialized with pooling: {self.pooling} and temperature: {self.temperature}") + logger.info(f"NemotronBidirectionalConfig initialized with mamba_bidirectional_strategy: {self.mamba_bidirectional_strategy}") + logger.info(f"NemotronBidirectionalConfig initialized with bidirectional_attention: {self.bidirectional_attention}") logger.info(f"NemotronBidirectionalConfig initialized with kwargs: {kwargs}") @@ -164,7 +169,62 @@ def __init__(self, config: NemotronBidirectionalConfig): self.config = config self.model = None self.tokenizer = None + + # Initialize gating layer if using gated bidirectional strategy + if hasattr(config, 'mamba_bidirectional_strategy') and config.mamba_bidirectional_strategy == "gated": + self.gate_layer = nn.Linear(config.hidden_size * 2, 1) + + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + """ + Override parent's causal mask to optionally create a bidirectional attention mask. + If bidirectional_attention is True, all tokens can attend to all other tokens (no causal masking). + If bidirectional_attention is False, uses standard causal masking (parent behavior). + In both cases, padding tokens are properly masked. + """ + # If bidirectional attention is disabled, use parent's causal mask + use_bidirectional = getattr(self.config, 'bidirectional_attention', True) + if not use_bidirectional: + return super()._update_causal_mask(attention_mask, input_tensor, cache_position) + + # Bidirectional attention implementation + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + # Create a full attention mask (all zeros, meaning all tokens can attend to all tokens) + # This is the key difference from causal attention + bidirectional_mask = torch.zeros((sequence_length, target_length), dtype=dtype, device=device) + + # Expand to 4D: [batch, 1, seq_len, target_len] + bidirectional_mask = bidirectional_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + # Apply padding mask if provided + if attention_mask is not None: + bidirectional_mask = bidirectional_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + # Mask out padding tokens (where attention_mask is 0) + padding_mask = attention_mask[:, None, None, :].eq(0.0) + bidirectional_mask[..., :mask_length] = bidirectional_mask[..., :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # For SDPA memory-efficient attention path + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + bidirectional_mask = AttentionMaskConverter._unmask_unattended(bidirectional_mask, min_dtype) + + return bidirectional_mask # @check_model_inputs def forward( @@ -227,28 +287,66 @@ def forward( # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) if mixer_block.block_type == "mamba": layer_mask = mamba_mask + + # Bidirectional processing for Mamba layers + strategy = getattr(self.config, 'mamba_bidirectional_strategy', 'average') + + # Forward pass + hidden_states_forward = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + # Backward pass (flip input and output) + hidden_states_reverse = hidden_states.flip(dims=[1]) + hidden_states_reverse = mixer_block( + hidden_states_reverse, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + hidden_states_backward = hidden_states_reverse.flip(dims=[1]) # Flip back to align positions + + # Combine forward and backward based on strategy + if strategy == "average": + hidden_states = (hidden_states_forward + hidden_states_backward) / 2 + elif strategy == "concat": + hidden_states = torch.cat([hidden_states_forward, hidden_states_backward], dim=-1) + elif strategy == "weighted": + forward_weight = getattr(self.config, 'forward_weight', 0.5) + backward_weight = 1.0 - forward_weight + hidden_states = forward_weight * hidden_states_forward + backward_weight * hidden_states_backward + elif strategy == "gated": + # Learned gating mechanism + combined = torch.cat([hidden_states_forward, hidden_states_backward], dim=-1) + gate = torch.sigmoid(self.gate_layer(combined)) + hidden_states = gate * hidden_states_forward + (1 - gate) * hidden_states_backward + else: + raise ValueError(f"Invalid mamba_bidirectional_strategy: {strategy}. Choose from: average, concat, weighted, gated") + elif mixer_block.block_type == "attention": layer_mask = causal_mask + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) elif mixer_block.block_type == "mlp": layer_mask = None - else: - raise ValueError(f"Invalid block_type: {self.block_type}") - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask - ) - else: - hidden_states = mixer_block( hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=layer_mask, ) + else: + raise ValueError(f"Invalid block_type: {mixer_block.block_type}") + + if output_hidden_states: + all_hidden_states += (hidden_states,) hidden_states = self.norm_f(hidden_states) From 5f7a29ab3604474c15fb9dac0898da7387d263c8 Mon Sep 17 00:00:00 2001 From: VINAY RAMAN Date: Sun, 1 Feb 2026 22:05:30 -0800 Subject: [PATCH 7/8] added minor changes to config --- .../biencoder/nemotron_bidirectional_model.py | 116 +++++++++++------- 1 file changed, 75 insertions(+), 41 deletions(-) diff --git a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py index 2d4d43f76..4d593162a 100644 --- a/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py +++ b/nemo_automodel/components/models/biencoder/nemotron_bidirectional_model.py @@ -97,9 +97,20 @@ def pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_ty last_hidden_states: Hidden states from the model [batch_size, seq_len, hidden_size] attention_mask: Attention mask [batch_size, seq_len] pool_type: Type of pooling to apply + - "avg": Average pooling over all non-padded tokens + - "weighted_avg": Weighted average pooling + - "cls": Use the [CLS] token (first token) + - "last": Use the last non-padded token + - "eos": Use the EOS token (requires EOS token at end of sequence) + - "cls_last": Use the [CLS] token + - "colbert": Return all hidden states (for ColBERT-style models) Returns: - Pooled embeddings [batch_size, hidden_size] + Pooled embeddings [batch_size, hidden_size] or [batch_size, seq_len, hidden_size] for colbert + + Note: + For "eos" pooling, ensure your tokenizer adds EOS tokens to the end of sequences. + The tokenizer should be configured with add_eos_token=True or similar parameter. """ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) @@ -119,6 +130,18 @@ def pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_ty emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] elif pool_type == "cls_last": emb = last_hidden[:, 0] + elif pool_type == "eos": + # Extract hidden state at EOS token position (last non-padded position) + # Similar to "last" pooling, but explicitly for EOS token + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + # If left padding, EOS is at the last position + emb = last_hidden[:, -1] + else: + # If right padding, EOS is at the last non-padded position + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden.shape[0] + emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] elif pool_type == "colbert": emb = last_hidden else: @@ -145,7 +168,7 @@ def __init__(self, **kwargs): self.pooling = kwargs.get("pooling", "avg") self.temperature = kwargs.get("temperature", 1.0) self.use_cache = kwargs.get("use_cache", False) - self.mamba_bidirectional_strategy = kwargs.get("mamba_bidirectional_strategy", "average") # Options: average, concat, weighted, gated + self.mamba_bidirectional_strategy = kwargs.get("mamba_bidirectional_strategy", "average") # Options: unidirectional, average, concat, weighted, gated self.forward_weight = kwargs.get("forward_weight", 0.5) # For weighted strategy self.bidirectional_attention = kwargs.get("bidirectional_attention", True) # Use bidirectional attention for attention layers logger.info(f"NemotronBidirectionalConfig initialized with pooling: {self.pooling} and temperature: {self.temperature}") @@ -169,7 +192,7 @@ def __init__(self, config: NemotronBidirectionalConfig): self.config = config self.model = None self.tokenizer = None - + # Initialize gating layer if using gated bidirectional strategy if hasattr(config, 'mamba_bidirectional_strategy') and config.mamba_bidirectional_strategy == "gated": self.gate_layer = nn.Linear(config.hidden_size * 2, 1) @@ -270,7 +293,7 @@ def forward( ) hidden_states = inputs_embeds - + if cache_position is None: cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) if position_ids is None: @@ -282,7 +305,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None # Until HERE - + for mixer_block in self.layers: # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) if mixer_block.block_type == "mamba": @@ -291,40 +314,51 @@ def forward( # Bidirectional processing for Mamba layers strategy = getattr(self.config, 'mamba_bidirectional_strategy', 'average') - # Forward pass - hidden_states_forward = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=layer_mask, - ) - - # Backward pass (flip input and output) - hidden_states_reverse = hidden_states.flip(dims=[1]) - hidden_states_reverse = mixer_block( - hidden_states_reverse, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=layer_mask, - ) - hidden_states_backward = hidden_states_reverse.flip(dims=[1]) # Flip back to align positions - - # Combine forward and backward based on strategy - if strategy == "average": - hidden_states = (hidden_states_forward + hidden_states_backward) / 2 - elif strategy == "concat": - hidden_states = torch.cat([hidden_states_forward, hidden_states_backward], dim=-1) - elif strategy == "weighted": - forward_weight = getattr(self.config, 'forward_weight', 0.5) - backward_weight = 1.0 - forward_weight - hidden_states = forward_weight * hidden_states_forward + backward_weight * hidden_states_backward - elif strategy == "gated": - # Learned gating mechanism - combined = torch.cat([hidden_states_forward, hidden_states_backward], dim=-1) - gate = torch.sigmoid(self.gate_layer(combined)) - hidden_states = gate * hidden_states_forward + (1 - gate) * hidden_states_backward + # Check if unidirectional (no backward pass needed) + if strategy == "unidirectional": + # Standard unidirectional forward pass only + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) else: - raise ValueError(f"Invalid mamba_bidirectional_strategy: {strategy}. Choose from: average, concat, weighted, gated") + # Bidirectional processing: forward + backward passes + # Forward pass + hidden_states_forward = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + # Backward pass (flip input and output) + hidden_states_reverse = hidden_states.flip(dims=[1]) + hidden_states_reverse = mixer_block( + hidden_states_reverse, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + hidden_states_backward = hidden_states_reverse.flip(dims=[1]) # Flip back to align positions + + # Combine forward and backward based on strategy + if strategy == "average": + hidden_states = (hidden_states_forward + hidden_states_backward) / 2 + elif strategy == "concat": + hidden_states = torch.cat([hidden_states_forward, hidden_states_backward], dim=-1) + elif strategy == "weighted": + forward_weight = getattr(self.config, 'forward_weight', 0.5) + backward_weight = 1.0 - forward_weight + hidden_states = forward_weight * hidden_states_forward + backward_weight * hidden_states_backward + elif strategy == "gated": + # Learned gating mechanism + combined = torch.cat([hidden_states_forward, hidden_states_backward], dim=-1) + gate = torch.sigmoid(self.gate_layer(combined)) + hidden_states = gate * hidden_states_forward + (1 - gate) * hidden_states_backward + else: + raise ValueError(f"Invalid mamba_bidirectional_strategy: {strategy}. Choose from: unidirectional, average, concat, weighted, gated") elif mixer_block.block_type == "attention": layer_mask = causal_mask @@ -457,19 +491,19 @@ def _encode(self, encoder: PreTrainedModel, input_dict: dict) -> Optional[torch. return None import inspect - + # Remove token_type_ids if encoder doesn't support it if ( "token_type_ids" not in inspect.getfullargspec(encoder.forward).args and "token_type_ids" in input_dict.keys() ): input_dict = {k: v for k, v in input_dict.items() if k != "token_type_ids"} - + # Get encoder outputs outputs = encoder(input_ids=input_dict["input_ids"], attention_mask=input_dict["attention_mask"], - output_hidden_states=True, + output_hidden_states=True, return_dict=True,) # Extract hidden states From 335658f9a7f195f184e4aaa8dc7643c0ff7f3353 Mon Sep 17 00:00:00 2001 From: VINAY RAMAN Date: Wed, 4 Feb 2026 12:34:04 -0800 Subject: [PATCH 8/8] minor change to config.yaml name --- ...otron_nanov3_biencoder.yaml => nemotron_nanov2_biencoder.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/biencoder/{nemotron_nanov3_biencoder.yaml => nemotron_nanov2_biencoder.yaml} (100%) diff --git a/examples/biencoder/nemotron_nanov3_biencoder.yaml b/examples/biencoder/nemotron_nanov2_biencoder.yaml similarity index 100% rename from examples/biencoder/nemotron_nanov3_biencoder.yaml rename to examples/biencoder/nemotron_nanov2_biencoder.yaml