@@ -2026,21 +2026,35 @@ def forward(
20262026
20272027 if fp8_attention :
20282028 ql_nope_shape = decode_ql_nope .shape
2029- decode_ql_nope , _ = ops .scaled_fp8_quant (
2030- decode_ql_nope .reshape (
2031- [ql_nope_shape [0 ], ql_nope_shape [1 ] * ql_nope_shape [2 ]]
2032- ),
2033- layer ._q_scale ,
2034- )
2035- decode_ql_nope = decode_ql_nope .reshape (ql_nope_shape )
20362029 q_pe_shape = decode_q_pe .shape
2037- decode_q_pe , _ = ops .scaled_fp8_quant (
2038- decode_q_pe .reshape ([q_pe_shape [0 ], q_pe_shape [1 ] * q_pe_shape [2 ]]),
2039- layer ._q_scale ,
2030+ assert decode_ql_nope .shape [0 ] == decode_q_pe .shape [0 ]
2031+ assert decode_ql_nope .shape [1 ] == decode_q_pe .shape [1 ]
2032+ decode_q_shape = (
2033+ ql_nope_shape [0 ],
2034+ ql_nope_shape [1 ],
2035+ ql_nope_shape [2 ] + q_pe_shape [2 ],
2036+ )
2037+ decode_q0 = torch .empty (
2038+ decode_q_shape ,
2039+ device = decode_ql_nope .device ,
2040+ dtype = decode_ql_nope .dtype ,
2041+ )
2042+ decode_q0 [..., : ql_nope_shape [2 ]].copy_ (decode_ql_nope )
2043+ decode_q0 [..., ql_nope_shape [2 ] :].copy_ (decode_q_pe )
2044+ decode_q = torch .empty (
2045+ decode_q_shape ,
2046+ device = decode_ql_nope .device ,
2047+ dtype = torch .float8_e4m3fn ,
20402048 )
2041- decode_q_pe = decode_q_pe .reshape (q_pe_shape )
20422049
2043- decode_q = (decode_ql_nope , decode_q_pe )
2050+ decode_q , _ = ops .scaled_fp8_quant (
2051+ decode_q0 .view (decode_q_shape [0 ], - 1 ),
2052+ layer ._q_scale ,
2053+ output = decode_q .view (decode_q_shape [0 ], - 1 ),
2054+ )
2055+ decode_q = decode_q .view (decode_q_shape )
2056+ else :
2057+ decode_q = (decode_ql_nope , decode_q_pe )
20442058 if self .dcp_world_size > 1 :
20452059 assert not fp8_attention , "DCP not support fp8 kvcache now."
20462060 # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
0 commit comments