Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 11, 2025

📄 19% (0.19x) speedup for MultiModalityDataPaddingPatternTokenPairs.pad_input_tokens in python/sglang/srt/managers/mm_utils.py

⏱️ Runtime : 697 microseconds 587 microseconds (best of 67 runs)

📝 Explanation and details

The optimized code achieves an 18% speedup through three key algorithmic and implementation improvements:

1. Eliminated redundant set recreations
The original code recreated start_token_ids and end_tokens_ids sets using comprehensions on each method call. The optimization moves this to a simple loop that builds the sets once, eliminating the overhead of comprehension execution and temporary list creation.

2. Single-pass token scanning
Instead of making two separate O(n) passes through input_ids to find start and end indices via list comprehensions, the optimized version scans the input once, collecting both start and end indices simultaneously. This halves the number of token lookups and reduces cache misses for large inputs.

3. Optimized list operations

  • Replaced mm_inputs.data_offsets += [start_idx] with direct append() to avoid list concatenation overhead
  • Cached self.data_start_token_ids in a local variable to eliminate repeated attribute lookups inside the loop

Performance characteristics:

  • Small inputs (4-10 tokens): 15-25% improvement due to reduced overhead
  • Medium inputs (100-1000 tokens): 10-20% improvement from single-pass scanning
  • Large inputs (900+ tokens): 30-35% improvement as the algorithmic benefits compound

The optimizations are particularly effective for the multimodal token processing workload, where input sequences can contain hundreds of tokens with multiple data regions that need padding. Since this function processes token sequences during model inference, the cumulative effect of these micro-optimizations can significantly impact overall throughput.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 60 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 90.3%
🌀 Generated Regression Tests and Runtime

from typing import List, Optional, Tuple

imports

import pytest
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs

Minimal stubs for dependencies

class MultimodalDataItem:
def init(self, pad_value):
self.pad_value = pad_value

class MultimodalInputs:
def init(self, mm_items=None, im_start_id=None, im_end_id=None):
self.mm_items = mm_items if mm_items is not None else []
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.data_offsets = []
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs

Unit tests

Basic Test Cases

def test_basic_single_pair():
# Single token pair, one data item
input_ids = [1, 101, 5, 6, 7, 102, 2]
mm_items = [MultimodalDataItem(pad_value=99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.22μs -> 4.51μs (15.7% faster)

def test_basic_multiple_pairs():
# Multiple token pairs, two data items
input_ids = [1, 101, 5, 6, 102, 2, 101, 7, 8, 102, 3]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.08μs -> 5.26μs (15.5% faster)

def test_basic_no_data_tokens():
# Input with no data token pairs
input_ids = [1, 2, 3, 4]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.58μs -> 2.99μs (20.0% faster)

def test_basic_multiple_pair_types():
# Two types of pairs
input_ids = [1, 101, 5, 102, 2, 201, 6, 202, 3]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102), (201, 202)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.29μs -> 5.54μs (13.3% faster)

def test_basic_start_token_ids_override():
# Custom start_token_ids
input_ids = [1, 101, 5, 102, 2, 201, 6, 202, 3]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
# Only 201 is considered a start
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102), (201, 202)], data_start_token_ids=[201])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.91μs -> 5.25μs (12.6% faster)

Edge Test Cases

def test_edge_unmatched_pairs():
# Unmatched start/end tokens
input_ids = [1, 101, 5, 6, 102, 101, 7, 8, 2]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.35μs -> 2.76μs (21.4% faster)

def test_edge_no_mm_items():
# No multimodal items
input_ids = [1, 101, 5, 6, 102, 2]
mm_items = []
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
# Should fail due to pad_values being empty, but function will try to access pad_value
with pytest.raises(IndexError):
pattern.pad_input_tokens(input_ids, mm_inputs) # 4.71μs -> 4.14μs (13.7% faster)

def test_edge_pad_values_shorter_than_pairs():
# Fewer pad_values than pairs: should use last pad_value for remaining
input_ids = [1, 101, 5, 102, 2, 101, 6, 102, 3]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
# Both regions replaced by 99
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.07μs -> 5.42μs (12.0% faster)

def test_edge_empty_input_ids():
# Empty input_ids
input_ids = []
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.09μs -> 2.73μs (13.2% faster)

def test_edge_zero_length_data_region():
# Data region with zero length (start token immediately followed by end token)
input_ids = [1, 101, 102, 2]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.92μs -> 4.98μs (18.9% faster)

def test_edge_multiple_consecutive_pairs():
# Multiple consecutive pairs
input_ids = [101, 5, 102, 101, 6, 102]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.25μs -> 5.42μs (15.4% faster)

Large Scale Test Cases

def test_large_scale_long_data_region():
# One large data region
n = 900
input_ids = [1, 101] + [i for i in range(n)] + [102, 2]
mm_items = [MultimodalDataItem(77)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 51.0μs -> 37.9μs (34.5% faster)
# All n tokens replaced by 77
expected = [1, 101] + [77] * n + [102, 2]

def test_large_scale_multiple_pair_types():
# Multiple pair types, alternating
n = 400
input_ids = []
for i in range(n):
if i % 2 == 0:
input_ids.extend([101, i, i+1, 102])
else:
input_ids.extend([201, i, 202])
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)] * (n // 2)
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102), (201, 202)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 83.4μs -> 75.4μs (10.6% faster)
# Regions replaced by 99 or 88, alternating
expected = []
idx = -1
for i in range(n):
if i % 2 == 0:
idx += 1
val = mm_items[idx].pad_value
expected.extend([101, val, val, 102])
else:
idx += 1
val = mm_items[idx].pad_value
expected.extend([201, val, 202])

def test_large_scale_all_pad_value_same():
# All pad values are the same
n = 700
input_ids = []
for i in range(n):
input_ids.extend([101, i, 102])
mm_items = [MultimodalDataItem(55)] * n
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 320μs -> 278μs (15.1% faster)
expected = []
for i in range(n):
expected.extend([101, 55, 102])

def test_large_scale_no_pairs():
# Large input, no pairs
n = 1000
input_ids = list(range(n))
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 54.5μs -> 40.4μs (35.0% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
from typing import List, Optional, Tuple

imports

import pytest
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs

Mocks and minimal stubs to support testing

class DummyMMItem:
def init(self, pad_value):
self.pad_value = pad_value

class DummyMMInputs:
def init(self, mm_items, im_start_id=None, im_end_id=None):
self.mm_items = mm_items
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.data_offsets = []

class MultiModalityDataPaddingPattern:
pass # placeholder for base class
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs

-------------------- UNIT TESTS --------------------

1. Basic Test Cases

def test_basic_single_pair():
# Single pair, one data item, simple replacement
# tokens: [10, 100, 1, 2, 3, 200, 20]
# pairs: (100, 200)
# pad_value: 0
# expected: [10, 100, 0, 0, 0, 200, 20]
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(0)])
input_ids = [10, 100, 1, 2, 3, 200, 20]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.49μs -> 4.62μs (19.0% faster)

def test_basic_multiple_pairs_multiple_data_items():
# Two pairs, two data items, two pad_values
# tokens: [100, 1, 2, 200, 100, 3, 4, 200]
# pairs: (100, 200)
# pad_values: 7, 8
# expected: [100, 7, 7, 200, 100, 8, 8, 200]
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(7), DummyMMItem(8)])
input_ids = [100, 1, 2, 200, 100, 3, 4, 200]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.00μs -> 5.40μs (11.0% faster)

def test_basic_different_token_pairs():
# Two different token pairs
# tokens: [10, 100, 1, 2, 200, 20, 101, 3, 4, 201, 30]
# pairs: (100, 200), (101, 201)
# pad_values: 5, 6
# expected: [10, 100, 5, 5, 200, 20, 101, 6, 6, 201, 30]
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200), (101, 201)])
mm_inputs = DummyMMInputs([DummyMMItem(5), DummyMMItem(6)])
input_ids = [10, 100, 1, 2, 200, 20, 101, 3, 4, 201, 30]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.19μs -> 5.59μs (10.8% faster)

def test_edge_mismatched_pairs():
# Start/end tokens count mismatch, should return input_ids unchanged
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(0)])
input_ids = [10, 100, 1, 2, 3, 20] # missing end token
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.75μs -> 3.08μs (21.9% faster)

def test_edge_more_pairs_than_pad_values():
# More pairs than pad_values: should use last pad_value for extra pairs
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(8)])
input_ids = [100, 1, 2, 200, 100, 3, 4, 200]
# Only one pad_value, two pairs: both should use 8
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 7.13μs -> 6.11μs (16.8% faster)

def test_edge_empty_input():
# Empty input_ids
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(0)])
input_ids = []
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.33μs -> 2.72μs (22.2% faster)

def test_edge_empty_token_pairs():
# Empty data_token_pairs (should fallback to mm_inputs.im_start_id/im_end_id)
pattern = MultiModalityDataPaddingPatternTokenPairs([])
mm_inputs = DummyMMInputs([DummyMMItem(2)], im_start_id=101, im_end_id=201)
input_ids = [101, 1, 2, 201]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.56μs -> 2.86μs (24.5% faster)

def test_edge_start_token_not_in_data_start_token_ids():
# If a start token is not in data_start_token_ids, data_idx should not increment
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)], data_start_token_ids=[101])
mm_inputs = DummyMMInputs([DummyMMItem(5)])
input_ids = [100, 1, 2, 200]
# data_idx remains -1, pad_value = pad_values[-1] = 5
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 4.98μs -> 4.34μs (14.8% faster)

def test_edge_pad_value_is_none():
# pad_value is None: should insert None
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(None)])
input_ids = [10, 100, 1, 2, 200, 20]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.17μs -> 4.33μs (19.4% faster)

3. Large Scale Test Cases

def test_large_scale_long_input_ids():
# Long input_ids, but only a few pairs
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(42)])
# input_ids: 500 normal tokens, one pair in the middle, 500 more
input_ids = list(range(500)) + [100, 1, 2, 3, 200] + list(range(500, 1000))
expected = list(range(500)) + [100, 42, 42, 42, 200] + list(range(500, 1000))
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 57.8μs -> 42.0μs (37.6% faster)

def test_large_scale_large_pad_value():
# Large pad_value (e.g. 106), ensure values are correct and not truncated
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(10
6)])
input_ids = [100, 1, 2, 3, 200]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.31μs -> 4.64μs (14.6% faster)

def test_large_scale_multiple_token_pairs():
# Many different token pairs, each with its own pad_value
pairs = [(i, i+1000) for i in range(100, 110)]
pattern = MultiModalityDataPaddingPatternTokenPairs(pairs)
mm_inputs = DummyMMInputs([DummyMMItem(i) for i in range(10)])
input_ids = []
expected = []
for i, (start, end) in enumerate(pairs):
input_ids.extend([start, 1, 2, end])
expected.extend([start, i, i, end])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 11.8μs -> 10.6μs (10.9% faster)

def test_large_scale_pad_values_list_shorter_than_pairs():
# 10 pairs, only 2 pad_values, last pad_value should be used for all extra pairs
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)]*10)
mm_inputs = DummyMMInputs([DummyMMItem(7), DummyMMItem(8)])
input_ids = []
expected = []
for i in range(10):
input_ids.extend([100, 1, 2, 200])
expected.extend([100, 7 if i==0 else 8, 7 if i==0 else 8, 200])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 10.5μs -> 9.60μs (9.05% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-MultiModalityDataPaddingPatternTokenPairs.pad_input_tokens-mhv5jt52 and push.

Codeflash Static Badge

The optimized code achieves an **18% speedup** through three key algorithmic and implementation improvements:

**1. Eliminated redundant set recreations**
The original code recreated `start_token_ids` and `end_tokens_ids` sets using comprehensions on each method call. The optimization moves this to a simple loop that builds the sets once, eliminating the overhead of comprehension execution and temporary list creation.

**2. Single-pass token scanning**
Instead of making two separate O(n) passes through `input_ids` to find start and end indices via list comprehensions, the optimized version scans the input once, collecting both start and end indices simultaneously. This halves the number of token lookups and reduces cache misses for large inputs.

**3. Optimized list operations**
- Replaced `mm_inputs.data_offsets += [start_idx]` with direct `append()` to avoid list concatenation overhead
- Cached `self.data_start_token_ids` in a local variable to eliminate repeated attribute lookups inside the loop

**Performance characteristics:**
- **Small inputs (4-10 tokens)**: 15-25% improvement due to reduced overhead
- **Medium inputs (100-1000 tokens)**: 10-20% improvement from single-pass scanning
- **Large inputs (900+ tokens)**: 30-35% improvement as the algorithmic benefits compound

The optimizations are particularly effective for the multimodal token processing workload, where input sequences can contain hundreds of tokens with multiple data regions that need padding. Since this function processes token sequences during model inference, the cumulative effect of these micro-optimizations can significantly impact overall throughput.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 11, 2025 22:38
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant