⚡️ Speed up method MultiModalityDataPaddingPatternTokenPairs.pad_input_tokens by 19%
#342
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 19% (0.19x) speedup for
MultiModalityDataPaddingPatternTokenPairs.pad_input_tokensinpython/sglang/srt/managers/mm_utils.py⏱️ Runtime :
697 microseconds→587 microseconds(best of67runs)📝 Explanation and details
The optimized code achieves an 18% speedup through three key algorithmic and implementation improvements:
1. Eliminated redundant set recreations
The original code recreated
start_token_idsandend_tokens_idssets using comprehensions on each method call. The optimization moves this to a simple loop that builds the sets once, eliminating the overhead of comprehension execution and temporary list creation.2. Single-pass token scanning
Instead of making two separate O(n) passes through
input_idsto find start and end indices via list comprehensions, the optimized version scans the input once, collecting both start and end indices simultaneously. This halves the number of token lookups and reduces cache misses for large inputs.3. Optimized list operations
mm_inputs.data_offsets += [start_idx]with directappend()to avoid list concatenation overheadself.data_start_token_idsin a local variable to eliminate repeated attribute lookups inside the loopPerformance characteristics:
The optimizations are particularly effective for the multimodal token processing workload, where input sequences can contain hundreds of tokens with multiple data regions that need padding. Since this function processes token sequences during model inference, the cumulative effect of these micro-optimizations can significantly impact overall throughput.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
from typing import List, Optional, Tuple
imports
import pytest
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs
Minimal stubs for dependencies
class MultimodalDataItem:
def init(self, pad_value):
self.pad_value = pad_value
class MultimodalInputs:
def init(self, mm_items=None, im_start_id=None, im_end_id=None):
self.mm_items = mm_items if mm_items is not None else []
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.data_offsets = []
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs
Unit tests
Basic Test Cases
def test_basic_single_pair():
# Single token pair, one data item
input_ids = [1, 101, 5, 6, 7, 102, 2]
mm_items = [MultimodalDataItem(pad_value=99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.22μs -> 4.51μs (15.7% faster)
def test_basic_multiple_pairs():
# Multiple token pairs, two data items
input_ids = [1, 101, 5, 6, 102, 2, 101, 7, 8, 102, 3]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.08μs -> 5.26μs (15.5% faster)
def test_basic_no_data_tokens():
# Input with no data token pairs
input_ids = [1, 2, 3, 4]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.58μs -> 2.99μs (20.0% faster)
def test_basic_multiple_pair_types():
# Two types of pairs
input_ids = [1, 101, 5, 102, 2, 201, 6, 202, 3]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102), (201, 202)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.29μs -> 5.54μs (13.3% faster)
def test_basic_start_token_ids_override():
# Custom start_token_ids
input_ids = [1, 101, 5, 102, 2, 201, 6, 202, 3]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
# Only 201 is considered a start
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102), (201, 202)], data_start_token_ids=[201])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.91μs -> 5.25μs (12.6% faster)
Edge Test Cases
def test_edge_unmatched_pairs():
# Unmatched start/end tokens
input_ids = [1, 101, 5, 6, 102, 101, 7, 8, 2]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.35μs -> 2.76μs (21.4% faster)
def test_edge_no_mm_items():
# No multimodal items
input_ids = [1, 101, 5, 6, 102, 2]
mm_items = []
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
# Should fail due to pad_values being empty, but function will try to access pad_value
with pytest.raises(IndexError):
pattern.pad_input_tokens(input_ids, mm_inputs) # 4.71μs -> 4.14μs (13.7% faster)
def test_edge_pad_values_shorter_than_pairs():
# Fewer pad_values than pairs: should use last pad_value for remaining
input_ids = [1, 101, 5, 102, 2, 101, 6, 102, 3]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
# Both regions replaced by 99
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.07μs -> 5.42μs (12.0% faster)
def test_edge_empty_input_ids():
# Empty input_ids
input_ids = []
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.09μs -> 2.73μs (13.2% faster)
def test_edge_zero_length_data_region():
# Data region with zero length (start token immediately followed by end token)
input_ids = [1, 101, 102, 2]
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.92μs -> 4.98μs (18.9% faster)
def test_edge_multiple_consecutive_pairs():
# Multiple consecutive pairs
input_ids = [101, 5, 102, 101, 6, 102]
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.25μs -> 5.42μs (15.4% faster)
Large Scale Test Cases
def test_large_scale_long_data_region():
# One large data region
n = 900
input_ids = [1, 101] + [i for i in range(n)] + [102, 2]
mm_items = [MultimodalDataItem(77)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 51.0μs -> 37.9μs (34.5% faster)
# All n tokens replaced by 77
expected = [1, 101] + [77] * n + [102, 2]
def test_large_scale_multiple_pair_types():
# Multiple pair types, alternating
n = 400
input_ids = []
for i in range(n):
if i % 2 == 0:
input_ids.extend([101, i, i+1, 102])
else:
input_ids.extend([201, i, 202])
mm_items = [MultimodalDataItem(99), MultimodalDataItem(88)] * (n // 2)
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102), (201, 202)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 83.4μs -> 75.4μs (10.6% faster)
# Regions replaced by 99 or 88, alternating
expected = []
idx = -1
for i in range(n):
if i % 2 == 0:
idx += 1
val = mm_items[idx].pad_value
expected.extend([101, val, val, 102])
else:
idx += 1
val = mm_items[idx].pad_value
expected.extend([201, val, 202])
def test_large_scale_all_pad_value_same():
# All pad values are the same
n = 700
input_ids = []
for i in range(n):
input_ids.extend([101, i, 102])
mm_items = [MultimodalDataItem(55)] * n
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 320μs -> 278μs (15.1% faster)
expected = []
for i in range(n):
expected.extend([101, 55, 102])
def test_large_scale_no_pairs():
# Large input, no pairs
n = 1000
input_ids = list(range(n))
mm_items = [MultimodalDataItem(99)]
mm_inputs = MultimodalInputs(mm_items=mm_items)
pattern = MultiModalityDataPaddingPatternTokenPairs(data_token_pairs=[(101, 102)])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 54.5μs -> 40.4μs (35.0% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from typing import List, Optional, Tuple
imports
import pytest
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs
Mocks and minimal stubs to support testing
class DummyMMItem:
def init(self, pad_value):
self.pad_value = pad_value
class DummyMMInputs:
def init(self, mm_items, im_start_id=None, im_end_id=None):
self.mm_items = mm_items
self.im_start_id = im_start_id
self.im_end_id = im_end_id
self.data_offsets = []
class MultiModalityDataPaddingPattern:
pass # placeholder for base class
from sglang.srt.managers.mm_utils import
MultiModalityDataPaddingPatternTokenPairs
-------------------- UNIT TESTS --------------------
1. Basic Test Cases
def test_basic_single_pair():
# Single pair, one data item, simple replacement
# tokens: [10, 100, 1, 2, 3, 200, 20]
# pairs: (100, 200)
# pad_value: 0
# expected: [10, 100, 0, 0, 0, 200, 20]
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(0)])
input_ids = [10, 100, 1, 2, 3, 200, 20]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.49μs -> 4.62μs (19.0% faster)
def test_basic_multiple_pairs_multiple_data_items():
# Two pairs, two data items, two pad_values
# tokens: [100, 1, 2, 200, 100, 3, 4, 200]
# pairs: (100, 200)
# pad_values: 7, 8
# expected: [100, 7, 7, 200, 100, 8, 8, 200]
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(7), DummyMMItem(8)])
input_ids = [100, 1, 2, 200, 100, 3, 4, 200]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.00μs -> 5.40μs (11.0% faster)
def test_basic_different_token_pairs():
# Two different token pairs
# tokens: [10, 100, 1, 2, 200, 20, 101, 3, 4, 201, 30]
# pairs: (100, 200), (101, 201)
# pad_values: 5, 6
# expected: [10, 100, 5, 5, 200, 20, 101, 6, 6, 201, 30]
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200), (101, 201)])
mm_inputs = DummyMMInputs([DummyMMItem(5), DummyMMItem(6)])
input_ids = [10, 100, 1, 2, 200, 20, 101, 3, 4, 201, 30]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 6.19μs -> 5.59μs (10.8% faster)
def test_edge_mismatched_pairs():
# Start/end tokens count mismatch, should return input_ids unchanged
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(0)])
input_ids = [10, 100, 1, 2, 3, 20] # missing end token
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.75μs -> 3.08μs (21.9% faster)
def test_edge_more_pairs_than_pad_values():
# More pairs than pad_values: should use last pad_value for extra pairs
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(8)])
input_ids = [100, 1, 2, 200, 100, 3, 4, 200]
# Only one pad_value, two pairs: both should use 8
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 7.13μs -> 6.11μs (16.8% faster)
def test_edge_empty_input():
# Empty input_ids
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(0)])
input_ids = []
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.33μs -> 2.72μs (22.2% faster)
def test_edge_empty_token_pairs():
# Empty data_token_pairs (should fallback to mm_inputs.im_start_id/im_end_id)
pattern = MultiModalityDataPaddingPatternTokenPairs([])
mm_inputs = DummyMMInputs([DummyMMItem(2)], im_start_id=101, im_end_id=201)
input_ids = [101, 1, 2, 201]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 3.56μs -> 2.86μs (24.5% faster)
def test_edge_start_token_not_in_data_start_token_ids():
# If a start token is not in data_start_token_ids, data_idx should not increment
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)], data_start_token_ids=[101])
mm_inputs = DummyMMInputs([DummyMMItem(5)])
input_ids = [100, 1, 2, 200]
# data_idx remains -1, pad_value = pad_values[-1] = 5
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 4.98μs -> 4.34μs (14.8% faster)
def test_edge_pad_value_is_none():
# pad_value is None: should insert None
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(None)])
input_ids = [10, 100, 1, 2, 200, 20]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.17μs -> 4.33μs (19.4% faster)
3. Large Scale Test Cases
def test_large_scale_long_input_ids():
# Long input_ids, but only a few pairs
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(42)])
# input_ids: 500 normal tokens, one pair in the middle, 500 more
input_ids = list(range(500)) + [100, 1, 2, 3, 200] + list(range(500, 1000))
expected = list(range(500)) + [100, 42, 42, 42, 200] + list(range(500, 1000))
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 57.8μs -> 42.0μs (37.6% faster)
def test_large_scale_large_pad_value():
# Large pad_value (e.g. 106), ensure values are correct and not truncated
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)])
mm_inputs = DummyMMInputs([DummyMMItem(106)])
input_ids = [100, 1, 2, 3, 200]
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 5.31μs -> 4.64μs (14.6% faster)
def test_large_scale_multiple_token_pairs():
# Many different token pairs, each with its own pad_value
pairs = [(i, i+1000) for i in range(100, 110)]
pattern = MultiModalityDataPaddingPatternTokenPairs(pairs)
mm_inputs = DummyMMInputs([DummyMMItem(i) for i in range(10)])
input_ids = []
expected = []
for i, (start, end) in enumerate(pairs):
input_ids.extend([start, 1, 2, end])
expected.extend([start, i, i, end])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 11.8μs -> 10.6μs (10.9% faster)
def test_large_scale_pad_values_list_shorter_than_pairs():
# 10 pairs, only 2 pad_values, last pad_value should be used for all extra pairs
pattern = MultiModalityDataPaddingPatternTokenPairs([(100, 200)]*10)
mm_inputs = DummyMMInputs([DummyMMItem(7), DummyMMItem(8)])
input_ids = []
expected = []
for i in range(10):
input_ids.extend([100, 1, 2, 200])
expected.extend([100, 7 if i==0 else 8, 7 if i==0 else 8, 200])
codeflash_output = pattern.pad_input_tokens(input_ids, mm_inputs); result = codeflash_output # 10.5μs -> 9.60μs (9.05% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-MultiModalityDataPaddingPatternTokenPairs.pad_input_tokens-mhv5jt52and push.