Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,18 +325,35 @@ def get_embedding_chunk(
extend_start_index = extend_prefix_len
extend_end_index = extend_prefix_len + extend_seq_len - 1

for start, end in items_offset:
if extend_start_index >= start and extend_start_index <= end:
start_index += extend_start_index - start
elif extend_start_index > end:
# Precompute for faster loop & avoid Python branching where possible
ranges = items_offset
for start, end in ranges:
# compute overlap with prefix for start_index
if extend_start_index <= end:
if extend_start_index >= start:
start_index += extend_start_index - start
else:
# extend_start_index < start <= end : don't add here
pass
else:
# extend_start_index > end
start_index += end - start + 1

if extend_end_index >= start and extend_end_index <= end:
end_index += extend_end_index - start + 1
elif extend_end_index > end:
# compute overlap with prefix+seq for end_index
if extend_end_index <= end:
if extend_end_index >= start:
end_index += extend_end_index - start + 1
else:
# extend_end_index < start <= end : don't add here
pass
else:
# extend_end_index > end
end_index += end - start + 1
# some models' embedding is 3-dim, reshape it to 2-dim
embedding = embedding.reshape(-1, embedding.shape[-1])

# Fast path: avoid extra computation for rank-2 tensors
if embedding.ndim != 2:
embedding = embedding.reshape(-1, embedding.shape[-1])

embedding_chunk = embedding[start_index:end_index]
return embedding_chunk, start_index, end_index

Expand Down