Skip to content

Commit e450cfe

Browse files
committed
gemma3 support
1 parent 06d11de commit e450cfe

File tree

5 files changed

+51658
-64
lines changed

5 files changed

+51658
-64
lines changed

backends/cuda/cuda_backend.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
6868
)
6969
triton_kernel_mode = mode
7070

71-
return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
71+
# return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
72+
return [ReplaceEdgeOpWithTritonOpPass()]
7273

7374
@classmethod
7475
def get_aoti_compile_options(
@@ -134,20 +135,20 @@ def get_aoti_compile_options(
134135

135136
return options
136137

137-
@classmethod
138-
def get_extra_aoti_compile_context_manager(cls):
139-
"""
140-
Return SDPA MATH backend context manager for CUDA compilation.
141-
142-
This context manager plays as a fallback solution for any remaining PyTorch SDPA
143-
operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.
144-
145-
Note:
146-
- If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
147-
this context manager will have no effect on those ops (they are no longer
148-
PyTorch SDPA ops).
149-
- If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
150-
context manager will force them to use the MATH backend, causing them to
151-
be automatically decomposed during compilation.
152-
"""
153-
return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])
138+
# @classmethod
139+
# def get_extra_aoti_compile_context_manager(cls):
140+
# """
141+
# Return SDPA MATH backend context manager for CUDA compilation.
142+
143+
# This context manager plays as a fallback solution for any remaining PyTorch SDPA
144+
# operations to use the MATH backend (decomposed SDPA) during AOTInductor compilation.
145+
146+
# Note:
147+
# - If SDPA ops are replaced with Triton kernels by ReplaceEdgeOpWithTritonOpPass,
148+
# this context manager will have no effect on those ops (they are no longer
149+
# PyTorch SDPA ops).
150+
# - If SDPA ops are NOT replaced (e.g., when triton_kernel_mode="OFF"), this
151+
# context manager will force them to use the MATH backend, causing them to
152+
# be automatically decomposed during compilation.
153+
# """
154+
# return torch.nn.attention.sdpa_kernel([SDPBackend.MATH])

backends/cuda/triton/kernels/sdpa.py

Lines changed: 259 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@
2222
from torch.library import triton_op, wrap_triton
2323

2424

25+
def _is_power_of_2(n: int) -> bool:
26+
"""Check if n is a power of 2."""
27+
return n > 0 and (n & (n - 1)) == 0
28+
29+
30+
def _next_power_of_2(x: int) -> int:
31+
"""Get the next power of 2 >= x, clamped to [16, 256]."""
32+
if x <= 16:
33+
return 16
34+
if x <= 32:
35+
return 32
36+
if x <= 64:
37+
return 64
38+
if x <= 128:
39+
return 128
40+
return 256
41+
42+
2543
def _validate_qkv_shapes(
2644
query: torch.Tensor,
2745
key: torch.Tensor,
@@ -64,6 +82,131 @@ def _validate_qkv_shapes(
6482
return B_q, H_q, L_q, L_kv_k, D_q, D_k
6583

6684

85+
# ==============================================================================
86+
# Non-power-of-2 HEAD_DIM kernel
87+
# ==============================================================================
88+
@triton.jit
89+
def _sdpa_fwd_kernel_non_pow2(
90+
q_ptr,
91+
k_ptr,
92+
v_ptr,
93+
o_ptr,
94+
mask_ptr,
95+
B,
96+
H,
97+
LQ,
98+
LK,
99+
HEAD_DIM,
100+
stride_qb,
101+
stride_qh,
102+
stride_ql,
103+
stride_qd,
104+
stride_kb,
105+
stride_kh,
106+
stride_kl,
107+
stride_kd,
108+
stride_vb,
109+
stride_vh,
110+
stride_vl,
111+
stride_vd,
112+
stride_ob,
113+
stride_oh,
114+
stride_ol,
115+
stride_od,
116+
stride_mb,
117+
stride_mh,
118+
stride_mlq,
119+
stride_mlk,
120+
scale,
121+
BLOCK_M: tl.constexpr,
122+
BLOCK_N: tl.constexpr,
123+
BLOCK_D: tl.constexpr,
124+
HAS_MASK: tl.constexpr,
125+
IS_CAUSAL: tl.constexpr,
126+
):
127+
"""
128+
SDPA forward kernel for non-power-of-2 HEAD_DIM.
129+
Uses dynamic masking to handle arbitrary head dimensions.
130+
"""
131+
pid_m = tl.program_id(axis=0)
132+
pid_bh = tl.program_id(axis=1)
133+
134+
b = pid_bh // H
135+
h = pid_bh % H
136+
137+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
138+
offs_n = tl.arange(0, BLOCK_N)
139+
offs_d = tl.arange(0, BLOCK_D)
140+
141+
d_mask = offs_d < HEAD_DIM
142+
q_row_mask = offs_m < LQ
143+
144+
q_base = q_ptr + b * stride_qb + h * stride_qh
145+
k_base = k_ptr + b * stride_kb + h * stride_kh
146+
v_base = v_ptr + b * stride_vb + h * stride_vh
147+
o_base = o_ptr + b * stride_ob + h * stride_oh
148+
149+
q_ptrs = q_base + (offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd)
150+
q = tl.load(q_ptrs, mask=q_row_mask[:, None] & d_mask[None, :], other=0.0)
151+
152+
acc = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32)
153+
m_i = tl.full((BLOCK_M,), -float("inf"), dtype=tl.float32)
154+
l_i = tl.full((BLOCK_M,), 1.0, dtype=tl.float32)
155+
156+
qk_scale_log2 = scale * 1.4426950408889634
157+
158+
if HAS_MASK:
159+
mask_b_base = mask_ptr + b * stride_mb
160+
161+
for start_n in tl.range(0, LK, BLOCK_N, num_stages=2):
162+
kn = start_n + offs_n
163+
kv_col_mask = kn < LK
164+
165+
k_ptrs = k_base + (kn[:, None] * stride_kl + offs_d[None, :] * stride_kd)
166+
k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0)
167+
168+
qk = tl.dot(q, tl.trans(k))
169+
qk = qk * qk_scale_log2
170+
171+
if IS_CAUSAL:
172+
row_abs = offs_m[:, None]
173+
col_abs = kn[None, :]
174+
causal_mask = col_abs > row_abs
175+
qk = tl.where(causal_mask, -float("inf"), qk)
176+
177+
if HAS_MASK:
178+
mask_ptrs = (
179+
mask_b_base + offs_m[:, None] * stride_mlq + kn[None, :] * stride_mlk
180+
)
181+
tile_valid = q_row_mask[:, None] & kv_col_mask[None, :]
182+
keep = tl.load(mask_ptrs, mask=tile_valid, other=True)
183+
qk = tl.where(keep, qk, -float("inf"))
184+
185+
qk = tl.where(kv_col_mask[None, :], qk, -float("inf"))
186+
187+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
188+
p = tl.math.exp2(qk - m_ij[:, None])
189+
l_ij = tl.sum(p, 1)
190+
alpha = tl.math.exp2(m_i - m_ij)
191+
192+
acc = acc * alpha[:, None]
193+
194+
v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd)
195+
v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0)
196+
197+
acc = tl.dot(p.to(v.dtype), v, acc)
198+
199+
l_i = l_i * alpha + l_ij
200+
m_i = m_ij
201+
202+
out = acc / l_i[:, None]
203+
o_ptrs = o_base + (offs_m[:, None] * stride_ol + offs_d[None, :] * stride_od)
204+
tl.store(o_ptrs, out.to(tl.bfloat16), mask=q_row_mask[:, None] & d_mask[None, :])
205+
206+
207+
# ==============================================================================
208+
# Power-of-2 HEAD_DIM kernels
209+
# ==============================================================================
67210
@triton.jit
68211
def _sdpa_fwd_kernel_body(
69212
Q_ptr,
@@ -463,57 +606,122 @@ def sdpa(
463606
def grid(meta):
464607
return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H)
465608

466-
# Dynamic kernel selection based on workload
467-
total_ctas_m64 = ((L_q + 63) // 64) * (B * H)
468-
threshold = 4 * 84 # Heuristic threshold for kernel selection
469-
use_small_block = total_ctas_m64 < threshold
470-
471-
if use_small_block:
472-
wrap_triton(_sdpa_fwd_kernel_m32)[grid](
473-
query,
474-
key,
475-
value,
476-
out,
477-
Mask_ptr if HAS_MASK else 0,
478-
B,
479-
H,
480-
L_q,
481-
L_kv,
482-
stride_qb,
483-
stride_qh,
484-
stride_qm,
485-
stride_qd,
486-
stride_kb,
487-
stride_kh,
488-
stride_kn,
489-
stride_kd,
490-
stride_vb,
491-
stride_vh,
492-
stride_vn,
493-
stride_vd,
494-
stride_ob,
495-
stride_oh,
496-
stride_om,
497-
stride_od,
498-
stride_mb,
499-
stride_mq,
500-
stride_mk,
501-
sm_scale,
502-
HAS_MASK=HAS_MASK,
503-
IS_CAUSAL=is_causal,
504-
HEAD_DIM=D,
505-
)
609+
# Select kernel based on whether HEAD_DIM is power of 2
610+
if _is_power_of_2(D):
611+
# Use power-of-2 optimized kernels with autotune
612+
# Dynamic kernel selection based on workload
613+
total_ctas_m64 = ((L_q + 63) // 64) * (B * H)
614+
threshold = 4 * 84 # Heuristic threshold for kernel selection
615+
use_small_block = total_ctas_m64 < threshold
616+
617+
if use_small_block:
618+
wrap_triton(_sdpa_fwd_kernel_m32)[grid](
619+
query,
620+
key,
621+
value,
622+
out,
623+
Mask_ptr if HAS_MASK else 0,
624+
B,
625+
H,
626+
L_q,
627+
L_kv,
628+
stride_qb,
629+
stride_qh,
630+
stride_qm,
631+
stride_qd,
632+
stride_kb,
633+
stride_kh,
634+
stride_kn,
635+
stride_kd,
636+
stride_vb,
637+
stride_vh,
638+
stride_vn,
639+
stride_vd,
640+
stride_ob,
641+
stride_oh,
642+
stride_om,
643+
stride_od,
644+
stride_mb,
645+
stride_mq,
646+
stride_mk,
647+
sm_scale,
648+
HAS_MASK=HAS_MASK,
649+
IS_CAUSAL=is_causal,
650+
HEAD_DIM=D,
651+
)
652+
else:
653+
wrap_triton(_sdpa_fwd_kernel_m64)[grid](
654+
query,
655+
key,
656+
value,
657+
out,
658+
Mask_ptr if HAS_MASK else 0,
659+
B,
660+
H,
661+
L_q,
662+
L_kv,
663+
stride_qb,
664+
stride_qh,
665+
stride_qm,
666+
stride_qd,
667+
stride_kb,
668+
stride_kh,
669+
stride_kn,
670+
stride_kd,
671+
stride_vb,
672+
stride_vh,
673+
stride_vn,
674+
stride_vd,
675+
stride_ob,
676+
stride_oh,
677+
stride_om,
678+
stride_od,
679+
stride_mb,
680+
stride_mq,
681+
stride_mk,
682+
sm_scale,
683+
HAS_MASK=HAS_MASK,
684+
IS_CAUSAL=is_causal,
685+
HEAD_DIM=D,
686+
)
506687
else:
507-
wrap_triton(_sdpa_fwd_kernel_m64)[grid](
688+
# Use non-power-of-2 kernel with dynamic HEAD_DIM masking
689+
BLOCK_D = _next_power_of_2(D)
690+
691+
if BLOCK_D >= 256:
692+
BLOCK_N = 64
693+
else:
694+
BLOCK_N = 128
695+
696+
BLOCK_M = 32
697+
num_warps = 4
698+
num_stages = 2
699+
700+
# Handle mask for non-pow2 kernel (different stride layout)
701+
if HAS_MASK:
702+
mask_ptr = attn_mask
703+
stride_mb_np2 = attn_mask.stride(0)
704+
stride_mh_np2 = attn_mask.stride(1)
705+
stride_mlq_np2 = attn_mask.stride(2)
706+
stride_mlk_np2 = attn_mask.stride(3)
707+
else:
708+
mask_ptr = torch.empty((1,), device=query.device, dtype=torch.bool)
709+
stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0
710+
711+
def grid_non_pow2(meta):
712+
return (triton.cdiv(L_q, meta["BLOCK_M"]), B * H)
713+
714+
wrap_triton(_sdpa_fwd_kernel_non_pow2)[grid_non_pow2](
508715
query,
509716
key,
510717
value,
511718
out,
512-
Mask_ptr if HAS_MASK else 0,
719+
mask_ptr,
513720
B,
514721
H,
515722
L_q,
516723
L_kv,
724+
D,
517725
stride_qb,
518726
stride_qh,
519727
stride_qm,
@@ -530,13 +738,18 @@ def grid(meta):
530738
stride_oh,
531739
stride_om,
532740
stride_od,
533-
stride_mb,
534-
stride_mq,
535-
stride_mk,
741+
stride_mb_np2,
742+
stride_mh_np2,
743+
stride_mlq_np2,
744+
stride_mlk_np2,
536745
sm_scale,
746+
BLOCK_M=BLOCK_M,
747+
BLOCK_N=BLOCK_N,
748+
BLOCK_D=BLOCK_D,
537749
HAS_MASK=HAS_MASK,
538750
IS_CAUSAL=is_causal,
539-
HEAD_DIM=D,
751+
num_warps=num_warps,
752+
num_stages=num_stages,
540753
)
541754

542755
return out

output.wav

938 KB
Binary file not shown.

0 commit comments

Comments
 (0)