⚡️ Speed up function _quantize_k_cache_slow by 55%
#341
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 55% (0.55x) speedup for
_quantize_k_cache_slowinpython/sglang/srt/layers/attention/nsa/quant_k_cache.py⏱️ Runtime :
16.4 milliseconds→10.6 milliseconds(best of169runs)📝 Explanation and details
The optimized code achieves a 55% speedup through key optimizations that reduce computational overhead in the quantization loop:
Main Performance Improvements:
Vectorized tensor operations: Replaced
torch.abs(...).max(dim=-1).valueswithtorch.abs(input_tile).amax(dim=-1), eliminating intermediate.valuesaccess and using the more efficientamaxoperation (line profiler shows ~47% time spent ontorch.absreduced to ~38%).Eliminated in-place operations: Removed
cur_scale_factors_inv.unsqueeze_(-1)and used direct broadcasting with.unsqueeze(-1)in the division, avoiding unnecessary tensor mutations and temporary allocations.Pre-computed slice boundaries: Stored
startandendindices once per tile iteration instead of computingtile_idx * tile_sizemultiple times, reducing redundant arithmetic.Streamlined tensor slicing: Used cached
input_tilevariable to avoid repeated indexing ofinput_k_cache[..., start:end], which the profiler showed was expensive (12.9% of time in original).Why This Matters:
The function is called from
quantize_k_cache()as part of attention mechanism quantization in neural network inference. Based on the test results showing 65.9% speedup on large tensors (512 blocks, 32 sequences) and consistent 6-85% improvements across various input sizes, this optimization significantly benefits workloads with:The optimizations maintain identical output correctness while reducing the computational bottleneck in the tile-wise quantization loop, making it particularly valuable for production inference scenarios.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest
import torch
from sglang.srt.layers.attention.nsa.quant_k_cache import
_quantize_k_cache_slow
unit tests
----------- BASIC TEST CASES -----------
def test_basic_shape_and_dtype_cpu():
# Basic test: 1 block, 1 sequence, dv=128, tile_size=64, d=192
num_blocks, block_size, h_k, d = 1, 1, 1, 192
dv, tile_size = 128, 64
# dv % tile_size == 0
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32, device='cpu')
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 104μs -> 95.6μs (9.11% faster)
# Output shape
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4 + input_elem_size * (d - dv))
def test_basic_values_no_rope():
# Basic test: d=dv, so no rope part
num_blocks, block_size, h_k, d = 2, 3, 1, 128
dv, tile_size = 128, 64
input_k_cache = torch.ones((num_blocks, block_size, h_k, d), dtype=torch.float32, device='cpu')
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 112μs -> 100μs (12.7% faster)
# Rope part should be empty
input_elem_size = input_k_cache.element_size()
num_tiles = dv // tile_size
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)
# All quantized values should be the same (since input is all ones)
# The scale factor for each tile should be 1/448
scale_factors = output[0,0,0,dv:dv+num_tiles*4].view(torch.float32)
def test_edge_minimum_sizes():
# Edge case: minimum allowed sizes
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 128
input_k_cache = torch.zeros((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 115μs -> 108μs (6.32% faster)
# Output shape
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)
# All scale factors should be 0 (since max(abs(0)) == 0)
scale_factors = output[0,0,0,dv:dv+num_tiles*4].view(torch.float32)
def test_edge_tile_size_not_divisor():
# Edge case: dv % tile_size != 0 should raise AssertionError
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 60
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 1.41μs -> 1.40μs (0.357% faster)
def test_edge_h_k_not_one():
# Edge case: h_k != 1 should raise AssertionError
num_blocks, block_size, h_k, d = 1, 1, 2, 128
dv, tile_size = 128, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 2.63μs -> 2.68μs (1.87% slower)
def test_edge_device_cuda():
# Edge case: input on CUDA device
if torch.cuda.is_available():
num_blocks, block_size, h_k, d = 2, 2, 1, 128
dv, tile_size = 128, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32, device='cuda')
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output
----------- LARGE SCALE TEST CASES -----------
def test_large_scale_512_blocks_32_seq():
# Large scale: 512 blocks, 32 sequence, dv=256, tile_size=64, d=300
num_blocks, block_size, h_k, d = 512, 32, 1, 300
dv, tile_size = 256, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 11.9ms -> 7.14ms (65.9% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4 + input_elem_size * (d - dv))
def test_large_scale_max_size_under_100mb():
# Large scale: max size under 100MB
# float32: 4 bytes, so 100MB/4 = 25M elements
# Let's use: num_blocks=64, block_size=64, h_k=1, d=64
num_blocks, block_size, h_k, d = 64, 64, 1, 64
dv, tile_size = 64, 32
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 804μs -> 433μs (85.6% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)
def test_large_scale_randomized():
# Large scale: random values, check deterministic output shape and dtype
num_blocks, block_size, h_k, d = 128, 8, 1, 256
dv, tile_size = 256, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 767μs -> 441μs (73.8% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)
def test_large_scale_rope_part_consistency():
# Large scale: d > dv, check rope part consistency
num_blocks, block_size, h_k, d = 16, 16, 1, 300
dv, tile_size = 256, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 319μs -> 225μs (41.7% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
rope_part = output[0,0,0,dv+num_tiles*4:].view(torch.float32)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest
import torch
from sglang.srt.layers.attention.nsa.quant_k_cache import
_quantize_k_cache_slow
unit tests
Helper: test device selection (CPU if no CUDA, else CUDA)
def get_test_device():
return "cuda" if torch.cuda.is_available() else "cpu"
Basic Test Cases
def test_basic_shape_and_dtype():
# Test with typical parameters, small tensor
num_blocks, block_size, h_k, d = 2, 3, 1, 520
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 197μs -> 180μs (9.88% faster)
# Check output shape
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))
def test_basic_content_rope_copy():
# Test that the rope part is copied exactly
num_blocks, block_size, h_k, d = 1, 2, 1, 520
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 183μs -> 163μs (11.9% faster)
# Rope part in output
rope_start = dv + 4 * (dv // tile_size)
rope_end = rope_start + input_k_cache.element_size() * (d - dv)
# The rope part should match input_k_cache[..., dv:]
result_rope = result.view(num_blocks, block_size, rope_end)[..., rope_start:rope_end].view(input_k_cache.dtype)
input_rope = input_k_cache.squeeze(2)[..., dv:]
def test_basic_scale_factor_shape_and_type():
# Test scale factor shape and type
num_blocks, block_size, h_k, d = 1, 1, 1, 512
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 168μs -> 149μs (12.6% faster)
# scale factor part
scale_start = dv
scale_end = dv + 4 * (dv // tile_size)
scale_part = result.view(num_blocks, block_size, scale_end)[..., scale_start:scale_end].view(torch.float32)
# Should be shape [num_blocks, block_size, dv//tile_size]
expected_shape = (num_blocks, block_size, dv // tile_size)
Edge Test Cases
def test_edge_zero_input():
# All zeros input, scale factor should be zero
num_blocks, block_size, h_k, d = 1, 1, 1, 512
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.zeros((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 164μs -> 147μs (11.5% faster)
scale_start = dv
scale_end = dv + 4 * (dv // tile_size)
scale_part = result.view(num_blocks, block_size, scale_end)[..., scale_start:scale_end].view(torch.float32)
def test_edge_negative_input():
# Negative values, scale should be max(abs(...))
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = -torch.ones((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 80.7μs -> 75.5μs (6.94% faster)
scale_start = dv
scale_end = dv + 4 * (dv // tile_size)
scale_part = result.view(num_blocks, block_size, scale_end)[..., scale_start:scale_end].view(torch.float32)
def test_edge_single_tile():
# dv == tile_size, only one tile
num_blocks, block_size, h_k, d = 2, 2, 1, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 89.4μs -> 85.5μs (4.46% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))
def test_edge_minimal_block():
# Minimal block size and num_blocks
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 86.3μs -> 81.0μs (6.60% faster)
def test_edge_invalid_h_k():
# Should assert if h_k != 1
num_blocks, block_size, h_k, d = 1, 1, 2, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 2.66μs -> 2.45μs (8.28% faster)
def test_edge_invalid_tile_size():
# Should assert if dv % tile_size != 0
num_blocks, block_size, h_k, d = 1, 1, 1, 130
dv, tile_size = 128, 7 # 128 % 7 != 0
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 1.28μs -> 1.22μs (5.02% faster)
def test_edge_d_less_than_dv():
# d < dv, rope part should be empty
num_blocks, block_size, h_k, d = 1, 1, 1, 100
dv, tile_size = 100, 100
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 113μs -> 106μs (6.37% faster)
# Rope part should be empty
rope_start = dv + 4 * (dv // tile_size)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size))
Large Scale Test Cases
def test_large_scale_max_size():
# Test with near-maximum allowed size (under 100MB)
num_blocks, block_size, h_k, d = 8, 16, 1, 512
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 328μs -> 263μs (24.5% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))
def test_large_scale_many_blocks():
# Test with many blocks and block_size, but <1000 elements
num_blocks, block_size, h_k, d = 32, 16, 1, 256
dv, tile_size = 256, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 405μs -> 307μs (32.0% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))
def test_large_scale_dv_not_512():
# Test with dv not equal to 512
num_blocks, block_size, h_k, d = 4, 4, 1, 256
dv, tile_size = 256, 64
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 201μs -> 182μs (10.5% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))
def test_large_scale_varied_dtype():
# Test with float16 and float32 input
num_blocks, block_size, h_k, d = 2, 2, 1, 128
dv, tile_size = 128, 64
device = get_test_device()
for dtype in [torch.float16, torch.float32]:
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=dtype, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 193μs -> 176μs (9.44% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))
def test_large_scale_rope_part_content():
# Test rope part for large d-dv
num_blocks, block_size, h_k, d = 2, 2, 1, 256
dv, tile_size = 128, 64
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 109μs -> 99.5μs (10.5% faster)
rope_start = dv + 4 * (dv // tile_size)
rope_end = rope_start + input_k_cache.element_size() * (d - dv)
result_rope = result.view(num_blocks, block_size, rope_end)[..., rope_start:rope_end].view(input_k_cache.dtype)
input_rope = input_k_cache.squeeze(2)[..., dv:]
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-_quantize_k_cache_slow-mhv5256rand push.