Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/models/jina_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class JinaVLScorer(nn.Module):
def __init__(self, model_config: "ModelConfig"):
super().__init__()
config = model_config.hf_config
config = model_config.hf_config.get_text_config()
head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear(
config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True
Expand Down
53 changes: 27 additions & 26 deletions vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
PoolingParamsUpdate,
PoolingType,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -62,19 +62,6 @@ def forward(
return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):
def __init__(self, config: ModernBertConfig, head_size: int, dim: int, base: float):
super().__init__(
head_size=head_size,
rotary_dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
is_neox_style=True,
dtype=torch.float16,
)
self.config = config


class ModernBertAttention(nn.Module):
def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
super().__init__()
Expand All @@ -95,19 +82,33 @@ def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
bias=config.attention_bias,
)

sliding_window = None
if layer_id % config.global_attn_every_n_layers != 0:
sliding_window = config.local_attention // 2
rope_theta = (
config.local_rope_theta
if config.local_rope_theta is not None
else config.global_rope_theta
)
if layer_types := getattr(config, "layer_types", None):
# Transformers v5
layer_type = layer_types[layer_id]
rope_parameters = config.rope_parameters[layer_type]
sliding_window: int | None = None
if layer_type == "sliding_attention":
sliding_window = config.local_attention // 2
else:
rope_theta = config.global_rope_theta

self.rotary_emb = ModernBertRotaryEmbedding(
config=config, head_size=self.head_dim, dim=self.head_dim, base=rope_theta
# Transformers v4
sliding_window = None
if layer_id % config.global_attn_every_n_layers != 0:
sliding_window = config.local_attention // 2
rope_theta = (
config.local_rope_theta
if config.local_rope_theta is not None
else config.global_rope_theta
)
else:
rope_theta = config.global_rope_theta
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}

self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
rope_parameters=rope_parameters,
dtype=torch.float16,
)
self.attn = EncoderOnlyAttention(
self.num_heads,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_config.get_text_config()
quant_config = vllm_config.quant_config

self.config = config
Expand Down