Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT: bool = True
ROCM_TRITON_MOE_PRESHUFFLE_SCALES: bool = True
VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
Expand Down Expand Up @@ -1380,13 +1381,13 @@
# Use AITER Triton fused RMSNORM + Quantization
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP8_QUANT", "1"))),
"VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT":

Check failure on line 1384 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1384:81: E501 Line too long (92 > 80)
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_RMSNORM_FP4_QUANT", "1"))),

Check failure on line 1386 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1386:81: E501 Line too long (92 > 80)
# Use AITER Triton fused elementwise multiply + elementwise addtion
"VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_MUL_ADD", "1"))),

Check failure on line 1390 in vllm/envs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/envs.py:1390:81: E501 Line too long (82 > 80)
# Use AITER Triton fused rope + zeros + reshape_and_cache
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))),
Expand Down Expand Up @@ -1448,6 +1449,10 @@
# Apply preshuffling for mxfp4 scales for ROCm backend
"VLLM_ROCM_USE_AITER_TRITON_MLA":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_MLA", "1"))),

# Use AITER Triton fused FP4 GEMM + split + cat
"VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT":
lambda: bool(int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT", "1"))),
}
# --8<-- [end:env-vars-definition]

Expand Down
109 changes: 92 additions & 17 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,22 @@ def dynamic_per_batched_tensor_quant(
return x_quant, x_quant_scale

from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant

VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT
if VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT:
from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_split_cat
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.model_executor.layers.quantization.quark.quark import QuarkLinearMethod


else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
VLLM_ROCM_USE_AITER_TRITON_FP8_BMM = False
VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE = 0

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM=} {VLLM_ROCM_USE_AITER_TRITON_FP8_BMM_MAX_BATCH_SIZE=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT=}")

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -1383,13 +1392,35 @@ def _compute_prefill_context(
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)

kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if (
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT
and (self.kv_b_proj.bias is None
or self.kv_b_proj.skip_bias_add)
and self.kv_b_proj.quant_method is not None
and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod)
and not self.kv_b_proj.gather_output
):
input = kv_c_normed
weight = self.kv_b_proj.weight
weight_scale = self.kv_b_proj.weight_scale

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_dtype = input.dtype

q_input, x_scale = dynamic_mxfp4_quant(input_2d)

k, v = fused_gemm_afp4wfp4_split_cat(
q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)

attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
Expand Down Expand Up @@ -1486,12 +1517,34 @@ def _context_parallel_compute_prefill_context(
chunk_idx=i,
toks=toks)

kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
if (
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT
and (self.kv_b_proj.bias is None
or self.kv_b_proj.skip_bias_add)
and self.kv_b_proj.quant_method is not None
and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod)
and not self.kv_b_proj.gather_output
):
input = kv_c_normed
weight = self.kv_b_proj.weight
weight_scale = self.kv_b_proj.weight_scale

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_dtype = input.dtype

q_input, x_scale = dynamic_mxfp4_quant(input_2d)

k, v = fused_gemm_afp4wfp4_split_cat(
q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
Expand Down Expand Up @@ -1533,12 +1586,34 @@ def _forward_prefill(
assert self.dcp_world_size is not None

has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if (
VLLM_ROCM_USE_AITER_TRITON_FUSED_GEMM_FP4_SPLIT_CAT
and (self.kv_b_proj.bias is None
or self.kv_b_proj.skip_bias_add)
and self.kv_b_proj.quant_method is not None
and isinstance(self.kv_b_proj.quant_method, QuarkLinearMethod)
and not self.kv_b_proj.gather_output
):
input = kv_c_normed
weight = self.kv_b_proj.weight
weight_scale = self.kv_b_proj.weight_scale

# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_dtype = input.dtype

q_input, x_scale = dynamic_mxfp4_quant(input_2d)

k, v = fused_gemm_afp4wfp4_split_cat(
q_input, weight, k_pe.expand((-1, self.num_heads, -1)), x_scale, weight_scale.T, self.qk_nope_head_dim, self.v_head_dim, output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
Expand Down
Loading