Skip to content

Commit c86564b

Browse files
committed
avoid launching two quant kernel and reshape
Signed-off-by: Siyuan Fu <[email protected]>
1 parent 0353d2e commit c86564b

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,21 +2026,35 @@ def forward(
20262026

20272027
if fp8_attention:
20282028
ql_nope_shape = decode_ql_nope.shape
2029-
decode_ql_nope, _ = ops.scaled_fp8_quant(
2030-
decode_ql_nope.reshape(
2031-
[ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
2032-
),
2033-
layer._q_scale,
2034-
)
2035-
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
20362029
q_pe_shape = decode_q_pe.shape
2037-
decode_q_pe, _ = ops.scaled_fp8_quant(
2038-
decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
2039-
layer._q_scale,
2030+
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
2031+
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
2032+
decode_q_shape = (
2033+
ql_nope_shape[0],
2034+
ql_nope_shape[1],
2035+
ql_nope_shape[2] + q_pe_shape[2],
2036+
)
2037+
decode_q0 = torch.empty(
2038+
decode_q_shape,
2039+
device=decode_ql_nope.device,
2040+
dtype=decode_ql_nope.dtype,
2041+
)
2042+
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
2043+
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
2044+
decode_q = torch.empty(
2045+
decode_q_shape,
2046+
device=decode_ql_nope.device,
2047+
dtype=torch.float8_e4m3fn,
20402048
)
2041-
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
20422049

2043-
decode_q = (decode_ql_nope, decode_q_pe)
2050+
decode_q, _ = ops.scaled_fp8_quant(
2051+
decode_q0.view(decode_q_shape[0], -1),
2052+
layer._q_scale,
2053+
output=decode_q.view(decode_q_shape[0], -1),
2054+
)
2055+
decode_q = decode_q.view(decode_q_shape)
2056+
else:
2057+
decode_q = (decode_ql_nope, decode_q_pe)
20442058
if self.dcp_world_size > 1:
20452059
assert not fp8_attention, "DCP not support fp8 kvcache now."
20462060
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)

0 commit comments

Comments
 (0)