2222from torch .library import triton_op , wrap_triton
2323
2424
25+ def _is_power_of_2 (n : int ) -> bool :
26+ """Check if n is a power of 2."""
27+ return n > 0 and (n & (n - 1 )) == 0
28+
29+
30+ def _next_power_of_2 (x : int ) -> int :
31+ """Get the next power of 2 >= x, clamped to [16, 256]."""
32+ if x <= 16 :
33+ return 16
34+ if x <= 32 :
35+ return 32
36+ if x <= 64 :
37+ return 64
38+ if x <= 128 :
39+ return 128
40+ return 256
41+
42+
2543def _validate_qkv_shapes (
2644 query : torch .Tensor ,
2745 key : torch .Tensor ,
@@ -64,6 +82,131 @@ def _validate_qkv_shapes(
6482 return B_q , H_q , L_q , L_kv_k , D_q , D_k
6583
6684
85+ # ==============================================================================
86+ # Non-power-of-2 HEAD_DIM kernel
87+ # ==============================================================================
88+ @triton .jit
89+ def _sdpa_fwd_kernel_non_pow2 (
90+ q_ptr ,
91+ k_ptr ,
92+ v_ptr ,
93+ o_ptr ,
94+ mask_ptr ,
95+ B ,
96+ H ,
97+ LQ ,
98+ LK ,
99+ HEAD_DIM ,
100+ stride_qb ,
101+ stride_qh ,
102+ stride_ql ,
103+ stride_qd ,
104+ stride_kb ,
105+ stride_kh ,
106+ stride_kl ,
107+ stride_kd ,
108+ stride_vb ,
109+ stride_vh ,
110+ stride_vl ,
111+ stride_vd ,
112+ stride_ob ,
113+ stride_oh ,
114+ stride_ol ,
115+ stride_od ,
116+ stride_mb ,
117+ stride_mh ,
118+ stride_mlq ,
119+ stride_mlk ,
120+ scale ,
121+ BLOCK_M : tl .constexpr ,
122+ BLOCK_N : tl .constexpr ,
123+ BLOCK_D : tl .constexpr ,
124+ HAS_MASK : tl .constexpr ,
125+ IS_CAUSAL : tl .constexpr ,
126+ ):
127+ """
128+ SDPA forward kernel for non-power-of-2 HEAD_DIM.
129+ Uses dynamic masking to handle arbitrary head dimensions.
130+ """
131+ pid_m = tl .program_id (axis = 0 )
132+ pid_bh = tl .program_id (axis = 1 )
133+
134+ b = pid_bh // H
135+ h = pid_bh % H
136+
137+ offs_m = pid_m * BLOCK_M + tl .arange (0 , BLOCK_M )
138+ offs_n = tl .arange (0 , BLOCK_N )
139+ offs_d = tl .arange (0 , BLOCK_D )
140+
141+ d_mask = offs_d < HEAD_DIM
142+ q_row_mask = offs_m < LQ
143+
144+ q_base = q_ptr + b * stride_qb + h * stride_qh
145+ k_base = k_ptr + b * stride_kb + h * stride_kh
146+ v_base = v_ptr + b * stride_vb + h * stride_vh
147+ o_base = o_ptr + b * stride_ob + h * stride_oh
148+
149+ q_ptrs = q_base + (offs_m [:, None ] * stride_ql + offs_d [None , :] * stride_qd )
150+ q = tl .load (q_ptrs , mask = q_row_mask [:, None ] & d_mask [None , :], other = 0.0 )
151+
152+ acc = tl .zeros ((BLOCK_M , BLOCK_D ), dtype = tl .float32 )
153+ m_i = tl .full ((BLOCK_M ,), - float ("inf" ), dtype = tl .float32 )
154+ l_i = tl .full ((BLOCK_M ,), 1.0 , dtype = tl .float32 )
155+
156+ qk_scale_log2 = scale * 1.4426950408889634
157+
158+ if HAS_MASK :
159+ mask_b_base = mask_ptr + b * stride_mb
160+
161+ for start_n in tl .range (0 , LK , BLOCK_N , num_stages = 2 ):
162+ kn = start_n + offs_n
163+ kv_col_mask = kn < LK
164+
165+ k_ptrs = k_base + (kn [:, None ] * stride_kl + offs_d [None , :] * stride_kd )
166+ k = tl .load (k_ptrs , mask = kv_col_mask [:, None ] & d_mask [None , :], other = 0.0 )
167+
168+ qk = tl .dot (q , tl .trans (k ))
169+ qk = qk * qk_scale_log2
170+
171+ if IS_CAUSAL :
172+ row_abs = offs_m [:, None ]
173+ col_abs = kn [None , :]
174+ causal_mask = col_abs > row_abs
175+ qk = tl .where (causal_mask , - float ("inf" ), qk )
176+
177+ if HAS_MASK :
178+ mask_ptrs = (
179+ mask_b_base + offs_m [:, None ] * stride_mlq + kn [None , :] * stride_mlk
180+ )
181+ tile_valid = q_row_mask [:, None ] & kv_col_mask [None , :]
182+ keep = tl .load (mask_ptrs , mask = tile_valid , other = True )
183+ qk = tl .where (keep , qk , - float ("inf" ))
184+
185+ qk = tl .where (kv_col_mask [None , :], qk , - float ("inf" ))
186+
187+ m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
188+ p = tl .math .exp2 (qk - m_ij [:, None ])
189+ l_ij = tl .sum (p , 1 )
190+ alpha = tl .math .exp2 (m_i - m_ij )
191+
192+ acc = acc * alpha [:, None ]
193+
194+ v_ptrs = v_base + (kn [:, None ] * stride_vl + offs_d [None , :] * stride_vd )
195+ v = tl .load (v_ptrs , mask = kv_col_mask [:, None ] & d_mask [None , :], other = 0.0 )
196+
197+ acc = tl .dot (p .to (v .dtype ), v , acc )
198+
199+ l_i = l_i * alpha + l_ij
200+ m_i = m_ij
201+
202+ out = acc / l_i [:, None ]
203+ o_ptrs = o_base + (offs_m [:, None ] * stride_ol + offs_d [None , :] * stride_od )
204+ tl .store (o_ptrs , out .to (tl .bfloat16 ), mask = q_row_mask [:, None ] & d_mask [None , :])
205+
206+
207+ # ==============================================================================
208+ # Power-of-2 HEAD_DIM kernels
209+ # ==============================================================================
67210@triton .jit
68211def _sdpa_fwd_kernel_body (
69212 Q_ptr ,
@@ -463,57 +606,122 @@ def sdpa(
463606 def grid (meta ):
464607 return (triton .cdiv (L_q , meta ["BLOCK_M" ]), B * H )
465608
466- # Dynamic kernel selection based on workload
467- total_ctas_m64 = ((L_q + 63 ) // 64 ) * (B * H )
468- threshold = 4 * 84 # Heuristic threshold for kernel selection
469- use_small_block = total_ctas_m64 < threshold
470-
471- if use_small_block :
472- wrap_triton (_sdpa_fwd_kernel_m32 )[grid ](
473- query ,
474- key ,
475- value ,
476- out ,
477- Mask_ptr if HAS_MASK else 0 ,
478- B ,
479- H ,
480- L_q ,
481- L_kv ,
482- stride_qb ,
483- stride_qh ,
484- stride_qm ,
485- stride_qd ,
486- stride_kb ,
487- stride_kh ,
488- stride_kn ,
489- stride_kd ,
490- stride_vb ,
491- stride_vh ,
492- stride_vn ,
493- stride_vd ,
494- stride_ob ,
495- stride_oh ,
496- stride_om ,
497- stride_od ,
498- stride_mb ,
499- stride_mq ,
500- stride_mk ,
501- sm_scale ,
502- HAS_MASK = HAS_MASK ,
503- IS_CAUSAL = is_causal ,
504- HEAD_DIM = D ,
505- )
609+ # Select kernel based on whether HEAD_DIM is power of 2
610+ if _is_power_of_2 (D ):
611+ # Use power-of-2 optimized kernels with autotune
612+ # Dynamic kernel selection based on workload
613+ total_ctas_m64 = ((L_q + 63 ) // 64 ) * (B * H )
614+ threshold = 4 * 84 # Heuristic threshold for kernel selection
615+ use_small_block = total_ctas_m64 < threshold
616+
617+ if use_small_block :
618+ wrap_triton (_sdpa_fwd_kernel_m32 )[grid ](
619+ query ,
620+ key ,
621+ value ,
622+ out ,
623+ Mask_ptr if HAS_MASK else 0 ,
624+ B ,
625+ H ,
626+ L_q ,
627+ L_kv ,
628+ stride_qb ,
629+ stride_qh ,
630+ stride_qm ,
631+ stride_qd ,
632+ stride_kb ,
633+ stride_kh ,
634+ stride_kn ,
635+ stride_kd ,
636+ stride_vb ,
637+ stride_vh ,
638+ stride_vn ,
639+ stride_vd ,
640+ stride_ob ,
641+ stride_oh ,
642+ stride_om ,
643+ stride_od ,
644+ stride_mb ,
645+ stride_mq ,
646+ stride_mk ,
647+ sm_scale ,
648+ HAS_MASK = HAS_MASK ,
649+ IS_CAUSAL = is_causal ,
650+ HEAD_DIM = D ,
651+ )
652+ else :
653+ wrap_triton (_sdpa_fwd_kernel_m64 )[grid ](
654+ query ,
655+ key ,
656+ value ,
657+ out ,
658+ Mask_ptr if HAS_MASK else 0 ,
659+ B ,
660+ H ,
661+ L_q ,
662+ L_kv ,
663+ stride_qb ,
664+ stride_qh ,
665+ stride_qm ,
666+ stride_qd ,
667+ stride_kb ,
668+ stride_kh ,
669+ stride_kn ,
670+ stride_kd ,
671+ stride_vb ,
672+ stride_vh ,
673+ stride_vn ,
674+ stride_vd ,
675+ stride_ob ,
676+ stride_oh ,
677+ stride_om ,
678+ stride_od ,
679+ stride_mb ,
680+ stride_mq ,
681+ stride_mk ,
682+ sm_scale ,
683+ HAS_MASK = HAS_MASK ,
684+ IS_CAUSAL = is_causal ,
685+ HEAD_DIM = D ,
686+ )
506687 else :
507- wrap_triton (_sdpa_fwd_kernel_m64 )[grid ](
688+ # Use non-power-of-2 kernel with dynamic HEAD_DIM masking
689+ BLOCK_D = _next_power_of_2 (D )
690+
691+ if BLOCK_D >= 256 :
692+ BLOCK_N = 64
693+ else :
694+ BLOCK_N = 128
695+
696+ BLOCK_M = 32
697+ num_warps = 4
698+ num_stages = 2
699+
700+ # Handle mask for non-pow2 kernel (different stride layout)
701+ if HAS_MASK :
702+ mask_ptr = attn_mask
703+ stride_mb_np2 = attn_mask .stride (0 )
704+ stride_mh_np2 = attn_mask .stride (1 )
705+ stride_mlq_np2 = attn_mask .stride (2 )
706+ stride_mlk_np2 = attn_mask .stride (3 )
707+ else :
708+ mask_ptr = torch .empty ((1 ,), device = query .device , dtype = torch .bool )
709+ stride_mb_np2 = stride_mh_np2 = stride_mlq_np2 = stride_mlk_np2 = 0
710+
711+ def grid_non_pow2 (meta ):
712+ return (triton .cdiv (L_q , meta ["BLOCK_M" ]), B * H )
713+
714+ wrap_triton (_sdpa_fwd_kernel_non_pow2 )[grid_non_pow2 ](
508715 query ,
509716 key ,
510717 value ,
511718 out ,
512- Mask_ptr if HAS_MASK else 0 ,
719+ mask_ptr ,
513720 B ,
514721 H ,
515722 L_q ,
516723 L_kv ,
724+ D ,
517725 stride_qb ,
518726 stride_qh ,
519727 stride_qm ,
@@ -530,13 +738,18 @@ def grid(meta):
530738 stride_oh ,
531739 stride_om ,
532740 stride_od ,
533- stride_mb ,
534- stride_mq ,
535- stride_mk ,
741+ stride_mb_np2 ,
742+ stride_mh_np2 ,
743+ stride_mlq_np2 ,
744+ stride_mlk_np2 ,
536745 sm_scale ,
746+ BLOCK_M = BLOCK_M ,
747+ BLOCK_N = BLOCK_N ,
748+ BLOCK_D = BLOCK_D ,
537749 HAS_MASK = HAS_MASK ,
538750 IS_CAUSAL = is_causal ,
539- HEAD_DIM = D ,
751+ num_warps = num_warps ,
752+ num_stages = num_stages ,
540753 )
541754
542755 return out
0 commit comments