Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 87 additions & 9 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
self.ftype = gguf.LlamaFileType.MOSTLY_F16
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")

self.dequant_model()

# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
Expand Down Expand Up @@ -527,6 +525,8 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
return ()

def prepare_tensors(self):
self.dequant_model()

# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
Expand Down Expand Up @@ -1808,7 +1808,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]

n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]

has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
Expand Down Expand Up @@ -1863,7 +1863,15 @@ def __init__(self, *args, **kwargs):
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
if preprocessor_config_path.is_file():
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
cfg = json.load(f)
# move media_proc_cfg to root level for compat
if "media_proc_cfg" in cfg:
cfg = {
**cfg,
**cfg["media_proc_cfg"],
}
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.preprocessor_config is empty at this point, so not really necessary to merge, but will allow it for consistent looks.


# prefer processor_config.json if possible
processor_config_path = self.dir_model / "processor_config.json"
Expand Down Expand Up @@ -1912,10 +1920,10 @@ def set_gguf_parameters(self):
self.image_size = self.find_vparam(["image_size"])
self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))

# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
Expand Down Expand Up @@ -7360,6 +7368,7 @@ def prepare_tensors(self):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration",
)
Expand Down Expand Up @@ -7478,8 +7487,8 @@ def set_gguf_parameters(self):
_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
return
if name.startswith("siglip2.") or name.startswith("merger."):
return
Expand Down Expand Up @@ -10712,6 +10721,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("KimiK25ForConditionalGeneration")
class KimiK25Model(MmprojModel):
"""Kimi-K2.5 with MoonViT3d vision encoder"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"

self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
self.patch_size = self.hparams_vision.get("patch_size", 14)

# Set image_size for compatibility with base class
# Use position embedding dimensions as image_size reference
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size

def set_gguf_parameters(self):
# Base class MmprojModel.set_gguf_parameters() already writes:
# - vision_block_count, vision_head_count, vision_embedding_length
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
super().set_gguf_parameters()
assert self.hparams_vision is not None

self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)

# Position embedding parameters (for interpolation) - KimiK25-specific
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))

# Projector parameters
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision and projector tensors
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])

if not is_vision:
return

# Split fused QKV tensors in vision encoder
if "wqkv" in name:
split_dim = 0 if "weight" in name else -1
wq, wk, wv = data_torch.chunk(3, dim=split_dim)
yield from super().modify_tensors(wq, name.replace("wqkv", "wq"), bid)
yield from super().modify_tensors(wk, name.replace("wqkv", "wk"), bid)
yield from super().modify_tensors(wv, name.replace("wqkv", "wv"), bid)
return

# Temporal embeddings: (T, 1, C) → (T, C)
if "pos_emb.time_weight" in name:
T, _, C = data_torch.shape
data_torch = data_torch.reshape(T, C)

# PatchMergerMLP tensor name mapping
# proj.0.weight → proj.linear_1.weight
# proj.2.weight → proj.linear_2.weight
if "mm_projector.proj.0." in name:
name = name.replace(".proj.0.", ".proj.linear_1.")
elif "mm_projector.proj.2." in name:
name = name.replace(".proj.2.", ".proj.linear_2.")

yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):

Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3610,6 +3610,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
KIMIK25 = "kimik25"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,7 @@ class TensorNameMap:

MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"mm_projector.proj.linear_{bid}", # Kimi-K2.5
"visual.merger.mlp.{bid}", # qwen2vl
"merger.mlp.{bid}",
),
Expand Down Expand Up @@ -1490,6 +1491,7 @@ class TensorNameMap:
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"mm_projector.pre_norm", # Kimi-K2.5
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
"merger.ln_q",
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(mtmd
models/glm4v.cpp
models/internvl.cpp
models/kimivl.cpp
models/kimik25.cpp
models/llama4.cpp
models/llava.cpp
models/minicpmv.cpp
Expand Down
11 changes: 11 additions & 0 deletions tools/mtmd/clip-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ struct clip_graph {
const bool interleave_freq
);

// 2D RoPE with interleaved frequency
// Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...]
// build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...]
ggml_tensor * build_rope_2d_interleaved(
ggml_context * ctx0,
ggml_tensor * cur, // [n_dim, n_head, n_pos]
ggml_tensor * pos_w, // [n_pos] - X/width positions
ggml_tensor * pos_h, // [n_pos] - Y/height positions
const float freq_base
);

// aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
// support dynamic resolution
ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor);
Expand Down
2 changes: 2 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ enum projector_type {
PROJECTOR_TYPE_LFM2A,
PROJECTOR_TYPE_GLM4V,
PROJECTOR_TYPE_YOUTUVL,
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_UNKNOWN,
};

Expand Down Expand Up @@ -266,6 +267,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
109 changes: 109 additions & 0 deletions tools/mtmd/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,83 @@ ggml_tensor * clip_graph::build_rope_2d(
return cur;
}

// 2D RoPE with interleaved frequency
// Pattern: [x_freq0, y_freq0, x_freq1, y_freq1, ...]
// build_rope_2d uses split pattern: [x_freq0, x_freq1, ..., y_freq0, y_freq1, ...]
ggml_tensor * clip_graph::build_rope_2d_interleaved(
ggml_context * ctx0,
ggml_tensor * cur, // [n_dim, n_head, n_pos]
ggml_tensor * pos_w, // [n_pos] - X/width positions
ggml_tensor * pos_h, // [n_pos] - Y/height positions
const float freq_base
) {
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];

GGML_ASSERT(n_dim % 4 == 0); // Must be divisible by 4 for interleaved x,y pairs

// Step 1: Reshape to expose interleaved structure
// cur: [n_dim, n_head, n_pos] -> [4, n_dim/4, n_head, n_pos]
ggml_tensor * reshaped = ggml_reshape_4d(ctx0, cur, 4, n_dim/4, n_head, n_pos);

// Step 2: Extract X pairs (elements 0,1 of each group of 4)
// x_pairs: [2, n_dim/4, n_head, n_pos]
ggml_tensor * x_pairs = ggml_view_4d(ctx0, reshaped,
2, n_dim/4, n_head, n_pos,
reshaped->nb[1], reshaped->nb[2], reshaped->nb[3],
0);

// Step 3: Extract Y pairs (elements 2,3 of each group of 4)
// y_pairs: [2, n_dim/4, n_head, n_pos]
ggml_tensor * y_pairs = ggml_view_4d(ctx0, reshaped,
2, n_dim/4, n_head, n_pos,
reshaped->nb[1], reshaped->nb[2], reshaped->nb[3],
2 * ggml_element_size(reshaped));

// Step 4: Make contiguous and reshape for rope_ext
// [2, n_dim/4, n_head, n_pos] -> [n_dim/2, n_head, n_pos]
x_pairs = ggml_cont(ctx0, x_pairs);
x_pairs = ggml_reshape_3d(ctx0, x_pairs, n_dim/2, n_head, n_pos);

y_pairs = ggml_cont(ctx0, y_pairs);
y_pairs = ggml_reshape_3d(ctx0, y_pairs, n_dim/2, n_head, n_pos);

// Step 5: Apply RoPE to X pairs using pos_w, Y pairs using pos_h
x_pairs = ggml_rope_ext(
ctx0,
x_pairs,
pos_w,
nullptr,
n_dim/2,
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);

y_pairs = ggml_rope_ext(
ctx0,
y_pairs,
pos_h,
nullptr,
n_dim/2,
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);

// Step 6: Reshape back to [2, n_dim/4, n_head, n_pos] for interleaving
x_pairs = ggml_reshape_4d(ctx0, x_pairs, 2, n_dim/4, n_head, n_pos);
y_pairs = ggml_reshape_4d(ctx0, y_pairs, 2, n_dim/4, n_head, n_pos);

// Step 7: Interleave X and Y pairs back together
// Concatenate along dimension 0: [4, n_dim/4, n_head, n_pos]
ggml_tensor * result = ggml_concat(ctx0, x_pairs, y_pairs, 0);

// Step 8: Reshape back to original: [n_dim, n_head, n_pos]
result = ggml_reshape_3d(ctx0, result, n_dim, n_head, n_pos);

return result;
}

// Generic function to stack frames for audio processing
// Abstracts out the StackAudioFrames logic used by ultravox
ggml_tensor * clip_graph::build_stack(ggml_tensor * cur, int32_t stack_factor, int32_t n_embed) {
Expand Down Expand Up @@ -825,6 +902,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
builder = std::make_unique<clip_graph_kimivl>(ctx, img);
} break;
case PROJECTOR_TYPE_KIMIK25:
{
builder = std::make_unique<clip_graph_kimik25>(ctx, img);
} break;
case PROJECTOR_TYPE_COGVLM:
{
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
Expand Down Expand Up @@ -1139,6 +1220,13 @@ struct clip_model_loader {
hparams.set_limit_image_tokens(8, 1024);
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
} break;
case PROJECTOR_TYPE_KIMIK25:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
hparams.set_limit_image_tokens(8, 4096);
hparams.set_warmup_n_tokens(256);
} break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
Expand Down Expand Up @@ -1668,6 +1756,7 @@ struct clip_model_loader {
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
} break;
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
Expand Down Expand Up @@ -3039,6 +3128,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(res));
} break;

case PROJECTOR_TYPE_KIMIK25:
{
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
original_size,
params.patch_size * params.n_merge,
params.image_min_pixels,
params.image_max_pixels);
const std::array<uint8_t, 3> pad_color = {0, 0, 0};

clip_image_u8 resized_img;
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
clip_image_f32_ptr res(clip_image_f32_init());
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
res_imgs->entries.push_back(std::move(res));
} break;

case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_MLP_NORM:
case PROJECTOR_TYPE_LDP:
Expand Down Expand Up @@ -3247,6 +3353,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
{
// dynamic size
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
Expand Down Expand Up @@ -3588,6 +3695,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} break;
case PROJECTOR_TYPE_PIXTRAL:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
case PROJECTOR_TYPE_LIGHTONOCR:
{
// set the 2D positions
Expand Down Expand Up @@ -3770,6 +3878,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_LFM2:
case PROJECTOR_TYPE_KIMIVL:
case PROJECTOR_TYPE_KIMIK25:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_COGVLM:
return ctx->model.mm_4h_to_h_w->ne[1];
Expand Down
Loading