⚡️ Speed up function get_embedding_and_mask by 16%
#346
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 16% (0.16x) speedup for
get_embedding_and_maskinpython/sglang/srt/managers/mm_utils.py⏱️ Runtime :
127 microseconds→110 microseconds(best of17runs)📝 Explanation and details
The optimized code achieves a 15% speedup through several key algorithmic and memory efficiency improvements:
Primary Optimizations:
Single-pass precomputed embedding detection: Instead of creating a list and then calling
any()andall()separately (which iterates the list twice), the optimized version uses a single loop with boolean flags to track both conditions simultaneously, eliminating redundant iterations.Generator expression in
all()check: Replaced list comprehension[offset_end < prefix_length[i] for _, offset_end in items_offset]with a generator expression inall(), avoiding unnecessary memory allocation for the intermediate list.Empty tensor filtering: Added
if embedding_per_req_chunk.numel() > 0:before appending chunks, preventing empty tensors from being added to the concatenation list, reducing memory overhead and tensor operations.Optimized tensor concatenation: Uses
torch.catinstead oftorch.concatand explicitly specifiesdim=0, which is slightly more efficient as it avoids function call indirection.Loop structure optimization: Converted the for-loop to a while-loop with explicit increment control, allowing more efficient handling of the
continuecases without redundant iterator advancement.Performance Impact:
Based on the annotated tests, the optimizations show significant improvements across different scenarios:
Hot Path Relevance:
The function is called from
embed_mm_inputs(), which processes multimodal data for each request in the inference pipeline. Since multimodal models handle image/video tokens frequently, these micro-optimizations compound significantly in production workloads where this function may be called hundreds of times per second. The improvements are especially valuable for edge cases (empty inputs, precomputed embeddings) that occur frequently in real inference scenarios.✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest
import torch
from sglang.srt.managers.mm_utils import get_embedding_and_mask
--- Minimal stubs for dependencies ---
MultimodalDataItem stub
class MultimodalDataItem:
def init(self, hash_val, precomputed_embeddings=None):
self.hash = hash_val
self.precomputed_embeddings = precomputed_embeddings
MultiModalCache stub
class MultiModalCache:
def init(self, max_size=100000000): # 100MB
self.cache = {}
def get(self, mm_hash):
return self.cache.get(mm_hash, None)
def put(self, mm_hash, embedding):
self.cache[mm_hash] = embedding
return True
Logger stub
class DummyLogger:
def warning(self, msg, stacklevel=2):
pass
logger = DummyLogger()
embedding_cache = MultiModalCache()
--- Function under test and helpers ---
def get_embedding_hash(embedding_items):
hash_list = [item.hash for item in embedding_items]
return hash(tuple(hash_list))
from sglang.srt.managers.mm_utils import get_embedding_and_mask
--- Unit Tests ---
Helper function for embedding generation
def simple_embedding_func(items):
# Each item gets a random embedding vector of size 4
# The hash of the item is used to seed for determinism
result = []
for item in items:
torch.manual_seed(item.hash)
result.append(torch.ones(1, 4) * item.hash)
return torch.cat(result, dim=0)
---------------- BASIC TEST CASES ----------------
def test_basic_single_item_precomputed():
# Single item, precomputed embedding
emb = torch.ones(1, 4) * 7
items = [MultimodalDataItem(hash_val=7, precomputed_embeddings=emb)]
placeholder_tensor = torch.tensor([99])
input_ids = torch.tensor([1, 99, 2, 3])
items_size = [0, 1]
prefix_length = [0]
extend_length = [1]
items_offset_list = [[(1, 1)]]
embedding, mask = get_embedding_and_mask(simple_embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list) # 54.4μs -> 48.6μs (11.9% faster)
def test_basic_multiple_items_precomputed():
# Multiple items, all precomputed
emb1 = torch.ones(1, 4) * 5
emb2 = torch.ones(1, 4) * 6
items = [
MultimodalDataItem(hash_val=5, precomputed_embeddings=emb1),
MultimodalDataItem(hash_val=6, precomputed_embeddings=emb2)
]
placeholder_tensor = torch.tensor([100])
input_ids = torch.tensor([100, 2, 100, 3])
items_size = [0, 2]
prefix_length = [0]
extend_length = [2]
items_offset_list = [[(0, 0), (2, 2)]]
embedding, mask = get_embedding_and_mask(simple_embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list) # 38.1μs -> 36.9μs (3.41% faster)
def test_edge_no_items():
# No items, should return (None, None)
items = []
placeholder_tensor = torch.tensor([99])
input_ids = torch.tensor([1, 2, 3])
items_size = [0]
prefix_length = [0]
extend_length = [0]
items_offset_list = [[]]
embedding, mask = get_embedding_and_mask(simple_embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list) # 4.48μs -> 2.86μs (56.6% faster)
def test_edge_some_precomputed_some_not():
# Some items have precomputed, some don't: should raise NotImplementedError
emb = torch.ones(1, 4) * 7
items = [
MultimodalDataItem(hash_val=7, precomputed_embeddings=emb),
MultimodalDataItem(hash_val=8)
]
placeholder_tensor = torch.tensor([99])
input_ids = torch.tensor([99, 99])
items_size = [0, 2]
prefix_length = [0]
extend_length = [2]
items_offset_list = [[(0, 0), (1, 1)]]
with pytest.raises(NotImplementedError):
get_embedding_and_mask(simple_embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list) # 3.56μs -> 2.18μs (63.7% faster)
def test_edge_empty_input_ids():
# Empty input_ids: mask should be shape (0, 1)
items = [MultimodalDataItem(hash_val=71)]
placeholder_tensor = torch.tensor([99])
input_ids = torch.tensor([])
items_size = [0, 1]
prefix_length = [0]
extend_length = [1]
items_offset_list = [[]]
embedding, mask = get_embedding_and_mask(simple_embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list) # 7.09μs -> 5.75μs (23.3% faster)
#------------------------------------------------
import sys
imports
import pytest
import torch
from sglang.srt.managers.mm_utils import get_embedding_and_mask
Dummy MultimodalDataItem for testing
class MultimodalDataItem:
def init(self, hash_value, precomputed_embeddings=None):
self.hash = hash_value
self.precomputed_embeddings = precomputed_embeddings
Dummy logger for testing
class DummyLogger:
def init(self):
self.warnings = []
def warning(self, msg, stacklevel=1):
self.warnings.append(msg)
from sglang.srt.managers.mm_utils import get_embedding_and_mask
------------------- UNIT TESTS -------------------
Basic test: single item, single placeholder
def test_no_multimodal_items():
def embedding_func(items): return torch.empty((0, 4))
items = []
input_ids = torch.tensor([1, 2, 3])
placeholder_tensor = torch.tensor([999])
items_size = [0]
prefix_length = []
extend_length = []
items_offset_list = []
embedding, mask = get_embedding_and_mask(
embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list
) # 4.47μs -> 2.77μs (61.7% faster)
Edge case: some but not all items have precomputed embeddings
def test_some_precomputed_embeddings_raises():
emb1 = torch.tensor([[1.0, 2.0]])
items = [
MultimodalDataItem(hash_value=1, precomputed_embeddings=emb1),
MultimodalDataItem(hash_value=2, precomputed_embeddings=None),
]
def embedding_func(items): raise Exception("Should not be called")
input_ids = torch.tensor([1, 2])
placeholder_tensor = torch.tensor([2])
items_size = [0, 2]
prefix_length = [0]
extend_length = [1]
items_offset_list = [[(1, 1)]]
with pytest.raises(NotImplementedError):
get_embedding_and_mask(
embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list
) # 3.40μs -> 2.03μs (67.6% faster)
Edge case: embedding length mismatch (more tokens in input than embedding)
def test_empty_items_offset_list():
def embedding_func(items): return torch.tensor([[1.0, 2.0]])
items = [MultimodalDataItem(hash_value=1)]
input_ids = torch.tensor([1, 2])
placeholder_tensor = torch.tensor([2])
items_size = [0, 1]
prefix_length = [0]
extend_length = [1]
items_offset_list = [[]]
embedding, mask = get_embedding_and_mask(
embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list
) # 7.18μs -> 5.79μs (23.8% faster)
Edge case: items_size with zero-length range
def test_input_ids_and_placeholder_tensor_empty():
def embedding_func(items): return torch.empty((0, 4))
items = []
input_ids = torch.tensor([], dtype=torch.long)
placeholder_tensor = torch.tensor([], dtype=torch.long)
items_size = [0]
prefix_length = []
extend_length = []
items_offset_list = []
embedding, mask = get_embedding_and_mask(
embedding_func, items, placeholder_tensor, input_ids, items_size, prefix_length, extend_length, items_offset_list
) # 4.42μs -> 2.91μs (51.9% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-get_embedding_and_mask-mhv6ykdsand push.