diff --git a/vllm/envs.py b/vllm/envs.py index 5919991aaa49..687516d4af43 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -203,6 +203,7 @@ VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT: bool = True VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT: bool = True ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False @@ -1448,6 +1449,10 @@ def get_vllm_port() -> Optional[int]: # Apply preshuffling for mxfp4 scales for ROCm backend "VLLM_ROCM_USE_AITER_TRITON_MLA": lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_MLA", "1"))), + + # Use AITER Triton fused FP4 GEMM + split + cat + "VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT": + lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT", "1"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 14d3040682ae..66fb63a2c80d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -245,6 +245,14 @@ def dynamic_per_batched_tensor_quant( return x_quant, x_quant_scale from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant + + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT + if VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT: + from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat + from aiter.ops.triton.quant import dynamic_mxfp4_quant + from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod + + else: VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False VLLM_ROCM_USE_AITER_TRITON_FP8_BMM = False @@ -252,6 +260,7 @@ def dynamic_per_batched_tensor_quant( logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM=} {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE=}") logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}") +logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT=}") try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -1383,13 +1392,35 @@ def _compute_prefill_context( k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + if ( + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT + and (self.kv_b_proj.bias is None + or self.kv_b_proj.skip_bias_add) + and self.kv_b_proj.quant_method is not None + and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) + and not self.kv_b_proj.gather_output + ): + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), + dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1486,12 +1517,34 @@ def _context_parallel_compute_prefill_context( chunk_idx=i, toks=toks) - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + if ( + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT + and (self.kv_b_proj.bias is None + or self.kv_b_proj.skip_bias_add) + and self.kv_b_proj.quant_method is not None + and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) + and not self.kv_b_proj.gather_output + ): + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -1533,12 +1586,34 @@ def _forward_prefill( assert self.dcp_world_size is not None has_context = attn_metadata.prefill.chunked_context is not None - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + if ( + VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT + and (self.kv_b_proj.bias is None + or self.kv_b_proj.skip_bias_add) + and self.kv_b_proj.quant_method is not None + and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod) + and not self.kv_b_proj.gather_output + ): + input = kv_c_normed + weight = self.kv_b_proj.weight + weight_scale = self.kv_b_proj.weight_scale + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_dtype = input.dtype + + q_input, x_scale = dynamic_mxfp4_quant(input_2d) + + k, v = fused_gemm_afp4wfp4_split_cat( + q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype + ) + else: + kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output = self._run_prefill_new_tokens( prefill=attn_metadata.prefill,