Skip to content

Commit

Permalink
Support the case where query_len>context_len in the multi-queries pag…
Browse files Browse the repository at this point in the history
…ed attention. (#8356)
  • Loading branch information
vanbasten23 authored Nov 14, 2024
1 parent f610f70 commit 102cd48
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 30 deletions.
21 changes: 18 additions & 3 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,10 @@ def test_paged_attention_multi_queries_wrapper(self):

max_kv_len = 2048
query_len = 64
kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32)
batch_size = 3
kv_seq_lens = torch.randint(
query_len, max_kv_len, (batch_size,), dtype=torch.int32)
effective_q_lens = torch.full((batch_size,), query_len, dtype=torch.int32)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
Expand All @@ -592,13 +595,15 @@ def test_paged_attention_multi_queries_wrapper(self):
v_pages_xla = v_pages.to("xla")
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")
effective_q_lens_xla = effective_q_lens.to("xla")

output = multi_queries_paged_attention(
q_xla,
k_pages_xla,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)
Expand All @@ -609,6 +614,7 @@ def test_paged_attention_multi_queries_wrapper(self):
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
Expand All @@ -619,6 +625,7 @@ def test_paged_attention_multi_queries_wrapper(self):
v_pages_jax = jnp.array(v_pages.numpy(), dtype=jnp.float32)
kv_seq_lens_jax = jnp.array(kv_seq_lens.numpy(), dtype=jnp.int32)
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32)
expected_output = torch.from_numpy(
np.array(
jax_multi_queries_paged_attention(
Expand All @@ -627,6 +634,7 @@ def test_paged_attention_multi_queries_wrapper(self):
v_pages_jax,
kv_seq_lens_jax,
page_indices_jax,
effective_q_lens_jax,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
)))
Expand Down Expand Up @@ -654,7 +662,10 @@ def test_paged_attention_multi_queries_wrapper_with_dynamo(self):

max_kv_len = 2048
query_len = 64
kv_seq_lens = torch.randint(query_len, max_kv_len, (3,), dtype=torch.int32)
batch_size = 3
kv_seq_lens = torch.randint(
query_len, max_kv_len, (batch_size,), dtype=torch.int32)
effective_q_lens = torch.full((batch_size,), query_len, dtype=torch.int32)
assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
Expand All @@ -679,9 +690,10 @@ def test_paged_attention_multi_queries_wrapper_with_dynamo(self):
v_pages_xla = v_pages.to("xla")
kv_seq_lens_xla = kv_seq_lens.to("xla")
page_indices_xla = page_indices.to("xla")
effective_q_lens_xla = effective_q_lens.to("xla")

def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
page_indices,
page_indices, effective_q_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel):
Expand All @@ -691,6 +703,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=use_kernel,
Expand All @@ -705,6 +718,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=True,
Expand All @@ -716,6 +730,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens,
v_pages_xla,
kv_seq_lens_xla,
page_indices_xla,
effective_q_lens_xla,
num_kv_pages_per_compute_block=block_kv_size // page_size,
num_queries_per_compute_block=num_queries_per_compute_block,
use_kernel=False,
Expand Down
141 changes: 125 additions & 16 deletions test/test_tpu_paged_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def _ref_jax_extended_paged_attention(
q, # [batch_size, query_len, num_query_heads, head_size]
k_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
v_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
lengths, # [batch_size]
lengths, # [batch_size], the effective kv_length.
page_indices, # [batch_size, pages_per_sequence]
effective_q_lens, # [batch_size] the effective q_length
):
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
Expand Down Expand Up @@ -72,7 +73,8 @@ def _ref_jax_extended_paged_attention(

attn = jnp.einsum("qhd,khd->hqk", q[i], k)
attn = attn.astype('float32')
q_span = (kv_len - query_len) + jax.lax.broadcasted_iota(
effective_q_len = effective_q_lens[i]
q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota(
jnp.int32, (query_len, kv_len), 0)
kv_span = jax.lax.broadcasted_iota(jnp.int32, (query_len, kv_len), 1)
mask = jnp.where(q_span < kv_span, float("-inf"), 0.)
Expand All @@ -91,17 +93,16 @@ class PagedAttentionKernelTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()


# def test_paged_attention(
# self,
# ):
# dtype = jnp.bfloat16
# page_size=16
# num_kv_heads = 8
# q_kv_head_ratio = 4
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256
# def test_paged_attention(
# self,
# ):
# dtype = jnp.bfloat16
# page_size=16
# num_kv_heads = 8
# q_kv_head_ratio = 4
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256

@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
Expand All @@ -112,7 +113,7 @@ def setUp(self):
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
)
def test_paged_attention(
def test_paged_attention_without_query_padding(
self,
dtype,
page_size,
Expand All @@ -125,13 +126,13 @@ def test_paged_attention(

max_kv_len = 2048
query_len = 64
batch_size = 3
kv_seq_lens = jax.random.randint(
jax.random.key(0), (3,), query_len, max_kv_len)
jax.random.key(0), (batch_size,), query_len, max_kv_len)

assert query_len <= max_kv_len
for cur_kv_seq in kv_seq_lens:
assert query_len <= cur_kv_seq, f'{query_len} should be less than or equal to the kv_len {cur_kv_seq} in the current sequence.'
batch_size = len(kv_seq_lens)
pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
Expand All @@ -150,12 +151,14 @@ def test_paged_attention(

print(f'Running paged_attention with {query_len=}')
num_kv_pages_per_compute_block = block_kv_size // page_size
effective_q_lens = jnp.full_like(kv_seq_lens, query_len)
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
Expand All @@ -168,6 +171,7 @@ def test_paged_attention(
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
)

self.assertEqual(actual_output.shape, expected_output.shape)
Expand All @@ -183,5 +187,110 @@ def test_paged_attention(
self.assertTrue(
jnp.allclose(expected_output, actual_output, atol=atol, rtol=rtol))

# def test_paged_attention_query_len_longer_than_kv_seq_len(
# self,
# ):
# dtype = jnp.float32
# page_size=16
# num_kv_heads = 8
# q_kv_head_ratio = 4
# head_dim = 256
# num_queries_per_compute_block = 32
# block_kv_size = 256
# In practice, vLLM would pad the query so that the query seq len will be longer than the kv seq len. query seq len may be padded but not for kv seq len.
# When this happens, we need an additional parameter `effective_q_lens` to the paged_attention to set the causal mask right.
@parameterized.product(
dtype=(jnp.float32, jnp.bfloat16),
page_size=(16, 32, 64),
num_kv_heads=(1, 8),
q_kv_head_ratio=(1, 4, 8),
head_dim=(128, 256),
num_queries_per_compute_block=(16, 32),
block_kv_size=(128, 256),
)
def test_paged_attention_with_query_padding(
self,
dtype,
page_size,
num_kv_heads,
q_kv_head_ratio,
head_dim,
num_queries_per_compute_block,
block_kv_size,
):

max_kv_len = 2048
# Set query_len>kv_seq_lens
query_len = max_kv_len
batch_size = 3
kv_seq_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, max_kv_len)
effective_q_lens = jax.random.randint(
jax.random.key(0), (batch_size,), 0, kv_seq_lens)
for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens):
assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.'

pages_per_sequence = max_kv_len // page_size
total_num_pages = batch_size * pages_per_sequence
assert max_kv_len <= total_num_pages * page_size
q, k_pages, v_pages, page_indices = _generate_qkv(
kv_seq_lens,
page_size,
max_kv_len,
query_len,
num_kv_heads,
num_kv_heads * q_kv_head_ratio,
head_dim,
jax.random.key(0),
dtype,
)

print(
f'Running paged_attention with {query_len=}, {kv_seq_lens=}, {effective_q_lens=}'
)
num_kv_pages_per_compute_block = block_kv_size // page_size
actual_output = paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
)
# actual_output = jax.block_until_ready(actual_output)

# Run the ref impl.
expected_output = _ref_jax_extended_paged_attention(
q,
k_pages,
v_pages,
kv_seq_lens,
page_indices,
effective_q_lens,
)

self.assertEqual(actual_output.shape, expected_output.shape)

if dtype == jnp.float32:
atol = 2e-2
rtol = 1e-2
elif dtype == jnp.bfloat16:
atol = 6e-1
rtol = 1e-1
else:
self.fail(f'Unsupported dtype: {dtype}')
for b in range(batch_size):
# N.B. For the output ([batch_size, query_len, num_q_heads, head_dim]) at query_len dim, all the value after the effective_q_len will be thrown away due to we padding the query seq len. The values after the effective_q_len may differ between the kernel and the ref impl because of the causal mask.
effective_q_len = effective_q_lens[b]
self.assertTrue(
jnp.allclose(
expected_output[b, :effective_q_len],
actual_output[b, :effective_q_len],
atol=atol,
rtol=rtol))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
22 changes: 14 additions & 8 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,9 @@ def _multi_queries_paged_attention_nonkernel(
q, # [batch_size, query_len, num_heads, head_size]
k_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
v_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens)
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length.
page_indices, # [batch_size, pages_per_sequence]
effective_q_lens, # [batch_size], the effective q_length
) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim]
batch_size, query_len, num_query_heads, head_size = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
Expand Down Expand Up @@ -528,7 +529,8 @@ def _multi_queries_paged_attention_nonkernel(
k) # [num_query_heads, query_len, kv_len]
attn = attn.float()
empty_mask = torch.ones(query_len, kv_len, device=attn.device)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
effective_q_len = effective_q_lens[i]
mask = torch.triu(empty_mask, diagonal=kv_len - effective_q_len + 1).bool()
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(
attn, dim=-1).to(v.dtype) # [num_query_heads, query_len, kv_len]
Expand All @@ -547,6 +549,7 @@ def multi_queries_paged_attention(
v_pages, # [num_kv_heads, total_num_pages, page_size, head_size]
lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens)
page_indices, # [batch_size, pages_per_sequence]
effective_q_lens, # [batch_size]
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
Expand All @@ -559,6 +562,7 @@ def multi_queries_paged_attention(
v_pages,
lengths,
page_indices,
effective_q_lens,
)

# Import JAX within the function such that we don't need to call the jax_import_guard()
Expand All @@ -572,6 +576,7 @@ def multi_queries_paged_attention(
v_pages,
lengths,
page_indices,
effective_q_lens,
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
num_queries_per_compute_block=num_queries_per_compute_block,
static_argnames=[
Expand All @@ -592,6 +597,7 @@ def multi_queries_paged_attention(
[
lengths,
page_indices_reshaped,
effective_q_lens,
buffer_index,
step,
q.to(q_dtype_for_kernel_launch),
Expand Down Expand Up @@ -1081,18 +1087,18 @@ def paged_attention_non_xla(q: torch.Tensor,


XLA_LIB.define(
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor",
"multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor",
)


@impl(XLA_LIB, "multi_queries_paged_attention", "XLA")
def multi_queries_paged_attention_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int,
use_kernel: bool):
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int, use_kernel: bool):
return multi_queries_paged_attention(q, k_pages, v_pages, lengths,
page_indices,
page_indices, effective_q_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel)
Expand All @@ -1102,8 +1108,8 @@ def multi_queries_paged_attention_xla(
def multi_queries_paged_attention_non_xla(
q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor,
lengths: torch.Tensor, page_indices: torch.Tensor,
num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int,
use_kernel: bool):
effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int,
num_queries_per_compute_block: int, use_kernel: bool):
return non_xla_attetion(q, k_pages, v_pages, "paged")


Expand Down
Loading

0 comments on commit 102cd48

Please sign in to comment.