Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 11, 2025

📄 10% (0.10x) speedup for _get_precomputed_embedding in python/sglang/srt/managers/mm_utils.py

⏱️ Runtime : 715 microseconds 647 microseconds (best of 114 runs)

📝 Explanation and details

The optimized code achieves a 10% speedup through several key micro-optimizations that reduce overhead and improve memory access patterns:

Key Optimizations:

  1. Single-pass algorithm: Instead of creating an intermediate list precomputed_embeddings and making two passes with any() and all(), the optimized version uses one loop that simultaneously:

    • Collects non-None embeddings in all_embeddings
    • Tracks if any None values were found with found_none
  2. Reduced function call overhead: Eliminates expensive generator expressions in any() and all() calls, which create iterator objects and perform redundant attribute access (item.precomputed_embeddings is accessed twice per item in the original).

  3. More efficient PyTorch operation: Uses torch.cat() instead of torch.concat() - while functionally identical, torch.cat() is the canonical PyTorch method and may have slight performance advantages.

Why This Matters:

The function is called from get_embedding_and_mask() which processes multimodal embeddings in ML inference pipelines. Based on the test results, the optimization is particularly effective for:

  • Empty/all-None cases (70-116% faster): Critical for batch processing where many items may lack precomputed embeddings
  • Mixed None/embedding cases (84-92% faster): Common in real workloads where only some items are precomputed
  • Large batches (76% faster for 1000 None items): Important for high-throughput inference

The single-pass approach reduces memory allocations and CPU cycles, which compounds when processing large batches of multimodal data in production ML systems.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 27 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime

from typing import List, Optional

imports

import pytest # used for our unit tests
import torch
from sglang.srt.managers.mm_utils import _get_precomputed_embedding

function to test

class MultimodalDataItem:
def init(self, precomputed_embeddings: Optional[torch.Tensor]):
self.precomputed_embeddings = precomputed_embeddings
from sglang.srt.managers.mm_utils import _get_precomputed_embedding

unit tests

1. BASIC TEST CASES

def test_all_none_returns_none():
# All items have precomputed_embeddings=None, should return None
items = [MultimodalDataItem(None) for _ in range(3)]
codeflash_output = _get_precomputed_embedding(items) # 1.75μs -> 1.02μs (71.2% faster)

def test_all_have_embeddings_2d():
# All items have 2D embeddings, should concatenate along first dimension
emb1 = torch.ones((2, 4))
emb2 = torch.zeros((3, 4))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 22.4μs -> 19.7μs (13.4% faster)

def test_empty_list_returns_none():
# Empty list should return None
codeflash_output = _get_precomputed_embedding([]) # 1.18μs -> 547ns (116% faster)

2. EDGE TEST CASES

def test_some_none_some_embeddings_raises():
# Some items have embeddings, some don't: should raise NotImplementedError
emb = torch.ones((2, 3))
items = [MultimodalDataItem(emb), MultimodalDataItem(None)]
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 2.60μs -> 1.35μs (92.2% faster)

def test_single_item_none_returns_none():
# Single item with None embedding should return None
items = [MultimodalDataItem(None)]
codeflash_output = _get_precomputed_embedding(items) # 1.43μs -> 716ns (100% faster)

def test_single_item_embedding_returns_reshaped():
# Single item with 3D embedding should be reshaped to 2D
emb = torch.arange(12).reshape(2, 2, 3)
items = [MultimodalDataItem(emb)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 15.7μs -> 13.4μs (17.1% faster)

def test_inconsistent_embedding_shapes_raises():
# Embeddings with incompatible shapes for concat should raise
emb1 = torch.ones((2, 4))
emb2 = torch.ones((2, 5)) # Different last dim
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
with pytest.raises(RuntimeError):
_get_precomputed_embedding(items) # 61.9μs -> 56.2μs (10.2% faster)

def test_embeddings_with_zero_rows():
# Embeddings with zero rows should be handled correctly
emb1 = torch.empty((0, 4))
emb2 = torch.ones((2, 4))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 20.1μs -> 17.3μs (16.2% faster)

def test_large_number_of_small_embeddings():
# Many small embeddings, test performance and correctness
items = [MultimodalDataItem(torch.ones((1, 8))) for _ in range(500)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 86.3μs -> 85.0μs (1.45% faster)

def test_large_embeddings_but_under_100mb():
# Each embedding is (10, 1280), 10 items: 10101280*4 = 512,000 bytes/item, 5MB total
items = [MultimodalDataItem(torch.full((10, 1280), fill_value=i, dtype=torch.float32)) for i in range(10)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 41.6μs -> 38.3μs (8.59% faster)
# Check that rows are filled with correct values
for i in range(10):
pass

def test_large_3d_embeddings_reshaped():
# 3D embeddings, many items, test reshape and concat
items = [MultimodalDataItem(torch.ones((2, 5, 32))) for _ in range(50)] # 5025324 = 64KB
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 27.6μs -> 26.0μs (6.29% faster)

def test_large_batch_with_some_none_raises():
# Large batch with one None should raise
items = [MultimodalDataItem(torch.ones((1, 16))) for _ in range(999)]
items.append(MultimodalDataItem(None))
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 38.7μs -> 40.1μs (3.53% slower)

def test_large_batch_all_none_returns_none():
# Large batch all None should return None
items = [MultimodalDataItem(None) for _ in range(1000)]
codeflash_output = _get_precomputed_embedding(items) # 42.0μs -> 23.8μs (76.6% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
import pytest # used for our unit tests
import torch
from sglang.srt.managers.mm_utils import _get_precomputed_embedding

function to test

class MultimodalDataItem:
def init(self, precomputed_embeddings=None):
self.precomputed_embeddings = precomputed_embeddings
from sglang.srt.managers.mm_utils import _get_precomputed_embedding

unit tests

------------------------

Basic Test Cases

------------------------

def test_all_none_returns_none():
# All items have None for precomputed_embeddings
items = [MultimodalDataItem(None) for _ in range(3)]
codeflash_output = _get_precomputed_embedding(items) # 1.73μs -> 905ns (90.8% faster)

def test_all_have_embeddings_2d():
# All items have 2D embeddings
emb1 = torch.ones((2, 5))
emb2 = torch.zeros((3, 5))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 22.9μs -> 19.7μs (16.1% faster)

def test_single_item_embedding_2d():
# Single item, 2D embedding
emb = torch.arange(6).reshape(2, 3)
items = [MultimodalDataItem(emb)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 16.5μs -> 14.1μs (17.5% faster)

def test_single_item_embedding_3d():
# Single item, 3D embedding
emb = torch.arange(12).reshape(1, 4, 3)
items = [MultimodalDataItem(emb)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 13.3μs -> 11.1μs (19.8% faster)
# Should flatten to (4, 3)
expected = emb.reshape(-1, emb.shape[-1])

------------------------

Edge Test Cases

------------------------

def test_some_none_some_tensor_raises():
# Some items have None, some have embeddings
emb = torch.ones((2, 5))
items = [MultimodalDataItem(emb), MultimodalDataItem(None)]
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 2.60μs -> 1.41μs (84.7% faster)

def test_empty_list_returns_none():
# Empty input list should return None
items = []
codeflash_output = _get_precomputed_embedding(items) # 1.18μs -> 563ns (109% faster)

def test_all_have_embeddings_different_shapes():
# All items have embeddings, but shapes are incompatible for concat
emb1 = torch.ones((2, 4))
emb2 = torch.zeros((3, 5)) # incompatible last dim
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
with pytest.raises(RuntimeError):
_get_precomputed_embedding(items) # 63.4μs -> 58.1μs (9.17% faster)

def test_all_have_embeddings_zero_length():
# All items have zero-length embeddings
emb1 = torch.empty((0, 5))
emb2 = torch.empty((0, 5))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 19.1μs -> 16.5μs (15.3% faster)

def test_all_have_embeddings_different_batch_dim():
# All items have embeddings, batch dimension differs
emb1 = torch.ones((1, 2, 4))
emb2 = torch.zeros((2, 2, 4))
items = [MultimodalDataItem(emb1), MultimodalDataItem(emb2)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 23.6μs -> 20.7μs (13.9% faster)

------------------------

Large Scale Test Cases

------------------------

def test_large_number_of_items_2d():
# Test with a large number of items, 2D embeddings
n_items = 500
emb_dim = 16
items = [MultimodalDataItem(torch.ones((1, emb_dim)) * i) for i in range(n_items)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 85.2μs -> 84.8μs (0.502% faster)
# Each row should be filled with its index
for i in range(n_items):
pass

def test_large_number_of_items_3d():
# Test with a large number of items, 3D embeddings
n_items = 200
seq_len = 2
emb_dim = 8
items = [MultimodalDataItem(torch.ones((1, seq_len, emb_dim)) * i) for i in range(n_items)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 48.5μs -> 46.3μs (4.70% faster)
# Each pair of rows should be filled with its index
for i in range(n_items):
pass

def test_large_embedding_size():
# Large embedding size, but < 100MB
n_items = 10
seq_len = 50
emb_dim = 200
items = [MultimodalDataItem(torch.ones((1, seq_len, emb_dim)) * i) for i in range(n_items)]
codeflash_output = _get_precomputed_embedding(items); result = codeflash_output # 32.7μs -> 31.4μs (4.16% faster)
for i in range(n_items):
pass

def test_large_number_some_none_raises():
# Large number of items, some None, should raise
n_items = 100
emb_dim = 10
items = [MultimodalDataItem(torch.ones((1, emb_dim)) if i % 2 == 0 else None) for i in range(n_items)]
with pytest.raises(NotImplementedError):
_get_precomputed_embedding(items) # 5.46μs -> 5.06μs (7.91% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_get_precomputed_embedding-mhv68s1j and push.

Codeflash Static Badge

The optimized code achieves a **10% speedup** through several key micro-optimizations that reduce overhead and improve memory access patterns:

**Key Optimizations:**

1. **Single-pass algorithm**: Instead of creating an intermediate list `precomputed_embeddings` and making two passes with `any()` and `all()`, the optimized version uses one loop that simultaneously:
   - Collects non-None embeddings in `all_embeddings`
   - Tracks if any None values were found with `found_none`

2. **Reduced function call overhead**: Eliminates expensive generator expressions in `any()` and `all()` calls, which create iterator objects and perform redundant attribute access (`item.precomputed_embeddings` is accessed twice per item in the original).

3. **More efficient PyTorch operation**: Uses `torch.cat()` instead of `torch.concat()` - while functionally identical, `torch.cat()` is the canonical PyTorch method and may have slight performance advantages.

**Why This Matters:**

The function is called from `get_embedding_and_mask()` which processes multimodal embeddings in ML inference pipelines. Based on the test results, the optimization is particularly effective for:
- **Empty/all-None cases** (70-116% faster): Critical for batch processing where many items may lack precomputed embeddings
- **Mixed None/embedding cases** (84-92% faster): Common in real workloads where only some items are precomputed
- **Large batches** (76% faster for 1000 None items): Important for high-throughput inference

The single-pass approach reduces memory allocations and CPU cycles, which compounds when processing large batches of multimodal data in production ML systems.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 11, 2025 22:57
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant