diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index 4ee02f003..825bd535e 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -32,6 +32,7 @@
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel
from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel
from lightllm.models.gemma3.model import Gemma3TpPartModel
+from lightllm.models.glm4v.model import GLM4VTpPartModel
from lightllm.models.tarsier2.model import (
Tarsier2Qwen2TpPartModel,
Tarsier2Qwen2VLTpPartModel,
diff --git a/lightllm/models/glm4v/__init__.py b/lightllm/models/glm4v/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/glm4v/glm4v_visual.py b/lightllm/models/glm4v/glm4v_visual.py
new file mode 100644
index 000000000..ea5e592de
--- /dev/null
+++ b/lightllm/models/glm4v/glm4v_visual.py
@@ -0,0 +1,437 @@
+import os
+import json
+import torch
+import torch.nn as nn
+from PIL import Image
+from io import BytesIO
+from typing import List, Optional
+from torch.nn import LayerNorm
+import torch.nn.functional as F
+from safetensors import safe_open
+from transformers.activations import ACT2FN
+from lightllm.server.multimodal_params import MultimodalParams, ImageItem
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
+from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
+from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
+from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
+
+
+class Glm4vRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Glm4vRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Glm4VisionMlp(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str, bias: bool = False):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Glm4vVisionPatchEmbed(nn.Module):
+ def __init__(self, patch_size: int, temporal_patch_size: int, in_channels: int, embed_dim: int) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Glm4vVisionRotaryEmbedding(nn.Module):
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self._seq_len_cached = 0
+ self._freqs_cos_cached = None
+ self._freqs_sin_cached = None
+
+ def update_freqs_cache(self, seqlen: int) -> None:
+ if seqlen > self._seq_len_cached:
+ seqlen *= 2
+ self._seq_len_cached = seqlen
+ self.inv_freq = 1.0 / (
+ self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim)
+ )
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ self._freqs_cos_cached = freqs.cos()
+ self._freqs_sin_cached = freqs.sin()
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ self.update_freqs_cache(seqlen)
+ return self._freqs_cos_cached[:seqlen], self._freqs_sin_cached[:seqlen]
+
+
+class Glm4vVisionPatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
+ super().__init__()
+ self.proj = nn.Linear(dim, dim, bias=bias)
+ self.post_projection_norm = LayerNorm(dim)
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
+ self.act1 = nn.GELU()
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.proj(hidden_state)
+ hidden_state = self.act1(self.post_projection_norm(hidden_state))
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Glm4vVisionEmbeddings(nn.Module):
+ def __init__(self, hidden_size: int, image_size: int, patch_size: int):
+ super().__init__()
+ self.embed_dim = hidden_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.position_ids = torch.arange(self.num_positions).expand((1, -1))
+
+ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
+ """
+ Forward pass with integrated position encoding adaptation using 2D interpolation.
+
+ Args:
+ embeddings: Input embeddings tensor
+ lengths (torch.Tensor): Sequence lengths for each image in the batch.
+ image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
+ h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
+ w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
+
+ Returns:
+ torch.Tensor: Embeddings with adapted position encoding added.
+ """
+ # Get position embedding parameters
+ pos_embed_weight = self.position_embedding.weight
+ hidden_size = pos_embed_weight.shape[1]
+ total_seq = h_coords.shape[0]
+ device = pos_embed_weight.device
+
+ # Move coordinates to correct device
+ h_coords, w_coords = h_coords.to(device), w_coords.to(device)
+
+ # Handle empty sequence case
+ if total_seq == 0:
+ adapted_pos_embed = torch.empty(0, hidden_size, device=device, dtype=pos_embed_weight.dtype)
+ else:
+ # Convert inputs to tensors if needed
+ if isinstance(lengths, list):
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
+ if not isinstance(image_shapes, torch.Tensor):
+ image_shapes = torch.tensor(image_shapes, device=device, dtype=torch.long)
+
+ # Prepare 2D position embedding
+ orig_size_sq = pos_embed_weight.shape[0]
+ orig_size = int(orig_size_sq ** 0.5)
+ pos_embed_2d = (
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
+ .permute(2, 0, 1)
+ .unsqueeze(0)
+ .to(device=device, dtype=torch.float32)
+ )
+
+ # Calculate target dimensions for each patch
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
+ device=device, dtype=torch.float32
+ )
+
+ # Normalize coordinates to [-1, 1] range for grid_sample
+ h_coords = h_coords.to(device=device, dtype=torch.float32)
+ w_coords = w_coords.to(device=device, dtype=torch.float32)
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
+
+ # Create sampling grid
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
+
+ # Perform bicubic interpolation
+ interpolated_embed_fp32 = F.grid_sample(
+ pos_embed_2d, grid, mode="bicubic", align_corners=False, padding_mode="border"
+ )
+
+ # Reshape and convert back to original dtype
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
+
+ # Add adapted position encoding to embeddings
+ embeddings = embeddings + adapted_pos_embed
+ return embeddings
+
+
+class Glm4vVisionAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int, attention_bias: bool = False, attention_dropout: float = 0.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = self.dim // self.num_heads
+ self.num_key_value_groups = 1 # needed for eager attention
+ self.qkv = nn.Linear(dim, dim * 3, bias=attention_bias)
+ self.proj = nn.Linear(dim, dim, bias=False)
+ self.scaling = self.head_dim ** -0.5
+ self.attention_dropout = attention_dropout
+ self.is_causal = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ max_seqlen: int = 0,
+ rotary_cos: torch.Tensor = None,
+ rotary_sin: torch.Tensor = None,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ q = apply_rotary_pos_emb_triton(q, rotary_cos, rotary_sin)
+ k = apply_rotary_pos_emb_triton(k, rotary_cos, rotary_sin)
+
+ attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
+
+ flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
+ attn_output = attn_output.reshape(seq_length, -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Glm4vVisionBlock(nn.Module):
+ def __init__(self, embed_dim, intermediate_size, num_heads, hidden_act, rms_norm_eps) -> None:
+ super().__init__()
+ self.norm1 = Glm4vRMSNorm(embed_dim, eps=rms_norm_eps)
+ self.norm2 = Glm4vRMSNorm(embed_dim, eps=rms_norm_eps)
+ self.attn = Glm4vVisionAttention(embed_dim, num_heads=num_heads)
+ self.mlp = Glm4VisionMlp(
+ hidden_size=embed_dim, intermediate_size=intermediate_size, hidden_act=hidden_act, bias=False
+ )
+
+ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ rotary_cos=rotary_cos,
+ rotary_sin=rotary_sin,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Glm4vVisionTransformerPretrainedModel(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ depth=24,
+ image_size=336,
+ hidden_size=1536,
+ intermediate_size=13696,
+ out_hidden_size=4096,
+ hidden_act="silu",
+ num_heads=12,
+ in_channels=3,
+ patch_size=14,
+ spatial_merge_size=2,
+ temporal_patch_size=2,
+ rms_norm_eps=1e-5,
+ **kwargs,
+ ):
+ super().__init__()
+ self.data_type = kvargs.get("data_type", "bfloat16")
+ self.depth = depth
+ self.intermediate_size = intermediate_size
+ self.out_hidden_size = out_hidden_size
+ self.hidden_size = hidden_size
+ self.hidden_act = hidden_act
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.spatial_merge_size = spatial_merge_size
+ self.temporal_patch_size = temporal_patch_size
+
+ self.embeddings = Glm4vVisionEmbeddings(hidden_size, image_size, patch_size)
+ self.patch_embed = Glm4vVisionPatchEmbed(patch_size, temporal_patch_size, in_channels, self.hidden_size)
+
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList(
+ [
+ Glm4vVisionBlock(self.hidden_size, self.out_hidden_size, num_heads, hidden_act, rms_norm_eps)
+ for _ in range(self.depth)
+ ]
+ )
+ self.merger = Glm4vVisionPatchMerger(
+ dim=self.out_hidden_size, context_dim=self.intermediate_size, hidden_act=hidden_act
+ )
+
+ self.post_conv_layernorm = Glm4vRMSNorm(hidden_size, eps=rms_norm_eps)
+ self.downsample = nn.Conv2d(
+ in_channels=hidden_size,
+ out_channels=out_hidden_size,
+ kernel_size=spatial_merge_size,
+ stride=spatial_merge_size,
+ )
+ self.post_layernorm = Glm4vRMSNorm(hidden_size, eps=rms_norm_eps)
+
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def load_model(self, weight_dir):
+
+ processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
+ with open(processor_config_path, "r") as f:
+ processor_config_dict = json.load(f)
+ self.processor = Qwen2VLImageProcessor(**processor_config_dict)
+
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "model.visual" in k:
+ weight_dict[k[len("model.visual.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "model.visual" in k:
+ weight_dict[k[len("model.visual.") :]] = f.get_tensor(k)
+
+ self.load_state_dict(weight_dict)
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ s = self.spatial_merge_size
+ for _, h, w in grid_thw:
+ pos_shape = (h // s, s, w // s, s)
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
+ wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
+
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ cos_full, sin_full = self.rotary_pos_emb(max_grid_size)
+ cos = cos_full[pos_ids].flatten(1)
+ sin = sin_full[pos_ids].flatten(1)
+ return cos, sin, pos_ids
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = self.post_conv_layernorm(hidden_states)
+ rotary_cos, rotary_sin, pos_ids = self.rot_pos_emb(grid_thw)
+ rotary_cos = rotary_cos.to("cuda", non_blocking=True)
+ rotary_sin = rotary_sin.to("cuda", non_blocking=True)
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0, dtype=torch.int32
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ cu_seqlens = cu_seqlens.to("cuda", non_blocking=True)
+ hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, pos_ids[:, 0], pos_ids[:, 1])
+
+ for blk in self.blocks:
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ rotary_cos=rotary_cos,
+ rotary_sin=rotary_sin,
+ )
+ hidden_states = self.post_layernorm(hidden_states)
+ hidden_states = hidden_states.view(
+ -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
+ )
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
+ hidden_states = self.downsample(hidden_states).view(-1, self.out_hidden_size)
+ return self.merger(hidden_states)
+
+ def encode(self, images: List[ImageItem]):
+ img_tensors = []
+ valid_ids = []
+ valid_id = 0
+ img_grids = []
+ uuids = []
+ for i, img in enumerate(images):
+ if isinstance(img, ImageItem):
+ uuids.append(img.uuid)
+ image_data = read_shm(get_shm_name_data(img.uuid))
+ image_data = Image.open(BytesIO(image_data))
+ pixel_values, image_grid_thw = self.processor.preprocess(image_data)
+ img_tensors.append(pixel_values)
+ img_grids.append(image_grid_thw)
+ else:
+ raise Exception("Unsupport input types: {} for {}".format(type(img), img))
+
+ # must devide merge_length
+ cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2)
+
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ imgs = torch.cat(img_tensors, dim=0)
+ grid_thw = torch.cat(img_grids, dim=0)
+
+ pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
+ image_grid_thw = grid_thw.to("cuda", non_blocking=True)
+
+ all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw)
+
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/glm4v/layer_infer/__init__.py b/lightllm/models/glm4v/layer_infer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/glm4v/layer_infer/transformer_layer_infer.py b/lightllm/models/glm4v/layer_infer/transformer_layer_infer.py
new file mode 100644
index 000000000..70c884381
--- /dev/null
+++ b/lightllm/models/glm4v/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,104 @@
+import torch
+import torch.functional as F
+import torch.distributed as dist
+import numpy as np
+from typing import Tuple
+from functools import partial
+
+from lightllm.distributed import all_reduce
+from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
+from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused
+from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight
+
+
+class Glm4VTransformerLayerInfer(LlamaTransformerLayerInfer):
+ def __init__(self, layer_num, network_config, mode=[]):
+ super().__init__(layer_num, network_config, mode)
+ mrope_section = network_config["rope_parameters"]["mrope_section"]
+ self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda")
+ self.partial_rotary_factor = network_config["rope_parameters"]["partial_rotary_factor"]
+
+ def _post_self_att_norm(
+ self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight
+ ) -> torch.Tensor:
+ out = self.alloc_tensor(input.shape, input.dtype)
+ rmsnorm_forward(input, weight=layer_weight._post_self_att_norm_weight_.weight, eps=self.eps_, out=out)
+ return out
+
+ def _post_mlp_norm(
+ self, input, infer_state: Qwen2VLInferStateInfo, layer_weight: Glm4VTransformerLayerWeight
+ ) -> torch.Tensor:
+ out = self.alloc_tensor(input.shape, input.dtype)
+ rmsnorm_forward(input, weight=layer_weight._post_mlp_norm_weight_.weight, eps=self.eps_, out=out)
+ return out
+
+ def _get_qkv(self, input, infer_state, layer_weight):
+ q = layer_weight.q_proj.mm(input)
+ cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+ mrope_triton_fused(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ self.mrope_section,
+ partial_rotary_factor=self.partial_rotary_factor,
+ is_interleaved=False,
+ is_glm4v=True,
+ )
+ return q, cache_kv
+
+ def context_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
+ input1 = None
+ self._post_cache_kv(cache_kv, infer_state, layer_weight)
+
+ o = self._TransformerLayerInferTpl__context_attention_wrapper_run(
+ q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
+ )
+
+ q = None
+ o = self._get_o(o, infer_state, layer_weight)
+ if self.tp_world_size_ > 1:
+ all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm
+ input_embdings.add_(o.view(-1, self.embed_dim_))
+ o = None
+
+ input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
+ ffn_out = self._ffn(input1, infer_state, layer_weight)
+ ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm
+ input1 = None
+ if self.tp_world_size_ > 1:
+ all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
+ return input_embdings
+
+ def token_forward(self, input_embdings, infer_state: Qwen2VLInferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
+ input1 = None
+ self._post_cache_kv(cache_kv, infer_state, layer_weight)
+ o = self._token_attention_kernel(q, infer_state, layer_weight)
+ q = None
+ o = self._get_o(o, infer_state, layer_weight)
+ if self.tp_world_size_ > 1:
+ all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ o = self._post_self_att_norm(o, infer_state, layer_weight) # add前多一次norm
+ input_embdings.add_(o.view(-1, self.embed_dim_))
+ o = None
+
+ input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
+ ffn_out = self._ffn(input1, infer_state, layer_weight)
+ ffn_out = self._post_mlp_norm(ffn_out, infer_state, layer_weight) # mlp之后多一次norm
+ input1 = None
+ if self.tp_world_size_ > 1:
+ all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
+ return input_embdings
+
+ def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
+ # TODO
+ raise Exception("not impl")
diff --git a/lightllm/models/glm4v/layer_weight/__init__.py b/lightllm/models/glm4v/layer_weight/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py b/lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py
new file mode 100644
index 000000000..52bfd76f5
--- /dev/null
+++ b/lightllm/models/glm4v/layer_weight/pre_and_post_layer_weight.py
@@ -0,0 +1,14 @@
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import rename_weight_keys
+
+
+class Glm4VPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config, mode):
+ super().__init__(data_type, network_config, mode)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/glm4v/layer_weight/transformer_layer_weight.py b/lightllm/models/glm4v/layer_weight/transformer_layer_weight.py
new file mode 100644
index 000000000..8302f3eea
--- /dev/null
+++ b/lightllm/models/glm4v/layer_weight/transformer_layer_weight.py
@@ -0,0 +1,35 @@
+from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight
+from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight
+
+
+class Glm4VTransformerLayerWeight(Qwen2TransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
+
+ def _init_weight_names(self):
+ self._post_self_att_norm_weight_name = f"model.layers.{self.layer_num_}.post_self_attn_layernorm.weight"
+ self._post_self_att_norm_bias_name = None
+ self._post_mlp_norm_weight_name = f"model.layers.{self.layer_num_}.post_mlp_layernorm.weight"
+ self._post_mlp_norm_bias_name = None
+ super()._init_weight_names()
+
+ def load_hf_weights(self, weights):
+ gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight"
+ if gate_up_weight_name in weights:
+ intermediate_size = self.network_config_["intermediate_size"]
+ gate_up_proj = weights[gate_up_weight_name]
+ gate_weight_ = gate_up_proj[0:intermediate_size, :]
+ up_weight_ = gate_up_proj[intermediate_size:, :]
+ weights[self._gate_weight_name] = gate_weight_
+ weights[self._up_weight_name] = up_weight_
+ del weights[gate_up_weight_name]
+ super().load_hf_weights(weights)
+
+ def _init_norm(self):
+ self._post_self_att_norm_weight_ = NormWeight(
+ self._post_self_att_norm_weight_name, self.data_type_, bias_name=self._post_self_att_norm_bias_name
+ )
+ self._post_mlp_norm_weight_ = NormWeight(
+ self._post_mlp_norm_weight_name, self.data_type_, bias_name=self._post_mlp_norm_bias_name
+ )
+ super()._init_norm()
diff --git a/lightllm/models/glm4v/model.py b/lightllm/models/glm4v/model.py
new file mode 100644
index 000000000..78157fdf7
--- /dev/null
+++ b/lightllm/models/glm4v/model.py
@@ -0,0 +1,87 @@
+import os
+import json
+import numpy as np
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
+from lightllm.models.glm4v.layer_infer.transformer_layer_infer import Glm4VTransformerLayerInfer
+from lightllm.models.glm4v.layer_weight.pre_and_post_layer_weight import Glm4VPreAndPostLayerWeight
+from lightllm.models.glm4v.layer_weight.transformer_layer_weight import Glm4VTransformerLayerWeight
+from lightllm.server.multimodal_params import MultimodalParams
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen2.model import Qwen2TpPartModel
+
+
+class GLM4VTokenizer(QWen2VLTokenizer):
+ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
+ self.tokenizer = tokenizer
+ self.image_processor = image_processor
+ self.min_pixel = self.image_processor.size["shortest_edge"]
+ self.max_pixel = self.image_processor.size["longest_edge"]
+ self.patch_size = self.image_processor.patch_size
+ self.merge_size = self.image_processor.merge_size
+ self.image_start_id = kwargs["model_cfg"]["image_start_token_id"]
+ self.image_end_id = kwargs["model_cfg"]["image_end_token_id"]
+ self.image_token_id = kwargs["model_cfg"]["image_token_id"]
+
+ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
+ origin_ids = self.tokenizer.encode(prompt)
+
+ #
->
+ origin_ids = [token for token in origin_ids if token != self.image_token_id]
+ #
-->
id,id+1...id+num
+ input_ids = []
+ image_id = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.image_start_id)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.image_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.images[image_id].token_id
+ token_num = multimodal_params.images[image_id].token_num
+ multimodal_params.images[image_id].start_idx = len(input_ids)
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.image_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ image_id += 1
+ else:
+ raise ValueError("image token error")
+ except ValueError:
+ break
+ input_ids.extend(origin_ids)
+ return input_ids
+
+
+@ModelRegistry(["glm4v"], is_multimodal=True)
+class GLM4VTpPartModel(Qwen2TpPartModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = Glm4VTransformerLayerInfer
+
+ pre_and_post_weight_class = Glm4VPreAndPostLayerWeight
+ transformer_weight_class = Glm4VTransformerLayerWeight
+
+ infer_state_class = Qwen2VLInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["text_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index a228e0025..420e12e51 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -108,6 +108,8 @@ def _init_custom(self):
模型特殊的一些初始化
"""
rope_scaling = self.config.get("rope_scaling", None)
+ if rope_scaling is None:
+ rope_scaling = self.config.get("rope_parameters", None)
if rope_scaling is None:
self._init_to_get_rotary()
return
@@ -171,14 +173,21 @@ def _init_weights(self):
return
def _init_to_get_rotary(self, default_base=10000):
- partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
+ rope_params = self.config.get("rope_parameters")
+ if rope_params is not None:
+ partial_rotary_factor = rope_params.get("partial_rotary_factor", 1)
+ base = rope_params.get("rope_theta", float(default_base))
+ else:
+ partial_rotary_factor = self.config.get("partial_rotary_factor", 1)
+ base = self.config.get("rope_theta", float(default_base))
+
+ partial_head_dim = int(partial_rotary_factor * self.head_dim_)
+
if self.config.get("rope_scaling", {}) is None:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
- base = self.config.get("rope_theta", float(default_base))
-
if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
diff --git a/lightllm/models/qwen2_vl/triton_kernel/mrope.py b/lightllm/models/qwen2_vl/triton_kernel/mrope.py
index 5aed65862..1d85b84c3 100644
--- a/lightllm/models/qwen2_vl/triton_kernel/mrope.py
+++ b/lightllm/models/qwen2_vl/triton_kernel/mrope.py
@@ -85,6 +85,7 @@ def _mrope_triton_fused_kernel(
stride_kh,
stride_kd,
is_interleaved: tl.constexpr,
+ is_glm4v: tl.constexpr,
HEAD_Q: tl.constexpr,
HEAD_K: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
@@ -95,6 +96,10 @@ def _mrope_triton_fused_kernel(
dim_range0 = tl.arange(0, BLOCK_DMODEL // 2)
dim_range1 = dim_range0 + BLOCK_DMODEL // 2
+ if is_glm4v:
+ dim_range0 = dim_range0 * 2
+ dim_range1 = dim_range0 + 1
+
t_cos = Cos + seq_index * stride_cosd
h_cos = Cos + stride_cosld + seq_index * stride_cosd
w_cos = Cos + 2 * stride_cosld + seq_index * stride_cosd
@@ -192,11 +197,13 @@ def mrope_triton_fused(
cos: torch.Tensor,
sin: torch.Tensor,
mrope_section: torch.Tensor,
- is_interleaved: bool,
+ partial_rotary_factor: float = 1.0,
+ is_interleaved: bool = False,
+ is_glm4v: bool = False,
run_config: Optional[dict] = None,
):
head_num_q, head_num_k = q.shape[1], k.shape[1]
- head_dim = int(q.shape[2])
+ head_dim = int(q.shape[2] * partial_rotary_factor)
num_tokens = q.shape[0]
if not run_config:
@@ -228,6 +235,7 @@ def mrope_triton_fused(
stride_kh=k.stride(1),
stride_kd=k.stride(2),
is_interleaved=is_interleaved,
+ is_glm4v=is_glm4v,
HEAD_Q=head_num_q,
HEAD_K=head_num_k,
BLOCK_DMODEL=head_dim,
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index e0b2bd425..e668156b9 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -29,6 +29,7 @@
from ..models.qwen_vl.model import QWenVLTokenizer
from ..models.qwen2_vl.model import QWen2VLTokenizer
from ..models.qwen3_vl.model import QWen3VLTokenizer
+from ..models.glm4v.model import GLM4VTokenizer
from ..models.internvl.model import InternvlTokenizer
from ..models.gemma3.model import Gemma3Tokenizer
@@ -104,5 +105,10 @@ def get_tokenizer(
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
tokenizer = Gemma3Tokenizer(tokenizer, model_cfg)
+ elif model_type == "glm4v":
+ from transformers import AutoProcessor
+
+ processor = AutoProcessor.from_pretrained(tokenizer_name)
+ tokenizer = GLM4VTokenizer(tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg)
return tokenizer
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index d3d1610f3..11c9b15c4 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -19,6 +19,7 @@
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
+from lightllm.models.glm4v.glm4v_visual import Glm4vVisionTransformerPretrainedModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
@@ -78,6 +79,10 @@ def exposed_init_model(self, kvargs):
# self.model = InternVLVisionModel()
elif self.model_type == "gemma3":
self.model = Gemma3VisionModel()
+ elif self.model_type == "glm4v":
+ self.model = (
+ Glm4vVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
+ )
else:
raise Exception(f"can not support {self.model_type} now")