fix(bert): propagate multimodal features to Python BertModel forward#1137
fix(bert): propagate multimodal features to Python BertModel forward#1137parkerpang wants to merge 1 commit into
Conversation
AI Code Review - PR #1137Status: BLOCKING Summary: P0/0 · P1/1 · P2/3 · P3/0 Blocking IssuesP1
Non-blocking SuggestionsP2
Checklist Violations (9 fail / 88 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
eab8702 to
0366816
Compare
AI Code Review - PR #1137Status: BLOCKING Summary: P0/0 · P1/1 · P2/0 · P3/1 Blocking IssuesP1
Non-blocking SuggestionsP3
Checklist Violations (2 fail / 88 total)RTP-LLM Checklist
Strengths
|
0366816 to
c055017
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/1 lgtm ready to ci Non-blocking SuggestionsP3
Checklist Violations (3 fail / 67 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
c055017 to
eb4b444
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/1 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (2 fail / 88 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
eb4b444 to
68badd0
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/2 · P3/1 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (4 fail / 94 total)General Principles Checklist
RTP-LLM Checklist
Python Static-First Checklist
Strengths
|
68badd0 to
ac06fd0
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/1 · P3/1 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (1 fail / 67 total)Python Static-First Checklist
Strengths
|
| mm_feats.reserve(inputs.multimodal_features.value().size()); | ||
| for (const auto& feature : inputs.multimodal_features.value()) { | ||
| DevicePerfWrapper wrapper(enable_device_perf_, "py model multimodal_features.cuda()"); | ||
| mm_feats.emplace_back(feature.cuda()); |
4429fa6 to
6c58893
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/1 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist ✅ (56 items passed)Strengths
|
6c58893 to
39f06a2
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/1 · P3/0 lgtm ready to ci Non-blocking SuggestionsP2
Checklist Violations (2 fail / 67 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
39f06a2 to
46f7dd8
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/3 · P3/2 lgtm ready to ci Non-blocking SuggestionsP2
P3
Checklist Violations (1 fail / 67 total)Python Static-First Checklist
Strengths
|
46f7dd8 to
5e7766b
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist Violations (2 fail / 88 total)General Principles Checklist
RTP-LLM Checklist
Strengths
|
5e7766b to
5d9d5e6
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist ✅ (84 items passed)Strengths
|
|
internal source has been updated, please review the changes! |
After the C++ GptModel removal (b6575e4), the Python BertModel.forward path never received computed vision embeddings from the C++ engine. Vision token positions (placeholder IDs -1/-2/-3) went through word embedding lookup producing garbage hidden states, causing ~5% argmax disagreement vs PyTorch baseline on VisionBert. Fix: - Add multimodal_features and mm_features_locs fields to BertEmbeddingInputs - Expose them via pybind11 - Populate them in PyWrappedModel::buildBertEmbeddingInputs - Splice vision embeddings into hidden_states after pre-LN in BertModel.forward Verified: argmax disagreement drops from 4.90% to 0.18% (full sweep, 62K items), matching mi308 reference (0.15%).
5d9d5e6 to
c9f128b
Compare
AI Code Review - PR #1137Status: LGTM Summary: P0/0 · P1/0 · P2/0 · P3/0 lgtm ready to ci Checklist ✅ (67 items passed)Strengths
|
|
internal source has been updated, please review the changes! |
After the C++ GptModel removal (b6575e4), the Python BertModel.forward path never received computed vision embeddings from the C++ engine. Vision token positions (placeholder IDs -1/-2/-3) went through word embedding lookup producing garbage hidden states, causing ~5% argmax disagreement vs PyTorch baseline on VisionBert.
Fix: