⚡️ Speed up function _get_precomputed_embedding by 10%
#344
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 10% (0.10x) speedup for
_get_precomputed_embeddinginpython/sglang/srt/managers/mm_utils.py⏱️ Runtime :
715 microseconds→647 microseconds(best of114runs)📝 Explanation and details
The optimized code achieves a 10% speedup through several key micro-optimizations that reduce overhead and improve memory access patterns:
Key Optimizations:
Single-pass algorithm: Instead of creating an intermediate list
precomputed_embeddingsand making two passes withany()andall(), the optimized version uses one loop that simultaneously:all_embeddingsfound_noneReduced function call overhead: Eliminates expensive generator expressions in
any()andall()calls, which create iterator objects and perform redundant attribute access (item.precomputed_embeddingsis accessed twice per item in the original).More efficient PyTorch operation: Uses
torch.cat()instead oftorch.concat()- while functionally identical,torch.cat()is the canonical PyTorch method and may have slight performance advantages.Why This Matters:
The function is called from
get_embedding_and_mask()which processes multimodal embeddings in ML inference pipelines. Based on the test results, the optimization is particularly effective for:The single-pass approach reduces memory allocations and CPU cycles, which compounds when processing large batches of multimodal data in production ML systems.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
from typing import List, Optional
imports
import pytest # used for our unit tests
import torch
from sglang.srt.managers.mm_utils import _get_precomputed_embedding
function to test
class MultimodalDataItem:
def init(self, precomputed_embeddings: Optional[torch.Tensor]):
self.precomputed_embeddings = precomputed_embeddings
from sglang.srt.managers.mm_utils import _get_precomputed_embedding
unit tests
1. BASIC TEST CASES
def test_all_none_returns_none():
# All items have precomputed_embeddings=None, should return None
items = [MultimodalDataItem(None) for _ in range(3)]
codeflash_output = _get_precomputed_embedding(items) # 1.75μs -> 1.02μs (71.2% faster)
def test_all_have_embeddings_2d():
# All items have 2D embeddings, should concatenate along first dimension
emb1 = torch.ones((2, 4))
emb2 = torch.zeros((3, 4))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 22.4μs -> 19.7μs (13.4% faster)
def test_empty_list_returns_none():
# Empty list should return None
codeflash_output = _get_precomputed_embedding([]) # 1.18μs -> 547ns (116% faster)
2. EDGE TEST CASES
def test_some_none_some_embeddings_raises():
# Some items have embeddings, some don't: should raise NotImplementedError
emb = torch.ones((2, 3))
items = [MultimodalDataItem(emb), MultimodalDataItem(None)]
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 2.60μs -> 1.35μs (92.2% faster)
def test_single_item_none_returns_none():
# Single item with None embedding should return None
items = [MultimodalDataItem(None)]
codeflash_output = _get_precomputed_embedding(items) # 1.43μs -> 716ns (100% faster)
def test_single_item_embedding_returns_reshaped():
# Single item with 3D embedding should be reshaped to 2D
emb = torch.arange(12).reshape(2, 2, 3)
items = [MultimodalDataItem(emb)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 15.7μs -> 13.4μs (17.1% faster)
def test_inconsistent_embedding_shapes_raises():
# Embeddings with incompatible shapes for concat should raise
emb1 = torch.ones((2, 4))
emb2 = torch.ones((2, 5)) # Different last dim
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
with pytest.raises(RuntimeError):
_get_precomputed_embedding(items) # 61.9μs -> 56.2μs (10.2% faster)
def test_embeddings_with_zero_rows():
# Embeddings with zero rows should be handled correctly
emb1 = torch.empty((0, 4))
emb2 = torch.ones((2, 4))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 20.1μs -> 17.3μs (16.2% faster)
def test_large_number_of_small_embeddings():
# Many small embeddings, test performance and correctness
items = [MultimodalDataItem(torch.ones((1, 8))) for _ in range(500)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 86.3μs -> 85.0μs (1.45% faster)
def test_large_embeddings_but_under_100mb():
# Each embedding is (10, 1280), 10 items: 10101280*4 = 512,000 bytes/item, 5MB total
items = [MultimodalDataItem(torch.full((10, 1280), fill_value=i, dtype=torch.float32)) for i in range(10)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 41.6μs -> 38.3μs (8.59% faster)
# Check that rows are filled with correct values
for i in range(10):
pass
def test_large_3d_embeddings_reshaped():
# 3D embeddings, many items, test reshape and concat
items = [MultimodalDataItem(torch.ones((2, 5, 32))) for _ in range(50)] # 5025324 = 64KB
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 27.6μs -> 26.0μs (6.29% faster)
def test_large_batch_with_some_none_raises():
# Large batch with one None should raise
items = [MultimodalDataItem(torch.ones((1, 16))) for _ in range(999)]
items.append(MultimodalDataItem(None))
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 38.7μs -> 40.1μs (3.53% slower)
def test_large_batch_all_none_returns_none():
# Large batch all None should return None
items = [MultimodalDataItem(None) for _ in range(1000)]
codeflash_output = _get_precomputed_embedding(items) # 42.0μs -> 23.8μs (76.6% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest # used for our unit tests
import torch
from sglang.srt.managers.mm_utils import _get_precomputed_embedding
function to test
class MultimodalDataItem:
def init(self, precomputed_embeddings=None):
self.precomputed_embeddings = precomputed_embeddings
from sglang.srt.managers.mm_utils import _get_precomputed_embedding
unit tests
------------------------
Basic Test Cases
------------------------
def test_all_none_returns_none():
# All items have None for precomputed_embeddings
items = [MultimodalDataItem(None) for _ in range(3)]
codeflash_output = _get_precomputed_embedding(items) # 1.73μs -> 905ns (90.8% faster)
def test_all_have_embeddings_2d():
# All items have 2D embeddings
emb1 = torch.ones((2, 5))
emb2 = torch.zeros((3, 5))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 22.9μs -> 19.7μs (16.1% faster)
def test_single_item_embedding_2d():
# Single item, 2D embedding
emb = torch.arange(6).reshape(2, 3)
items = [MultimodalDataItem(emb)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 16.5μs -> 14.1μs (17.5% faster)
def test_single_item_embedding_3d():
# Single item, 3D embedding
emb = torch.arange(12).reshape(1, 4, 3)
items = [MultimodalDataItem(emb)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 13.3μs -> 11.1μs (19.8% faster)
# Should flatten to (4, 3)
expected = emb.reshape(-1, emb.shape[-1])
------------------------
Edge Test Cases
------------------------
def test_some_none_some_tensor_raises():
# Some items have None, some have embeddings
emb = torch.ones((2, 5))
items = [MultimodalDataItem(emb), MultimodalDataItem(None)]
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 2.60μs -> 1.41μs (84.7% faster)
def test_empty_list_returns_none():
# Empty input list should return None
items = []
codeflash_output = _get_precomputed_embedding(items) # 1.18μs -> 563ns (109% faster)
def test_all_have_embeddings_different_shapes():
# All items have embeddings, but shapes are incompatible for concat
emb1 = torch.ones((2, 4))
emb2 = torch.zeros((3, 5)) # incompatible last dim
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
with pytest.raises(RuntimeError):
_get_precomputed_embedding(items) # 63.4μs -> 58.1μs (9.17% faster)
def test_all_have_embeddings_zero_length():
# All items have zero-length embeddings
emb1 = torch.empty((0, 5))
emb2 = torch.empty((0, 5))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 19.1μs -> 16.5μs (15.3% faster)
def test_all_have_embeddings_different_batch_dim():
# All items have embeddings, batch dimension differs
emb1 = torch.ones((1, 2, 4))
emb2 = torch.zeros((2, 2, 4))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 23.6μs -> 20.7μs (13.9% faster)
------------------------
Large Scale Test Cases
------------------------
def test_large_number_of_items_2d():
# Test with a large number of items, 2D embeddings
n_items = 500
emb_dim = 16
items = [MultimodalDataItem(torch.ones((1, emb_dim)) * i) for i in range(n_items)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 85.2μs -> 84.8μs (0.502% faster)
# Each row should be filled with its index
for i in range(n_items):
pass
def test_large_number_of_items_3d():
# Test with a large number of items, 3D embeddings
n_items = 200
seq_len = 2
emb_dim = 8
items = [MultimodalDataItem(torch.ones((1, seq_len, emb_dim)) * i) for i in range(n_items)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 48.5μs -> 46.3μs (4.70% faster)
# Each pair of rows should be filled with its index
for i in range(n_items):
pass
def test_large_embedding_size():
# Large embedding size, but < 100MB
n_items = 10
seq_len = 50
emb_dim = 200
items = [MultimodalDataItem(torch.ones((1, seq_len, emb_dim)) * i) for i in range(n_items)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 32.7μs -> 31.4μs (4.16% faster)
for i in range(n_items):
pass
def test_large_number_some_none_raises():
# Large number of items, some None, should raise
n_items = 100
emb_dim = 10
items = [MultimodalDataItem(torch.ones((1, emb_dim)) if i % 2 == 0 else None) for i in range(n_items)]
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 5.46μs -> 5.06μs (7.91% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-_get_precomputed_embedding-mhv68s1jand push.