Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions rtp_llm/cpp/models/PyWrappedModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,12 @@ torch_ext::BertEmbeddingInputs PyWrappedModel::buildBertEmbeddingInputs(const Gp

// Convert combo_position_ids from Buffer to torch::Tensor
if (inputs.combo_position_ids.defined()) {
bert_embedding_inputs.combo_position_ids = inputs.combo_position_ids.cuda();
bert_embedding_inputs.combo_position_ids = tensorHoldHostAndToCuda(inputs.combo_position_ids);
}

// Convert combo_tokens_type_ids from Buffer to torch::Tensor
if (inputs.combo_tokens_type_ids.defined()) {
{
DevicePerfWrapper wrapper(enable_device_perf_, "py model combo_tokens.cuda()");
bert_embedding_inputs.combo_tokens_type_ids = inputs.combo_tokens_type_ids.cuda();
}
bert_embedding_inputs.combo_tokens_type_ids = tensorHoldHostAndToCuda(inputs.combo_tokens_type_ids);
}

// Get position_encoding from model weights (no clone needed for weights)
Expand All @@ -238,6 +235,21 @@ torch_ext::BertEmbeddingInputs PyWrappedModel::buildBertEmbeddingInputs(const Gp

// Set input_embedding_scalar
bert_embedding_inputs.input_embedding_scalar = description_.input_embedding_scalar;

// Propagate multimodal features so Python BertModel.forward can splice them
// into hidden_states (without this, vision token positions get garbage from
// placeholder ID lookups in the word embedding table).
if (inputs.multimodal_features && !inputs.multimodal_features.value().empty()) {
std::vector<torch::Tensor> mm_feats;
mm_feats.reserve(inputs.multimodal_features.value().size());
for (const auto& feature : inputs.multimodal_features.value()) {
mm_feats.emplace_back(tensorHoldHostAndToCuda(feature));
}
bert_embedding_inputs.multimodal_features = std::move(mm_feats);
}
if (inputs.mm_features_locs.defined()) {
bert_embedding_inputs.mm_features_locs = tensorHoldHostAndToCuda(inputs.mm_features_locs);
}
return bert_embedding_inputs;
}

Expand Down
6 changes: 6 additions & 0 deletions rtp_llm/models_py/bindings/OpDefs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ void registerPyOpDefs(pybind11::module& m) {
"token_type_embedding", &BertEmbeddingInputs::token_type_embedding, "Token type embedding tensor")
.def_readwrite(
"input_embedding_scalar", &BertEmbeddingInputs::input_embedding_scalar, "Input embedding scalar value")
.def_readwrite("multimodal_features",
&BertEmbeddingInputs::multimodal_features,
"Optional list of multimodal feature tensors to splice into hidden states")
.def_readwrite("mm_features_locs",
&BertEmbeddingInputs::mm_features_locs,
"Token-index locations for multimodal features splicing")
.def("__repr__", [](const BertEmbeddingInputs& self) { return "BertEmbeddingInputs"; });

pybind11::class_<PyEmbeddingInputs>(m, "PyEmbeddingInputs")
Expand Down
5 changes: 5 additions & 0 deletions rtp_llm/models_py/bindings/OpDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ struct BertEmbeddingInputs {
torch::Tensor combo_tokens_type_ids;
torch::Tensor token_type_embedding;
float input_embedding_scalar{1.0};
// Multimodal (vision) embeddings to be spliced into hidden states after pre-LN.
// Each tensor in multimodal_features is shape [seq_len_i, hidden_size].
// mm_features_locs is the flat token index where each multimodal feature should be placed.
std::vector<torch::Tensor> multimodal_features;
torch::Tensor mm_features_locs;
};

struct PyEmbeddingInputs {
Expand Down
8 changes: 8 additions & 0 deletions rtp_llm/models_py/model_desc/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EmbeddingBert,
FMHAImplBase,
LayerNorm,
MultimodalEmbeddingInjector,
)
from rtp_llm.ops import HWKernelConfig, ParallelismConfig
from rtp_llm.ops.compute_ops import (
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
beta=weights.get_global_weight(W.pre_decoder_ln_beta),
eps=config.layernorm_eps,
)
self.multimodal_embedding_injector = MultimodalEmbeddingInjector()
self.layers = nn.ModuleList(
[
BertDecoderLayer(
Expand Down Expand Up @@ -143,6 +145,12 @@ def forward(
bert_embedding_inputs.input_embedding_scalar,
)
hidden_states = self.pre_decoder_layernorm(inputs_embeds)
hidden_states = self.multimodal_embedding_injector(
hidden_states,
bert_embedding_inputs.multimodal_features,
bert_embedding_inputs.mm_features_locs,
)

if fmha_impl is None:
fmha_impl = self.prepare_fmha_impl(inputs)
for i, decoder_layer in enumerate(self.layers[: self.layer_num]):
Expand Down
11 changes: 11 additions & 0 deletions rtp_llm/models_py/model_desc/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ py_test(
tags = ["H20"],
exec_properties = {'gpu': 'H20', 'gpu_count': '1'},
)

py_test(
name = "bert_multimodal_test",
srcs = ["bert_multimodal_test.py"],
deps = [
"//rtp_llm/models_py/standalone:py_standalone_testlib",
],
env = {"GPU_COUNT": "1"},
tags = ["H20"],
exec_properties = {'gpu': 'H20', 'gpu_count': '1'},
)
139 changes: 139 additions & 0 deletions rtp_llm/models_py/model_desc/test/bert_multimodal_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Tests for BertModel multimodal-feature splicing.

BertModel.forward must inject vision features into hidden_states at the
positions advertised by bert_embedding_inputs.mm_features_locs, leaving all
other token positions untouched. The forward delegates to the shared
MultimodalEmbeddingInjector, so this test verifies the wiring end-to-end
without needing real model weights.
"""

import unittest
from types import SimpleNamespace

import torch

from rtp_llm.models_py.model_desc.bert import BertModel
from rtp_llm.models_py.modules import MultimodalEmbeddingInjector


class _FmhaStub:
def __init__(self):
self.fmha_params = None


class BertMultimodalForwardTest(unittest.TestCase):
def setUp(self) -> None:
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
self.device = torch.device("cuda:0")
torch.manual_seed(0)

def _build_model(self, hidden_size: int, dtype: torch.dtype) -> BertModel:
model = BertModel.__new__(BertModel)
torch.nn.Module.__init__(model)
# Replace heavy submodules with identity-style stubs so forward exercises
# only the multimodal splice path that we want to test.
model.embed_tokens = lambda input_ids, *args, **kwargs: torch.zeros(
input_ids.shape[0], hidden_size, device=self.device, dtype=dtype
)
model.pre_decoder_layernorm = lambda x: x
model.multimodal_embedding_injector = MultimodalEmbeddingInjector()
model.layers = []
model.layer_num = 0
model.kv_cache = None
model.prepare_fmha_impl = lambda inputs: _FmhaStub()
return model

@staticmethod
def _make_inputs(seq_len: int, features, locs, device, dtype):
bert_inputs = SimpleNamespace(
combo_position_ids=torch.empty(0),
position_encoding=torch.empty(0),
combo_tokens_type_ids=torch.empty(0),
token_type_embedding=torch.empty(0),
input_embedding_scalar=1.0,
multimodal_features=features,
mm_features_locs=(
torch.tensor(locs, device=device, dtype=torch.int32)
if locs
else torch.empty(0, device=device, dtype=torch.int32)
),
)
return SimpleNamespace(
input_ids=torch.zeros(seq_len, device=device, dtype=torch.int64),
bert_embedding_inputs=bert_inputs,
)

def test_features_overwrite_target_positions_only(self) -> None:
hidden_size = 16
seq_len = 12
dtype = torch.float16
model = self._build_model(hidden_size, dtype)

feat0 = torch.randn(3, hidden_size, device=self.device, dtype=dtype)
feat1 = torch.randn(2, hidden_size, device=self.device, dtype=dtype)
locs = [1, 7]

inputs = self._make_inputs(seq_len, [feat0, feat1], locs, self.device, dtype)
out = model.forward(inputs).hidden_states

expected = torch.zeros(seq_len, hidden_size, device=self.device, dtype=dtype)
expected[1:4] = feat0
expected[7:9] = feat1
torch.testing.assert_close(out, expected)

def test_empty_features_leaves_hidden_states_untouched(self) -> None:
hidden_size = 8
seq_len = 5
dtype = torch.float16
model = self._build_model(hidden_size, dtype)

inputs = self._make_inputs(seq_len, [], [], self.device, dtype)
out = model.forward(inputs).hidden_states

torch.testing.assert_close(
out,
torch.zeros(seq_len, hidden_size, device=self.device, dtype=dtype),
)

def test_negative_loc_truncates_prefix(self) -> None:
# loc < 0 means part of the feature is already in the reused KV prefix;
# the injector must drop the head rows and place the tail at position 0.
hidden_size = 4
seq_len = 6
dtype = torch.float16
model = self._build_model(hidden_size, dtype)

feat = torch.randn(4, hidden_size, device=self.device, dtype=dtype)
inputs = self._make_inputs(seq_len, [feat], [-2], self.device, dtype)
out = model.forward(inputs).hidden_states

expected = torch.zeros(seq_len, hidden_size, device=self.device, dtype=dtype)
expected[0:2] = feat[2:]
torch.testing.assert_close(out, expected)

def test_out_of_range_loc_raises(self) -> None:
hidden_size = 4
seq_len = 4
dtype = torch.float16
model = self._build_model(hidden_size, dtype)

feat = torch.randn(3, hidden_size, device=self.device, dtype=dtype)
inputs = self._make_inputs(seq_len, [feat], [2], self.device, dtype)
with self.assertRaisesRegex(IndexError, r"cannot be placed"):
model.forward(inputs)

def test_locs_length_mismatch_raises(self) -> None:
hidden_size = 4
seq_len = 6
dtype = torch.float16
model = self._build_model(hidden_size, dtype)

feat = torch.randn(2, hidden_size, device=self.device, dtype=dtype)
inputs = self._make_inputs(seq_len, [feat], [0, 3], self.device, dtype)
with self.assertRaisesRegex(ValueError, r"entries .* features"):
model.forward(inputs)


if __name__ == "__main__":
unittest.main()
Loading