⚡️ Speed up function get_embedding_chunk by 47%
#343
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 47% (0.47x) speedup for
get_embedding_chunkinpython/sglang/srt/managers/mm_utils.py⏱️ Runtime :
359 microseconds→244 microseconds(best of250runs)📝 Explanation and details
The optimization achieves a 47% speedup by eliminating redundant tensor operations and reducing branching overhead in the hot path loop.
Key optimizations:
Conditional reshape avoidance: The original code always called
embedding.reshape(), even for already 2D tensors. The optimized version checksembedding.ndim != 2first, avoiding unnecessary tensor operations when the embedding is already 2D (common case in batch inference).Reduced Python branching: The loop logic was restructured to minimize conditional checks. Instead of complex nested if-elif chains, the optimized version uses cleaner if-else patterns that are more efficient for the Python interpreter.
Variable elimination: Removed the redundant
ranges = items_offsetassignment that added no value.Performance impact by test case:
Hot path significance: The function is called within
_get_chunked_prefill_embedding()during multimodal batch processing loops, where it processes embeddings for each request. Given that this runs in inference pipelines where latency matters, the 47% reduction in per-call overhead translates to meaningful throughput improvements, especially when processing many multimodal requests with predominantly 2D embeddings.The optimization is particularly effective for typical workloads where embeddings are already 2D tensors, making the reshape check a worthwhile trade-off despite the minor cost for 3D inputs.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest # used for our unit tests
import torch
from sglang.srt.managers.mm_utils import get_embedding_chunk
unit tests
-------------------- BASIC TEST CASES --------------------
def test_basic_single_offset_full_overlap():
# Single offset, extraction fully inside the offset
embedding = torch.arange(10 * 4).reshape(10, 4) # 10 tokens, 4-dim
items_offset = [(0, 9)] # covers all tokens
chunk, start, end = get_embedding_chunk(embedding, 2, 5, items_offset) # 10.6μs -> 6.74μs (57.1% faster)
def test_basic_multiple_offsets_partial_overlap():
# Multiple offsets, extraction starts in second offset
embedding = torch.arange(12 * 2).reshape(12, 2)
items_offset = [(0, 3), (4, 7), (8, 11)]
chunk, start, end = get_embedding_chunk(embedding, 4, 3, items_offset) # 10.1μs -> 6.68μs (50.4% faster)
def test_basic_single_offset_exact_end():
# Extraction ends exactly at the end of the offset
embedding = torch.arange(8 * 3).reshape(8, 3)
items_offset = [(2, 7)]
chunk, start, end = get_embedding_chunk(embedding, 5, 3, items_offset) # 9.46μs -> 5.98μs (58.2% faster)
def test_basic_3d_embedding_reshape():
# 3D embedding should be reshaped to 2D
embedding = torch.arange(254).reshape(2,5,4)
items_offset = [(0, 9)]
chunk, start, end = get_embedding_chunk(embedding, 3, 4, items_offset) # 9.10μs -> 9.82μs (7.37% slower)
-------------------- EDGE TEST CASES --------------------
def test_edge_no_overlap_returns_empty():
# Extraction range does not overlap any items_offset
embedding = torch.arange(6 * 2).reshape(6, 2)
items_offset = [(0, 1), (2, 3)]
chunk, start, end = get_embedding_chunk(embedding, 4, 2, items_offset) # 9.61μs -> 6.05μs (58.9% faster)
def test_edge_empty_items_offset():
# items_offset is empty
embedding = torch.arange(5 * 2).reshape(5, 2)
items_offset = []
chunk, start, end = get_embedding_chunk(embedding, 1, 3, items_offset) # 8.71μs -> 5.58μs (56.1% faster)
def test_edge_zero_length_extraction():
# extend_seq_len is zero
embedding = torch.arange(7 * 2).reshape(7, 2)
items_offset = [(0, 6)]
chunk, start, end = get_embedding_chunk(embedding, 2, 0, items_offset) # 9.33μs -> 5.93μs (57.3% faster)
def test_edge_extraction_starts_before_first_offset():
# Extraction starts before any offset
embedding = torch.arange(10 * 3).reshape(10, 3)
items_offset = [(3, 6), (7, 9)]
chunk, start, end = get_embedding_chunk(embedding, 0, 2, items_offset) # 9.38μs -> 5.46μs (71.8% faster)
def test_edge_extraction_ends_after_last_offset():
# Extraction ends after last offset
embedding = torch.arange(10 * 3).reshape(10, 3)
items_offset = [(0, 2), (3, 5)]
chunk, start, end = get_embedding_chunk(embedding, 4, 5, items_offset) # 9.79μs -> 6.17μs (58.8% faster)
def test_edge_items_offset_with_single_element():
# Offset range is a single element
embedding = torch.arange(5 * 2).reshape(5, 2)
items_offset = [(2, 2)]
chunk, start, end = get_embedding_chunk(embedding, 2, 1, items_offset) # 9.51μs -> 6.07μs (56.5% faster)
def test_edge_negative_indices():
# Negative extend_prefix_len, should not overlap
embedding = torch.arange(6 * 2).reshape(6, 2)
items_offset = [(0, 5)]
chunk, start, end = get_embedding_chunk(embedding, -2, 3, items_offset) # 9.15μs -> 5.88μs (55.6% faster)
def test_edge_extend_seq_len_greater_than_embedding():
# Extraction length longer than embedding size
embedding = torch.arange(8 * 2).reshape(8, 2)
items_offset = [(0, 7)]
chunk, start, end = get_embedding_chunk(embedding, 3, 10, items_offset) # 9.28μs -> 5.90μs (57.2% faster)
def test_edge_items_offset_unsorted():
# Unsorted offsets
embedding = torch.arange(10 * 2).reshape(10, 2)
items_offset = [(5, 7), (0, 4)]
chunk, start, end = get_embedding_chunk(embedding, 2, 4, items_offset) # 9.43μs -> 6.17μs (52.9% faster)
def test_edge_items_offset_overlap():
# Overlapping offsets
embedding = torch.arange(8 * 2).reshape(8, 2)
items_offset = [(2, 5), (4, 7)]
chunk, start, end = get_embedding_chunk(embedding, 4, 3, items_offset) # 9.44μs -> 6.14μs (53.7% faster)
-------------------- LARGE SCALE TEST CASES --------------------
def test_large_scale_embedding_and_offsets():
# Large embedding and offsets, but <100MB
tokens = 999
dim = 32
embedding = torch.arange(tokens * dim).reshape(tokens, dim)
items_offset = [(0, 499), (500, 998)] # covers all tokens
chunk, start, end = get_embedding_chunk(embedding, 450, 100, items_offset) # 10.6μs -> 6.70μs (58.5% faster)
def test_large_scale_many_offsets():
# Many small offsets
tokens = 600
dim = 8
embedding = torch.arange(tokens * dim).reshape(tokens, dim)
# 100 offsets of 6 tokens each
items_offset = [(i6, i6+5) for i in range(100)]
chunk, start, end = get_embedding_chunk(embedding, 300, 50, items_offset) # 18.1μs -> 13.0μs (39.1% faster)
def test_large_scale_3d_embedding():
# Large 3D embedding
embedding = torch.arange(20205).reshape(20, 20, 5) # (400, 5)
items_offset = [(0, 399)]
chunk, start, end = get_embedding_chunk(embedding, 100, 200, items_offset) # 9.81μs -> 10.2μs (4.21% slower)
def test_large_scale_extraction_at_end():
# Extraction at the end of a large embedding
tokens = 999
dim = 16
embedding = torch.arange(tokens * dim).reshape(tokens, dim)
items_offset = [(0, 998)]
chunk, start, end = get_embedding_chunk(embedding, 950, 49, items_offset) # 10.3μs -> 6.35μs (62.8% faster)
def test_large_scale_no_overlap():
# Large embedding, but no overlap
tokens = 500
dim = 10
embedding = torch.arange(tokens * dim).reshape(tokens, dim)
items_offset = [(0, 249), (250, 299)]
chunk, start, end = get_embedding_chunk(embedding, 400, 50, items_offset) # 10.2μs -> 6.46μs (57.2% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
import torch
from sglang.srt.managers.mm_utils import get_embedding_chunk
unit tests
--- Basic Test Cases ---
def test_basic_single_range_full_overlap():
# Test: extraction within a single offset range, fully inside
embedding = torch.arange(20).reshape(10, 2)
items_offset = [(0, 9)]
chunk, start, end = get_embedding_chunk(embedding, 2, 5, items_offset) # 9.64μs -> 5.94μs (62.4% faster)
def test_basic_single_range_partial_overlap():
# Test: extraction partially overlaps offset range
embedding = torch.arange(30).reshape(10, 3)
items_offset = [(3, 7)]
chunk, start, end = get_embedding_chunk(embedding, 5, 2, items_offset) # 9.49μs -> 5.89μs (61.2% faster)
def test_basic_multiple_ranges():
# Test: extraction across multiple offset ranges
embedding = torch.arange(60).reshape(20, 3)
items_offset = [(0, 4), (5, 9), (10, 14)]
chunk, start, end = get_embedding_chunk(embedding, 7, 5, items_offset) # 9.55μs -> 6.30μs (51.6% faster)
def test_basic_3d_embedding():
# Test: input embedding is 3D, should be reshaped to 2D
embedding = torch.arange(24).reshape(2, 4, 3)
items_offset = [(0, 7)]
chunk, start, end = get_embedding_chunk(embedding, 3, 4, items_offset) # 9.35μs -> 10.2μs (8.63% slower)
--- Edge Test Cases ---
def test_edge_no_overlap():
# Test: requested range does not overlap any items_offset
embedding = torch.arange(20).reshape(10, 2)
items_offset = [(0, 3), (7, 9)]
chunk, start, end = get_embedding_chunk(embedding, 4, 2, items_offset) # 9.37μs -> 5.99μs (56.4% faster)
def test_edge_empty_items_offset():
# Test: items_offset is empty
embedding = torch.arange(20).reshape(10, 2)
items_offset = []
chunk, start, end = get_embedding_chunk(embedding, 2, 3, items_offset) # 9.20μs -> 5.59μs (64.4% faster)
def test_edge_zero_length_extraction():
# Test: extend_seq_len is zero
embedding = torch.arange(20).reshape(10, 2)
items_offset = [(0, 9)]
chunk, start, end = get_embedding_chunk(embedding, 5, 0, items_offset) # 9.46μs -> 5.96μs (58.7% faster)
def test_edge_negative_extend_prefix_len():
# Test: negative extend_prefix_len
embedding = torch.arange(20).reshape(10, 2)
items_offset = [(0, 9)]
chunk, start, end = get_embedding_chunk(embedding, -2, 3, items_offset) # 9.70μs -> 6.09μs (59.2% faster)
def test_edge_extend_seq_len_exceeds_embedding():
# Test: extend_seq_len exceeds embedding length
embedding = torch.arange(20).reshape(10, 2)
items_offset = [(0, 9)]
chunk, start, end = get_embedding_chunk(embedding, 7, 10, items_offset) # 9.37μs -> 5.75μs (62.8% faster)
def test_edge_items_offset_out_of_bounds():
# Test: items_offset contains out-of-bounds indices
embedding = torch.arange(20).reshape(10, 2)
items_offset = [(-5, -1), (0, 4)]
chunk, start, end = get_embedding_chunk(embedding, 2, 3, items_offset) # 9.92μs -> 6.26μs (58.4% faster)
def test_edge_embedding_smaller_than_offsets():
# Test: embedding smaller than items_offset ranges
embedding = torch.arange(8).reshape(4, 2)
items_offset = [(0, 10)]
chunk, start, end = get_embedding_chunk(embedding, 2, 3, items_offset) # 9.23μs -> 6.11μs (51.0% faster)
def test_edge_empty_embedding():
# Test: embedding tensor is empty
embedding = torch.empty((0, 2))
items_offset = [(0, 3)]
chunk, start, end = get_embedding_chunk(embedding, 1, 2, items_offset) # 12.0μs -> 6.27μs (90.6% faster)
--- Large Scale Test Cases ---
def test_large_embedding_and_offsets():
# Test: large embedding and multiple offset ranges
embedding = torch.arange(3000).reshape(1000, 3)
items_offset = [(0, 299), (300, 699), (700, 999)]
chunk, start, end = get_embedding_chunk(embedding, 500, 100, items_offset) # 10.5μs -> 6.93μs (51.6% faster)
def test_large_embedding_partial_overlap():
# Test: extraction at the very end of a large embedding
embedding = torch.arange(3000).reshape(1000, 3)
items_offset = [(0, 999)]
chunk, start, end = get_embedding_chunk(embedding, 990, 20, items_offset) # 9.81μs -> 6.34μs (54.8% faster)
def test_large_embedding_no_overlap():
# Test: large embedding, offset ranges do not overlap extraction
embedding = torch.arange(3000).reshape(1000, 3)
items_offset = [(0, 499)]
chunk, start, end = get_embedding_chunk(embedding, 700, 100, items_offset) # 9.94μs -> 6.11μs (62.5% faster)
def test_large_embedding_3d():
# Test: large 3D embedding
embedding = torch.arange(6000).reshape(2, 1000, 3)
items_offset = [(0, 999)]
chunk, start, end = get_embedding_chunk(embedding, 500, 100, items_offset) # 9.94μs -> 10.5μs (5.24% slower)
def test_large_embedding_max_size():
# Test: embedding near 100MB (1000x128 = 128000 floats, 512KB; safe)
embedding = torch.arange(128000, dtype=torch.float32).reshape(1000, 128)
items_offset = [(0, 999)]
chunk, start, end = get_embedding_chunk(embedding, 100, 900, items_offset) # 11.0μs -> 6.71μs (63.6% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-get_embedding_chunk-mhv5wc5sand push.