Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions python/sglang/jit_kernel/hadamard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _hadamard_transform_impl(
scale: float,
pad_multiple: int,
kernel_fn: Callable,
out: torch.Tensor | None = None,
) -> torch.Tensor:
if not x.is_cuda:
raise RuntimeError(f"{kernel_fn.__name__} only supports CUDA tensors")
Expand All @@ -46,36 +47,69 @@ def _hadamard_transform_impl(

needs_pad = dim_og % pad_multiple != 0
if needs_pad:
if out is not None:
raise ValueError("out= not supported when input needs padding")
x = torch.nn.functional.pad(x, (0, pad_multiple - dim_og % pad_multiple))

out = torch.empty_like(x)
kernel_fn(x, out, scale)
if out is not None:
if out.device != x.device:
raise ValueError(f"out device {out.device} != x device {x.device}")
if out.dtype != x.dtype:
raise ValueError(f"out dtype {out.dtype} != x dtype {x.dtype}")
out_2d = out.reshape(-1, x.size(-1))
if out_2d.stride(-1) != 1:
raise ValueError("out must be contiguous in last dimension")
if out_2d.shape != x.shape:
raise ValueError(f"out shape {out_2d.shape} != x shape {x.shape}")
else:
out_2d = torch.empty_like(x)

kernel_fn(x, out_2d, scale)

if needs_pad:
out = out[:, :dim_og]
return out.reshape(shapes_og)
out_2d = out_2d[:, :dim_og]
return out_2d.reshape(shapes_og)
return out_2d.reshape(shapes_og)


def hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
def hadamard_transform(
x: torch.Tensor, scale: float = 1.0, out: torch.Tensor | None = None
) -> torch.Tensor:
module = _jit_hadamard_module(x.dtype)
return _hadamard_transform_impl(x, scale, 8, module.hadamard_transform)
return _hadamard_transform_impl(x, scale, 8, module.hadamard_transform, out=out)


def hadamard_transform_12n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
def hadamard_transform_12n(
x: torch.Tensor, scale: float = 1.0, out: torch.Tensor | None = None
) -> torch.Tensor:
module = _jit_hadamard_module(x.dtype)
return _hadamard_transform_impl(x, scale, 4 * 12, module.hadamard_transform_12n)
return _hadamard_transform_impl(
x, scale, 4 * 12, module.hadamard_transform_12n, out=out
)


def hadamard_transform_20n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
def hadamard_transform_20n(
x: torch.Tensor, scale: float = 1.0, out: torch.Tensor | None = None
) -> torch.Tensor:
module = _jit_hadamard_module(x.dtype)
return _hadamard_transform_impl(x, scale, 4 * 20, module.hadamard_transform_20n)
return _hadamard_transform_impl(
x, scale, 4 * 20, module.hadamard_transform_20n, out=out
)


def hadamard_transform_28n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
def hadamard_transform_28n(
x: torch.Tensor, scale: float = 1.0, out: torch.Tensor | None = None
) -> torch.Tensor:
module = _jit_hadamard_module(x.dtype)
return _hadamard_transform_impl(x, scale, 4 * 28, module.hadamard_transform_28n)
return _hadamard_transform_impl(
x, scale, 4 * 28, module.hadamard_transform_28n, out=out
)


def hadamard_transform_40n(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
def hadamard_transform_40n(
x: torch.Tensor, scale: float = 1.0, out: torch.Tensor | None = None
) -> torch.Tensor:
module = _jit_hadamard_module(x.dtype)
return _hadamard_transform_impl(x, scale, 4 * 40, module.hadamard_transform_40n)
return _hadamard_transform_impl(
x, scale, 4 * 40, module.hadamard_transform_40n, out=out
)
49 changes: 44 additions & 5 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def is_enabled(self) -> bool:
@dataclass
class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
active_kv_indices: Optional[List[Optional[torch.Tensor]]] = None


@dataclass
Expand Down Expand Up @@ -436,7 +437,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
fixed_split_size=self.decode_split_tile_size,
disable_split_kv=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
self.forward_metadata = DecodeMetadata(
self.decode_wrappers,
active_kv_indices=self.indices_updater_decode._pending_kv_indices,
)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
Expand Down Expand Up @@ -874,9 +878,8 @@ def forward_decode(
forward_batch: ForwardBatch,
save_kv_cache=True,
):
decode_wrapper = self.forward_metadata.decode_wrappers[
self._get_wrapper_idx(layer)
]
wrapper_idx = self._get_wrapper_idx(layer)
decode_wrapper = self.forward_metadata.decode_wrappers[wrapper_idx]
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
Expand All @@ -890,10 +893,35 @@ def forward_decode(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)

pool = forward_batch.token_to_kv_pool
is_tq = getattr(pool, "supports_sparse_dequant", False)

if is_tq:
# TurboQuant sparse dequant does not support sliding-window attention.
# SWA translates kv_indices into SWA-pool locations that don't address
# the TQ compressed pool, so sparse dequant would read wrong data.
if (
self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW
and wrapper_idx == 0
):
raise NotImplementedError(
"TurboQuant does not support sliding-window attention"
)

# Sparse dequant path: dequant only referenced tokens, then FlashInfer decode
kv_indices = (
self.forward_metadata.active_kv_indices[wrapper_idx]
if self.forward_metadata.active_kv_indices is not None
else None
)
kv_buf = pool.get_kv_buffer(layer.layer_id, kv_indices=kv_indices)
else:
kv_buf = pool.get_kv_buffer(layer.layer_id)

# Call the wrapped function
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
kv_buf,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
Expand Down Expand Up @@ -936,6 +964,10 @@ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBacken
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator

# Per-wrapper active kv_indices for sparse dequant (TurboQuant)
self._pending_kv_indices: List[Optional[torch.Tensor]] = [None] * attn_backend.num_wrappers
self._last_active_kv_indices: Optional[torch.Tensor] = None

# Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
self.update = self.update_sliding_window
Expand Down Expand Up @@ -985,6 +1017,7 @@ def update_single_wrapper(
fixed_split_size=fixed_split_size,
disable_split_kv=disable_split_kv,
)
self._pending_kv_indices[0] = self._last_active_kv_indices

def update_sliding_window(
self,
Expand Down Expand Up @@ -1035,6 +1068,7 @@ def update_sliding_window(
seq_lens_cpu=seq_lens_cpu_tmp,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)
self._pending_kv_indices[wrapper_id] = self._last_active_kv_indices

def update_cross_attention(
self,
Expand Down Expand Up @@ -1069,6 +1103,7 @@ def update_cross_attention(
spec_info,
seq_lens_cpu=seq_lens_cpu,
)
self._pending_kv_indices[wrapper_id] = self._last_active_kv_indices

def call_begin_forward(
self,
Expand Down Expand Up @@ -1171,6 +1206,10 @@ def call_begin_forward(
),
)

# Save active kv_indices slice for sparse dequant (TurboQuant)
active_len = int(kv_indptr[-1])
self._last_active_kv_indices = kv_indices[:active_len] if active_len > 0 else None

if locally_override:
global_override_indptr_cpu = None

Expand Down
111 changes: 107 additions & 4 deletions python/sglang/srt/layers/attention/flashinfer_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
@dataclass
class DecodeMetadata:
decode_wrapper: BatchMLAPagedAttentionWrapper
# Cached for TurboQuant fused decode kernel (bypasses FlashInfer)
kv_indptr: Optional[torch.Tensor] = None
kv_indices: Optional[torch.Tensor] = None
seq_lens: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -285,6 +289,11 @@ def __init__(
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify

# Pre-allocated TurboQuant fused decode split buffers (lazy init on first use)
self._tq_attn_logits: Optional[torch.Tensor] = None
self._tq_attn_lse: Optional[torch.Tensor] = None
self._tq_max_kv_splits = 128

def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
Expand All @@ -294,7 +303,12 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
decode_wrapper=self.decode_wrapper,
init_metadata_replay=False,
)
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
self.forward_metadata = DecodeMetadata(
self.decode_wrapper,
kv_indptr=self.indices_updater_decode._cached_kv_indptr,
kv_indices=self.indices_updater_decode._cached_kv_indices,
seq_lens=self.indices_updater_decode._cached_seq_lens,
)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
Expand Down Expand Up @@ -401,7 +415,12 @@ def init_forward_metadata_capture_cuda_graph(
spec_info=spec_info,
)
self.decode_cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = DecodeMetadata(decode_wrapper)
self.forward_metadata = DecodeMetadata(
decode_wrapper,
kv_indptr=self.indices_updater_decode._cached_kv_indptr,
kv_indices=self.indices_updater_decode._cached_kv_indices,
seq_lens=self.indices_updater_decode._cached_seq_lens,
)
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
elif forward_mode.is_target_verify():
verify_wrapper = BatchMLAPagedAttentionWrapper(
Expand Down Expand Up @@ -484,6 +503,13 @@ def init_forward_metadata_replay_cuda_graph(
spec_info=spec_info,
**self.fast_decode_kwargs,
)
# Update forward_metadata with refreshed indices for TQ fused kernel
self.forward_metadata = DecodeMetadata(
self.decode_cuda_graph_metadata[bs],
kv_indptr=self.indices_updater_decode._cached_kv_indptr,
kv_indices=self.indices_updater_decode._cached_kv_indices,
seq_lens=self.indices_updater_decode._cached_seq_lens,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
Expand Down Expand Up @@ -636,9 +662,81 @@ def forward_decode(
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]

k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
q.dtype
pool = forward_batch.token_to_kv_pool

# TurboQuant fused dequant-attention path (4-bit MSE only)
_tq_fused = getattr(pool, "can_use_fused_kernel", False) and getattr(
layer, "_tq_fused_ready", False
)
if _tq_fused:
from sglang.srt.layers.attention.triton_ops.decode_attention_turboquant import (
decode_attention_fwd_tq,
)

# Rotate Q using existing HadamardTransform
q_nope_rot = pool.nope_hadamard.forward(q_nope)
q_rope_rot = pool.rope_hadamard.forward(q_rope)

# Get raw compressed buffers (no dequant)
nope_packed = pool.get_nope_packed_buffer(layer.layer_id)
rope_packed = pool.get_rope_packed_buffer(layer.layer_id)
nope_norms = pool.get_nope_norms_buffer(layer.layer_id)
rope_norms = pool.get_rope_norms_buffer(layer.layer_id)

# Compute num_kv_splits
metadata = self.forward_metadata
bs = q_nope.shape[0]
head_num = q_nope.shape[1]
kv_lora_rank = q_nope.shape[2]
BLOCK_N = 32
max_kv_splits = self._tq_max_kv_splits
num_kv_splits = torch.clamp(
(metadata.seq_lens[:bs] + BLOCK_N * 4 - 1) // (BLOCK_N * 4),
min=1,
max=max_kv_splits,
).to(torch.int32)

# Lazy-allocate split buffers once, reuse across all subsequent calls
if self._tq_attn_logits is None or self._tq_attn_logits.shape[1] < head_num:
max_bs = self.kv_indptr.shape[0] - 1 # req_to_token_pool.size
self._tq_attn_logits = torch.zeros(
(max_bs, head_num, max_kv_splits, kv_lora_rank),
dtype=torch.float32,
device=q_nope.device,
)
self._tq_attn_lse = torch.zeros(
(max_bs, head_num, max_kv_splits),
dtype=torch.float32,
device=q_nope.device,
)

o = q_nope.new_empty(q_nope.shape[0], q_nope.shape[1], q_nope.shape[2])
decode_attention_fwd_tq(
q_nope_rot,
q_rope_rot,
nope_packed,
rope_packed,
nope_norms,
rope_norms,
pool.nope_centroids_scaled,
pool.rope_centroids_scaled,
o,
metadata.kv_indptr,
metadata.kv_indices,
num_kv_splits,
max_kv_splits,
sm_scale=layer.scaling,
logit_cap=layer.logit_cap,
attn_logits=self._tq_attn_logits,
attn_lse=self._tq_attn_lse,
)

# Flag so forward_mla.py uses w_vc_tq_rotated
forward_batch._tq_rotated_output = True
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

# Existing FlashInfer path (unchanged fallback)
k_buffer = pool.get_key_buffer(layer.layer_id).to(q.dtype)

o = q_nope.new_empty(q_nope.shape)
# Direct call to run without the wrapper
Expand Down Expand Up @@ -730,6 +828,11 @@ def call_begin_forward(
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices

# Cache for TurboQuant fused decode (which bypasses FlashInfer wrapper)
self._cached_kv_indptr = kv_indptr
self._cached_kv_indices = kv_indices
self._cached_seq_lens = paged_kernel_lens

if not init_metadata_replay:
wrapper.plan(
q_indptr,
Expand Down
Loading