Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 11, 2025

📄 55% (0.55x) speedup for _quantize_k_cache_slow in python/sglang/srt/layers/attention/nsa/quant_k_cache.py

⏱️ Runtime : 16.4 milliseconds 10.6 milliseconds (best of 169 runs)

📝 Explanation and details

The optimized code achieves a 55% speedup through key optimizations that reduce computational overhead in the quantization loop:

Main Performance Improvements:

  1. Vectorized tensor operations: Replaced torch.abs(...).max(dim=-1).values with torch.abs(input_tile).amax(dim=-1), eliminating intermediate .values access and using the more efficient amax operation (line profiler shows ~47% time spent on torch.abs reduced to ~38%).

  2. Eliminated in-place operations: Removed cur_scale_factors_inv.unsqueeze_(-1) and used direct broadcasting with .unsqueeze(-1) in the division, avoiding unnecessary tensor mutations and temporary allocations.

  3. Pre-computed slice boundaries: Stored start and end indices once per tile iteration instead of computing tile_idx * tile_size multiple times, reducing redundant arithmetic.

  4. Streamlined tensor slicing: Used cached input_tile variable to avoid repeated indexing of input_k_cache[..., start:end], which the profiler showed was expensive (12.9% of time in original).

Why This Matters:
The function is called from quantize_k_cache() as part of attention mechanism quantization in neural network inference. Based on the test results showing 65.9% speedup on large tensors (512 blocks, 32 sequences) and consistent 6-85% improvements across various input sizes, this optimization significantly benefits workloads with:

  • Large batch sizes or sequence lengths
  • Repeated quantization operations during model inference
  • CUDA deployments where vectorized operations show greater benefits

The optimizations maintain identical output correctness while reducing the computational bottleneck in the tile-wise quantization loop, making it particularly valuable for production inference scenarios.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 26 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime

import pytest
import torch
from sglang.srt.layers.attention.nsa.quant_k_cache import
_quantize_k_cache_slow

unit tests

----------- BASIC TEST CASES -----------

def test_basic_shape_and_dtype_cpu():
# Basic test: 1 block, 1 sequence, dv=128, tile_size=64, d=192
num_blocks, block_size, h_k, d = 1, 1, 1, 192
dv, tile_size = 128, 64
# dv % tile_size == 0
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32, device='cpu')
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 104μs -> 95.6μs (9.11% faster)
# Output shape
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4 + input_elem_size * (d - dv))

def test_basic_values_no_rope():
# Basic test: d=dv, so no rope part
num_blocks, block_size, h_k, d = 2, 3, 1, 128
dv, tile_size = 128, 64
input_k_cache = torch.ones((num_blocks, block_size, h_k, d), dtype=torch.float32, device='cpu')
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 112μs -> 100μs (12.7% faster)
# Rope part should be empty
input_elem_size = input_k_cache.element_size()
num_tiles = dv // tile_size
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)
# All quantized values should be the same (since input is all ones)
# The scale factor for each tile should be 1/448
scale_factors = output[0,0,0,dv:dv+num_tiles*4].view(torch.float32)

def test_edge_minimum_sizes():
# Edge case: minimum allowed sizes
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 128
input_k_cache = torch.zeros((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 115μs -> 108μs (6.32% faster)
# Output shape
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)
# All scale factors should be 0 (since max(abs(0)) == 0)
scale_factors = output[0,0,0,dv:dv+num_tiles*4].view(torch.float32)

def test_edge_tile_size_not_divisor():
# Edge case: dv % tile_size != 0 should raise AssertionError
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 60
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 1.41μs -> 1.40μs (0.357% faster)

def test_edge_h_k_not_one():
# Edge case: h_k != 1 should raise AssertionError
num_blocks, block_size, h_k, d = 1, 1, 2, 128
dv, tile_size = 128, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 2.63μs -> 2.68μs (1.87% slower)

def test_edge_device_cuda():
# Edge case: input on CUDA device
if torch.cuda.is_available():
num_blocks, block_size, h_k, d = 2, 2, 1, 128
dv, tile_size = 128, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32, device='cuda')
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output

----------- LARGE SCALE TEST CASES -----------

def test_large_scale_512_blocks_32_seq():
# Large scale: 512 blocks, 32 sequence, dv=256, tile_size=64, d=300
num_blocks, block_size, h_k, d = 512, 32, 1, 300
dv, tile_size = 256, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 11.9ms -> 7.14ms (65.9% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4 + input_elem_size * (d - dv))

def test_large_scale_max_size_under_100mb():
# Large scale: max size under 100MB
# float32: 4 bytes, so 100MB/4 = 25M elements
# Let's use: num_blocks=64, block_size=64, h_k=1, d=64
num_blocks, block_size, h_k, d = 64, 64, 1, 64
dv, tile_size = 64, 32
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 804μs -> 433μs (85.6% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)

def test_large_scale_randomized():
# Large scale: random values, check deterministic output shape and dtype
num_blocks, block_size, h_k, d = 128, 8, 1, 256
dv, tile_size = 256, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 767μs -> 441μs (73.8% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
expected_shape = (num_blocks, block_size, h_k, dv + num_tiles * 4)

def test_large_scale_rope_part_consistency():
# Large scale: d > dv, check rope part consistency
num_blocks, block_size, h_k, d = 16, 16, 1, 300
dv, tile_size = 256, 64
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); output = codeflash_output # 319μs -> 225μs (41.7% faster)
num_tiles = dv // tile_size
input_elem_size = input_k_cache.element_size()
rope_part = output[0,0,0,dv+num_tiles*4:].view(torch.float32)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
import pytest
import torch
from sglang.srt.layers.attention.nsa.quant_k_cache import
_quantize_k_cache_slow

unit tests

Helper: test device selection (CPU if no CUDA, else CUDA)

def get_test_device():
return "cuda" if torch.cuda.is_available() else "cpu"

Basic Test Cases

def test_basic_shape_and_dtype():
# Test with typical parameters, small tensor
num_blocks, block_size, h_k, d = 2, 3, 1, 520
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 197μs -> 180μs (9.88% faster)
# Check output shape
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))

def test_basic_content_rope_copy():
# Test that the rope part is copied exactly
num_blocks, block_size, h_k, d = 1, 2, 1, 520
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 183μs -> 163μs (11.9% faster)
# Rope part in output
rope_start = dv + 4 * (dv // tile_size)
rope_end = rope_start + input_k_cache.element_size() * (d - dv)
# The rope part should match input_k_cache[..., dv:]
result_rope = result.view(num_blocks, block_size, rope_end)[..., rope_start:rope_end].view(input_k_cache.dtype)
input_rope = input_k_cache.squeeze(2)[..., dv:]

def test_basic_scale_factor_shape_and_type():
# Test scale factor shape and type
num_blocks, block_size, h_k, d = 1, 1, 1, 512
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 168μs -> 149μs (12.6% faster)
# scale factor part
scale_start = dv
scale_end = dv + 4 * (dv // tile_size)
scale_part = result.view(num_blocks, block_size, scale_end)[..., scale_start:scale_end].view(torch.float32)
# Should be shape [num_blocks, block_size, dv//tile_size]
expected_shape = (num_blocks, block_size, dv // tile_size)

Edge Test Cases

def test_edge_zero_input():
# All zeros input, scale factor should be zero
num_blocks, block_size, h_k, d = 1, 1, 1, 512
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.zeros((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 164μs -> 147μs (11.5% faster)
scale_start = dv
scale_end = dv + 4 * (dv // tile_size)
scale_part = result.view(num_blocks, block_size, scale_end)[..., scale_start:scale_end].view(torch.float32)

def test_edge_negative_input():
# Negative values, scale should be max(abs(...))
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = -torch.ones((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 80.7μs -> 75.5μs (6.94% faster)
scale_start = dv
scale_end = dv + 4 * (dv // tile_size)
scale_part = result.view(num_blocks, block_size, scale_end)[..., scale_start:scale_end].view(torch.float32)

def test_edge_single_tile():
# dv == tile_size, only one tile
num_blocks, block_size, h_k, d = 2, 2, 1, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 89.4μs -> 85.5μs (4.46% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))

def test_edge_minimal_block():
# Minimal block size and num_blocks
num_blocks, block_size, h_k, d = 1, 1, 1, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 86.3μs -> 81.0μs (6.60% faster)

def test_edge_invalid_h_k():
# Should assert if h_k != 1
num_blocks, block_size, h_k, d = 1, 1, 2, 128
dv, tile_size = 128, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 2.66μs -> 2.45μs (8.28% faster)

def test_edge_invalid_tile_size():
# Should assert if dv % tile_size != 0
num_blocks, block_size, h_k, d = 1, 1, 1, 130
dv, tile_size = 128, 7 # 128 % 7 != 0
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
with pytest.raises(AssertionError):
_quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size) # 1.28μs -> 1.22μs (5.02% faster)

def test_edge_d_less_than_dv():
# d < dv, rope part should be empty
num_blocks, block_size, h_k, d = 1, 1, 1, 100
dv, tile_size = 100, 100
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 113μs -> 106μs (6.37% faster)
# Rope part should be empty
rope_start = dv + 4 * (dv // tile_size)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size))

Large Scale Test Cases

def test_large_scale_max_size():
# Test with near-maximum allowed size (under 100MB)
num_blocks, block_size, h_k, d = 8, 16, 1, 512
dv, tile_size = 512, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 328μs -> 263μs (24.5% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))

def test_large_scale_many_blocks():
# Test with many blocks and block_size, but <1000 elements
num_blocks, block_size, h_k, d = 32, 16, 1, 256
dv, tile_size = 256, 128
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 405μs -> 307μs (32.0% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))

def test_large_scale_dv_not_512():
# Test with dv not equal to 512
num_blocks, block_size, h_k, d = 4, 4, 1, 256
dv, tile_size = 256, 64
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.bfloat16, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 201μs -> 182μs (10.5% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))

def test_large_scale_varied_dtype():
# Test with float16 and float32 input
num_blocks, block_size, h_k, d = 2, 2, 1, 128
dv, tile_size = 128, 64
device = get_test_device()
for dtype in [torch.float16, torch.float32]:
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=dtype, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 193μs -> 176μs (9.44% faster)
expected_shape = (num_blocks, block_size, h_k, dv + 4 * (dv // tile_size) + input_k_cache.element_size() * (d - dv))

def test_large_scale_rope_part_content():
# Test rope part for large d-dv
num_blocks, block_size, h_k, d = 2, 2, 1, 256
dv, tile_size = 128, 64
device = get_test_device()
input_k_cache = torch.randn((num_blocks, block_size, h_k, d), dtype=torch.float32, device=device)
codeflash_output = _quantize_k_cache_slow(input_k_cache, dv=dv, tile_size=tile_size); result = codeflash_output # 109μs -> 99.5μs (10.5% faster)
rope_start = dv + 4 * (dv // tile_size)
rope_end = rope_start + input_k_cache.element_size() * (d - dv)
result_rope = result.view(num_blocks, block_size, rope_end)[..., rope_start:rope_end].view(input_k_cache.dtype)
input_rope = input_k_cache.squeeze(2)[..., dv:]

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_quantize_k_cache_slow-mhv5256r and push.

Codeflash Static Badge

The optimized code achieves a **55% speedup** through key optimizations that reduce computational overhead in the quantization loop:

**Main Performance Improvements:**

1. **Vectorized tensor operations**: Replaced `torch.abs(...).max(dim=-1).values` with `torch.abs(input_tile).amax(dim=-1)`, eliminating intermediate `.values` access and using the more efficient `amax` operation (line profiler shows ~47% time spent on `torch.abs` reduced to ~38%).

2. **Eliminated in-place operations**: Removed `cur_scale_factors_inv.unsqueeze_(-1)` and used direct broadcasting with `.unsqueeze(-1)` in the division, avoiding unnecessary tensor mutations and temporary allocations.

3. **Pre-computed slice boundaries**: Stored `start` and `end` indices once per tile iteration instead of computing `tile_idx * tile_size` multiple times, reducing redundant arithmetic.

4. **Streamlined tensor slicing**: Used cached `input_tile` variable to avoid repeated indexing of `input_k_cache[..., start:end]`, which the profiler showed was expensive (12.9% of time in original).

**Why This Matters:**
The function is called from `quantize_k_cache()` as part of attention mechanism quantization in neural network inference. Based on the test results showing **65.9% speedup on large tensors** (512 blocks, 32 sequences) and consistent 6-85% improvements across various input sizes, this optimization significantly benefits workloads with:
- Large batch sizes or sequence lengths
- Repeated quantization operations during model inference
- CUDA deployments where vectorized operations show greater benefits

The optimizations maintain identical output correctness while reducing the computational bottleneck in the tile-wise quantization loop, making it particularly valuable for production inference scenarios.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 11, 2025 22:24
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant