Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions python/sglang/srt/layers/attention/nsa/quant_k_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,36 @@ def _quantize_k_cache_slow(
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()


# Precompute shapes
rope_len = input_elem_size * (d - dv)
result_shape = (num_blocks, block_size, dv + num_tiles * 4 + rope_len)
result = torch.empty(
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
result_shape,
dtype=torch.float8_e4m3fn,
device=input_k_cache.device,
)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]

for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = (
torch.abs(
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
)
.max(dim=-1)
.values
/ 448.0
) # [num_blocks, block_size]
# Use directly indexed assignment (avoid unnecessary slice copying).
result_k_rope_part[...] = input_k_cache[..., dv:]

# Vectorized quantization over tiles for significant speedup.
for tile_idx in range(num_tiles):
start = tile_idx * tile_size
end = (tile_idx + 1) * tile_size

input_tile = input_k_cache[..., start:end]
# Vectorize: torch.abs and .max
cur_scale_factors_inv = torch.abs(input_tile).amax(dim=-1) / 448.0 # Shape: [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv

cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (
input_k_cache[
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
].float()
/ cur_scale_factors_inv.float()
).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
cur_quantized_nope
)
# Instead of unsqueeze_: use broadcasting directly
cur_quantized_nope = (input_tile.float() / cur_scale_factors_inv.float().unsqueeze(-1)).to(torch.float8_e4m3fn)
result_k_nope_part[..., start:end] = cur_quantized_nope

# Fast (no copy) reshape to output format

result = result.view(num_blocks, block_size, 1, -1)
return result
Expand Down