diff --git a/rtp_llm/models_py/modules/factory/attention/rocm_impl/aiter.py b/rtp_llm/models_py/modules/factory/attention/rocm_impl/aiter.py index 3ba0f322ab..ea31a37723 100644 --- a/rtp_llm/models_py/modules/factory/attention/rocm_impl/aiter.py +++ b/rtp_llm/models_py/modules/factory/attention/rocm_impl/aiter.py @@ -1371,6 +1371,34 @@ def forward( ) return output.view(num_seqs, -1) +def _try_embedding_fast_path( + fmha_input: Any, + fmha_impl: Any, + fmha_params: Any, +) -> Optional[torch.Tensor]: + """Embedding fast path: extract packed QKV from C++ return and call flash_attn_varlen directly. + + C++ returns (packed_qkv, empty_k, empty_v) for embedding models. + Skips FP8 inputs which require a dedicated attention path. + Returns attention output if fast path is taken, None otherwise. + """ + packed_qkv = fmha_input[0] if isinstance(fmha_input, tuple) else fmha_input + if packed_qkv.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + return None + + token_q_num = getattr(fmha_params, "token_q_num", packed_qkv.shape[0]) + token_kv_num = getattr(fmha_params, "token_kv_num", packed_qkv.shape[0]) + q, k, v = split_raw_qkv( + packed_qkv, fmha_impl.head_num, fmha_impl.head_num_kv, + fmha_impl.head_dim, token_q_num, token_kv_num, + ) + return aiter.flash_attn_varlen_func( + q, k, v, + fmha_params.cu_seqlens_q, fmha_params.cu_seqlens_k, + fmha_params.max_seqlen_q, fmha_params.max_seqlen_k, + dropout_p=0.0, causal=fmha_impl.is_causal, + ).reshape(token_q_num, fmha_impl.head_num * fmha_impl.head_dim) + class AiterPrefillImplAsm(FMHAImplBase): """Aiter prefill attention implementation using ASM.""" @@ -1408,13 +1436,17 @@ def forward( layer_idx: int = 0, ) -> torch.Tensor: if kv_cache is None: - # Embedding models still need positional encoding even without a KV cache. if self.need_rope_kv_cache: fmha_input = self.rope_kvcache_impl.forward( qkv, kv_cache, self.rope_params ) else: fmha_input = qkv + + fast_out = _try_embedding_fast_path(fmha_input, self.fmha_impl, self.fmha_params) + if fast_out is not None: + return fast_out + return self.fmha_impl.forward(fmha_input, kv_cache, self.fmha_params) # Apply RoPE and KV Cache processing @@ -1468,13 +1500,17 @@ def forward( layer_idx: int = 0, ) -> torch.Tensor: if kv_cache is None: - # Embedding models still need positional encoding even without a KV cache. if self.need_rope_kv_cache: fmha_input = self.rope_kvcache_impl.forward( qkv, kv_cache, self.rope_params ) else: fmha_input = qkv + + fast_out = _try_embedding_fast_path(fmha_input, self.fmha_impl, self.fmha_params) + if fast_out is not None: + return fast_out + return self.fmha_impl.forward(fmha_input, kv_cache, self.fmha_params) # Apply RoPE and KV Cache processing diff --git a/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/BUILD b/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/BUILD index c5c035ce01..faa0cda25a 100644 --- a/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/BUILD +++ b/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/BUILD @@ -29,6 +29,14 @@ py_test( exec_properties = {"gpu": "MI308X-ROCM7"}, ) +py_test( + name = "test_embedding_fast_path", + srcs = ["test_embedding_fast_path.py"], + deps = py_test_deps, + tags = ["rocm"], + exec_properties = {"gpu": "MI308X-ROCM7"}, +) + py_test( name = "test_fused_qkv_transpose_v3", srcs = ["test_fused_qkv_transpose_v3.py"], diff --git a/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/test_embedding_fast_path.py b/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/test_embedding_fast_path.py new file mode 100644 index 0000000000..7b875f6d90 --- /dev/null +++ b/rtp_llm/models_py/modules/factory/attention/rocm_impl/test/test_embedding_fast_path.py @@ -0,0 +1,114 @@ +"""Regression tests for _try_embedding_fast_path.""" +import unittest + +import torch + +from rtp_llm.models_py.modules.factory.attention.rocm_impl.aiter import ( + _try_embedding_fast_path, +) + +HAS_GPU = torch.cuda.is_available() + + +class FakeFmhaImpl: + def __init__(self, head_num=14, head_num_kv=2, head_dim=64, is_causal=False): + self.head_num = head_num + self.head_num_kv = head_num_kv + self.head_dim = head_dim + self.is_causal = is_causal + + +class FakeFmhaParams: + def __init__(self, seqlens, device="cpu"): + cu = torch.zeros(len(seqlens) + 1, dtype=torch.int32, device=device) + cu[1:] = torch.tensor(seqlens, dtype=torch.int32, device=device).cumsum(0) + self.cu_seqlens_q = cu + self.cu_seqlens_k = cu.clone() + self.max_seqlen_q = max(seqlens) + self.max_seqlen_k = max(seqlens) + total = sum(seqlens) + self.token_q_num = total + self.token_kv_num = total + + +class TestEmbeddingFastPathSkip(unittest.TestCase): + """Cases where fast path must NOT trigger.""" + + def _hidden(self, impl): + return (impl.head_num + 2 * impl.head_num_kv) * impl.head_dim + + def test_fp8_input_skipped(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([4]) + qkv = torch.randn(4, self._hidden(impl)).to(torch.float8_e4m3fnuz) + self.assertIsNone(_try_embedding_fast_path(qkv, impl, params)) + + def test_3d_input_skipped(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([4]) + self.assertIsNone(_try_embedding_fast_path(torch.randn(1, 4, 100), impl, params)) + + def test_tuple_with_valid_kv_skipped(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([4]) + q = torch.randn(4, impl.head_num, impl.head_dim) + k = torch.randn(4, impl.head_num_kv, impl.head_dim) + v = torch.randn(4, impl.head_num_kv, impl.head_dim) + self.assertIsNone(_try_embedding_fast_path((q, k, v), impl, params)) + + +@unittest.skipUnless(HAS_GPU, "requires GPU") +class TestEmbeddingFastPathTrigger(unittest.TestCase): + """Cases where fast path must trigger and produce correct shapes.""" + + def _hidden(self, impl): + return (impl.head_num + 2 * impl.head_num_kv) * impl.head_dim + + def test_2d_packed_qkv(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([6], device="cuda") + qkv = torch.randn(6, self._hidden(impl), device="cuda", dtype=torch.bfloat16) + result = _try_embedding_fast_path(qkv, impl, params) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (6, impl.head_num * impl.head_dim)) + + def test_tuple_packed_none_none(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([4], device="cuda") + qkv = torch.randn(4, self._hidden(impl), device="cuda", dtype=torch.bfloat16) + result = _try_embedding_fast_path((qkv, None, None), impl, params) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (4, impl.head_num * impl.head_dim)) + + def test_tuple_packed_empty_tensors(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([4], device="cuda") + qkv = torch.randn(4, self._hidden(impl), device="cuda", dtype=torch.bfloat16) + result = _try_embedding_fast_path((qkv, torch.Tensor(), torch.Tensor()), impl, params) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (4, impl.head_num * impl.head_dim)) + + def test_variable_length_trimming(self): + """Unequal seqlens: output rows == token_q_num, not input rows.""" + impl = FakeFmhaImpl() + seqlens = [3, 5, 2] + params = FakeFmhaParams(seqlens, device="cuda") + # Allocate extra rows beyond token_q_num + qkv = torch.randn(12, self._hidden(impl), device="cuda", dtype=torch.bfloat16) + params.token_q_num = 10 + params.token_kv_num = 10 + result = _try_embedding_fast_path(qkv, impl, params) + self.assertIsNotNone(result) + self.assertEqual(result.shape[0], 10, "output should be trimmed to token_q_num") + + def test_single_token(self): + impl = FakeFmhaImpl() + params = FakeFmhaParams([1], device="cuda") + qkv = torch.randn(1, self._hidden(impl), device="cuda", dtype=torch.bfloat16) + result = _try_embedding_fast_path(qkv, impl, params) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (1, impl.head_num * impl.head_dim)) + + +if __name__ == "__main__": + unittest.main(verbosity=2)