diff --git a/rtp_llm/cpp/models/PyWrappedModel.cc b/rtp_llm/cpp/models/PyWrappedModel.cc index a6bb94f83b..168dd1a679 100644 --- a/rtp_llm/cpp/models/PyWrappedModel.cc +++ b/rtp_llm/cpp/models/PyWrappedModel.cc @@ -213,15 +213,12 @@ torch_ext::BertEmbeddingInputs PyWrappedModel::buildBertEmbeddingInputs(const Gp // Convert combo_position_ids from Buffer to torch::Tensor if (inputs.combo_position_ids.defined()) { - bert_embedding_inputs.combo_position_ids = inputs.combo_position_ids.cuda(); + bert_embedding_inputs.combo_position_ids = tensorHoldHostAndToCuda(inputs.combo_position_ids); } // Convert combo_tokens_type_ids from Buffer to torch::Tensor if (inputs.combo_tokens_type_ids.defined()) { - { - DevicePerfWrapper wrapper(enable_device_perf_, "py model combo_tokens.cuda()"); - bert_embedding_inputs.combo_tokens_type_ids = inputs.combo_tokens_type_ids.cuda(); - } + bert_embedding_inputs.combo_tokens_type_ids = tensorHoldHostAndToCuda(inputs.combo_tokens_type_ids); } // Get position_encoding from model weights (no clone needed for weights) @@ -238,6 +235,21 @@ torch_ext::BertEmbeddingInputs PyWrappedModel::buildBertEmbeddingInputs(const Gp // Set input_embedding_scalar bert_embedding_inputs.input_embedding_scalar = description_.input_embedding_scalar; + + // Propagate multimodal features so Python BertModel.forward can splice them + // into hidden_states (without this, vision token positions get garbage from + // placeholder ID lookups in the word embedding table). + if (inputs.multimodal_features && !inputs.multimodal_features.value().empty()) { + std::vector mm_feats; + mm_feats.reserve(inputs.multimodal_features.value().size()); + for (const auto& feature : inputs.multimodal_features.value()) { + mm_feats.emplace_back(tensorHoldHostAndToCuda(feature)); + } + bert_embedding_inputs.multimodal_features = std::move(mm_feats); + } + if (inputs.mm_features_locs.defined()) { + bert_embedding_inputs.mm_features_locs = tensorHoldHostAndToCuda(inputs.mm_features_locs); + } return bert_embedding_inputs; } diff --git a/rtp_llm/models_py/bindings/OpDefs.cc b/rtp_llm/models_py/bindings/OpDefs.cc index a70dd00bf4..781e35541b 100644 --- a/rtp_llm/models_py/bindings/OpDefs.cc +++ b/rtp_llm/models_py/bindings/OpDefs.cc @@ -142,6 +142,12 @@ void registerPyOpDefs(pybind11::module& m) { "token_type_embedding", &BertEmbeddingInputs::token_type_embedding, "Token type embedding tensor") .def_readwrite( "input_embedding_scalar", &BertEmbeddingInputs::input_embedding_scalar, "Input embedding scalar value") + .def_readwrite("multimodal_features", + &BertEmbeddingInputs::multimodal_features, + "Optional list of multimodal feature tensors to splice into hidden states") + .def_readwrite("mm_features_locs", + &BertEmbeddingInputs::mm_features_locs, + "Token-index locations for multimodal features splicing") .def("__repr__", [](const BertEmbeddingInputs& self) { return "BertEmbeddingInputs"; }); pybind11::class_(m, "PyEmbeddingInputs") diff --git a/rtp_llm/models_py/bindings/OpDefs.h b/rtp_llm/models_py/bindings/OpDefs.h index 8b67ba86b4..593428cd38 100644 --- a/rtp_llm/models_py/bindings/OpDefs.h +++ b/rtp_llm/models_py/bindings/OpDefs.h @@ -216,6 +216,11 @@ struct BertEmbeddingInputs { torch::Tensor combo_tokens_type_ids; torch::Tensor token_type_embedding; float input_embedding_scalar{1.0}; + // Multimodal (vision) embeddings to be spliced into hidden states after pre-LN. + // Each tensor in multimodal_features is shape [seq_len_i, hidden_size]. + // mm_features_locs is the flat token index where each multimodal feature should be placed. + std::vector multimodal_features; + torch::Tensor mm_features_locs; }; struct PyEmbeddingInputs { diff --git a/rtp_llm/models_py/model_desc/bert.py b/rtp_llm/models_py/model_desc/bert.py index 3a14762b5e..6f4c1856dc 100644 --- a/rtp_llm/models_py/model_desc/bert.py +++ b/rtp_llm/models_py/model_desc/bert.py @@ -14,6 +14,7 @@ EmbeddingBert, FMHAImplBase, LayerNorm, + MultimodalEmbeddingInjector, ) from rtp_llm.ops import HWKernelConfig, ParallelismConfig from rtp_llm.ops.compute_ops import ( @@ -116,6 +117,7 @@ def __init__( beta=weights.get_global_weight(W.pre_decoder_ln_beta), eps=config.layernorm_eps, ) + self.multimodal_embedding_injector = MultimodalEmbeddingInjector() self.layers = nn.ModuleList( [ BertDecoderLayer( @@ -143,6 +145,12 @@ def forward( bert_embedding_inputs.input_embedding_scalar, ) hidden_states = self.pre_decoder_layernorm(inputs_embeds) + hidden_states = self.multimodal_embedding_injector( + hidden_states, + bert_embedding_inputs.multimodal_features, + bert_embedding_inputs.mm_features_locs, + ) + if fmha_impl is None: fmha_impl = self.prepare_fmha_impl(inputs) for i, decoder_layer in enumerate(self.layers[: self.layer_num]): diff --git a/rtp_llm/models_py/model_desc/test/BUILD b/rtp_llm/models_py/model_desc/test/BUILD index b431a5097d..8c74e60d9a 100644 --- a/rtp_llm/models_py/model_desc/test/BUILD +++ b/rtp_llm/models_py/model_desc/test/BUILD @@ -8,3 +8,14 @@ py_test( tags = ["H20"], exec_properties = {'gpu': 'H20', 'gpu_count': '1'}, ) + +py_test( + name = "bert_multimodal_test", + srcs = ["bert_multimodal_test.py"], + deps = [ + "//rtp_llm/models_py/standalone:py_standalone_testlib", + ], + env = {"GPU_COUNT": "1"}, + tags = ["H20"], + exec_properties = {'gpu': 'H20', 'gpu_count': '1'}, +) diff --git a/rtp_llm/models_py/model_desc/test/bert_multimodal_test.py b/rtp_llm/models_py/model_desc/test/bert_multimodal_test.py new file mode 100644 index 0000000000..19ffdfb146 --- /dev/null +++ b/rtp_llm/models_py/model_desc/test/bert_multimodal_test.py @@ -0,0 +1,139 @@ +"""Tests for BertModel multimodal-feature splicing. + +BertModel.forward must inject vision features into hidden_states at the +positions advertised by bert_embedding_inputs.mm_features_locs, leaving all +other token positions untouched. The forward delegates to the shared +MultimodalEmbeddingInjector, so this test verifies the wiring end-to-end +without needing real model weights. +""" + +import unittest +from types import SimpleNamespace + +import torch + +from rtp_llm.models_py.model_desc.bert import BertModel +from rtp_llm.models_py.modules import MultimodalEmbeddingInjector + + +class _FmhaStub: + def __init__(self): + self.fmha_params = None + + +class BertMultimodalForwardTest(unittest.TestCase): + def setUp(self) -> None: + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + self.device = torch.device("cuda:0") + torch.manual_seed(0) + + def _build_model(self, hidden_size: int, dtype: torch.dtype) -> BertModel: + model = BertModel.__new__(BertModel) + torch.nn.Module.__init__(model) + # Replace heavy submodules with identity-style stubs so forward exercises + # only the multimodal splice path that we want to test. + model.embed_tokens = lambda input_ids, *args, **kwargs: torch.zeros( + input_ids.shape[0], hidden_size, device=self.device, dtype=dtype + ) + model.pre_decoder_layernorm = lambda x: x + model.multimodal_embedding_injector = MultimodalEmbeddingInjector() + model.layers = [] + model.layer_num = 0 + model.kv_cache = None + model.prepare_fmha_impl = lambda inputs: _FmhaStub() + return model + + @staticmethod + def _make_inputs(seq_len: int, features, locs, device, dtype): + bert_inputs = SimpleNamespace( + combo_position_ids=torch.empty(0), + position_encoding=torch.empty(0), + combo_tokens_type_ids=torch.empty(0), + token_type_embedding=torch.empty(0), + input_embedding_scalar=1.0, + multimodal_features=features, + mm_features_locs=( + torch.tensor(locs, device=device, dtype=torch.int32) + if locs + else torch.empty(0, device=device, dtype=torch.int32) + ), + ) + return SimpleNamespace( + input_ids=torch.zeros(seq_len, device=device, dtype=torch.int64), + bert_embedding_inputs=bert_inputs, + ) + + def test_features_overwrite_target_positions_only(self) -> None: + hidden_size = 16 + seq_len = 12 + dtype = torch.float16 + model = self._build_model(hidden_size, dtype) + + feat0 = torch.randn(3, hidden_size, device=self.device, dtype=dtype) + feat1 = torch.randn(2, hidden_size, device=self.device, dtype=dtype) + locs = [1, 7] + + inputs = self._make_inputs(seq_len, [feat0, feat1], locs, self.device, dtype) + out = model.forward(inputs).hidden_states + + expected = torch.zeros(seq_len, hidden_size, device=self.device, dtype=dtype) + expected[1:4] = feat0 + expected[7:9] = feat1 + torch.testing.assert_close(out, expected) + + def test_empty_features_leaves_hidden_states_untouched(self) -> None: + hidden_size = 8 + seq_len = 5 + dtype = torch.float16 + model = self._build_model(hidden_size, dtype) + + inputs = self._make_inputs(seq_len, [], [], self.device, dtype) + out = model.forward(inputs).hidden_states + + torch.testing.assert_close( + out, + torch.zeros(seq_len, hidden_size, device=self.device, dtype=dtype), + ) + + def test_negative_loc_truncates_prefix(self) -> None: + # loc < 0 means part of the feature is already in the reused KV prefix; + # the injector must drop the head rows and place the tail at position 0. + hidden_size = 4 + seq_len = 6 + dtype = torch.float16 + model = self._build_model(hidden_size, dtype) + + feat = torch.randn(4, hidden_size, device=self.device, dtype=dtype) + inputs = self._make_inputs(seq_len, [feat], [-2], self.device, dtype) + out = model.forward(inputs).hidden_states + + expected = torch.zeros(seq_len, hidden_size, device=self.device, dtype=dtype) + expected[0:2] = feat[2:] + torch.testing.assert_close(out, expected) + + def test_out_of_range_loc_raises(self) -> None: + hidden_size = 4 + seq_len = 4 + dtype = torch.float16 + model = self._build_model(hidden_size, dtype) + + feat = torch.randn(3, hidden_size, device=self.device, dtype=dtype) + inputs = self._make_inputs(seq_len, [feat], [2], self.device, dtype) + with self.assertRaisesRegex(IndexError, r"cannot be placed"): + model.forward(inputs) + + def test_locs_length_mismatch_raises(self) -> None: + hidden_size = 4 + seq_len = 6 + dtype = torch.float16 + model = self._build_model(hidden_size, dtype) + + feat = torch.randn(2, hidden_size, device=self.device, dtype=dtype) + inputs = self._make_inputs(seq_len, [feat], [0, 3], self.device, dtype) + with self.assertRaisesRegex(ValueError, r"entries .* features"): + model.forward(inputs) + + +if __name__ == "__main__": + unittest.main()