Skip to content

Commit 5e7a35b

Browse files
authored
[DICP][Ascend] Add some lightllm op tests. (#792)
1 parent 7546e7a commit 5e7a35b

File tree

5 files changed

+195
-0
lines changed

5 files changed

+195
-0
lines changed

dicp/test/ascend_scripts/ops/static.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ python_files =
3535
test_index.py
3636
test_le.py
3737
; test_lift_fresh_copy.py
38+
test_lightllm_copy_with_offset.py
39+
test_lightllm_incre_attention.py
40+
test_lightllm_prompt_attention.py
41+
test_lightllm_rotary_emb.py
3842
; test_log.py
3943
test_logical_or.py
4044
test_lt.py
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
3+
from dicp.vendor.AscendGraph import ext_ops
4+
from ..common.utils import (
5+
torch,
6+
dynamo,
7+
parse_args,
8+
compile_model,
9+
get_device,
10+
Size,
11+
update_dynamo_config,
12+
)
13+
14+
15+
class OpModule(torch.nn.Module):
16+
def forward(self, out, k, start_dim, end_dim):
17+
res = torch.ops.lightllm.copy_with_offset.default(out, k, start_dim, end_dim)
18+
return res
19+
20+
21+
model = OpModule()
22+
args = parse_args()
23+
compiled_model = compile_model(model, args.backend, args.dynamic)
24+
25+
26+
class TestLightllmCopyWithOffset():
27+
@pytest.mark.parametrize("dtype", [torch.float32])
28+
@pytest.mark.parametrize("sizes", [Size(((8, 8, 16), (6, 8, 16)), ((8, 8, 16), (6, 8, 16))), Size(((8, 16, 32), (6, 16, 32)), ((8, 16, 32), (6, 16, 32)))])
29+
@pytest.mark.parametrize("compiled_model", compiled_model)
30+
def test_lighllm_copy_with_offset(self, sizes, dtype, compiled_model):
31+
device = get_device()
32+
size = sizes.dynamic if compiled_model.dynamic else sizes.static
33+
input1 = torch.randn(size[0], dtype=dtype)
34+
input2 = torch.randn(size[1], dtype=dtype)
35+
start_dim = 0
36+
end_dim = 6
37+
38+
dicp_input1 = input1.to(device)
39+
dicp_input2 = input2.to(device)
40+
41+
output = model(input1, input2, start_dim, end_dim)
42+
dynamo.reset()
43+
update_dynamo_config(compiled_model.dynamic)
44+
dicp_output = compiled_model.model(dicp_input1, dicp_input2, start_dim, end_dim)
45+
46+
assert torch.allclose(output, dicp_output.cpu(), rtol=1e-02, atol=1e-02, equal_nan=True)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import pytest
2+
3+
from dicp.vendor.AscendGraph import ext_ops
4+
from ..common.utils import (
5+
torch,
6+
dynamo,
7+
parse_args,
8+
compile_model,
9+
get_device,
10+
Size,
11+
update_dynamo_config,
12+
)
13+
14+
15+
class OpModule(torch.nn.Module):
16+
def forward(self, q, k, v, int_index_list, max_seq_length):
17+
res = torch.ops.lightllm.flash_attention_inference.default(q, k, v, int_index_list, max_seq_length)
18+
return res
19+
20+
21+
model = OpModule()
22+
args = parse_args()
23+
compiled_model = compile_model(model, args.backend, args.dynamic)
24+
25+
26+
class TestLightllmIncreAttention():
27+
@pytest.mark.parametrize("dtype", [torch.float32])
28+
@pytest.mark.parametrize("sizes", [Size(((8, 16), (9,)), ((8, 16), (9,))), Size(((8, 32), (9,)), ((8, 32), (9,)))])
29+
@pytest.mark.parametrize("compiled_model", compiled_model)
30+
def test_lightllm_incre_attention(self, sizes, dtype, compiled_model):
31+
device = get_device()
32+
size = sizes.dynamic if compiled_model.dynamic else sizes.static
33+
input1 = torch.randn((1,) + size[0], dtype=dtype)
34+
input2 = torch.randn(size[1] + size[0], dtype=dtype)
35+
input3 = torch.randn(size[1] + size[0], dtype=dtype)
36+
input4 = list(size[1])
37+
max_seq_length = size[1][0]
38+
39+
dicp_input1 = input1.to(device)
40+
dicp_input2 = input2.to(device)
41+
dicp_input3 = input3.to(device)
42+
dicp_input4 = input4
43+
44+
output = model(input1, input2, input3, input4, max_seq_length)
45+
dynamo.reset()
46+
update_dynamo_config(compiled_model.dynamic)
47+
dicp_output = compiled_model.model(dicp_input1, dicp_input2, dicp_input3, dicp_input4, max_seq_length)
48+
49+
assert torch.allclose(output, dicp_output.cpu(), rtol=1e-02, atol=1e-02, equal_nan=True)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from dicp.vendor.AscendGraph import ext_ops
4+
from ..common.utils import (
5+
torch,
6+
dynamo,
7+
parse_args,
8+
compile_model,
9+
get_device,
10+
Size,
11+
update_dynamo_config,
12+
)
13+
14+
15+
class OpModule(torch.nn.Module):
16+
def forward(self, q, k, v, seqlen, num_head, head_dim):
17+
res = torch.ops.lightllm.prompt_attention_inference.default(q, k, v, seqlen, num_head, head_dim)
18+
return res
19+
20+
21+
model = OpModule()
22+
args = parse_args()
23+
compiled_model = compile_model(model, args.backend, args.dynamic)
24+
25+
26+
class TestLightllmPromptAttention():
27+
@pytest.mark.parametrize("dtype", [torch.float16])
28+
@pytest.mark.parametrize("sizes", [Size(((1, 32, 16, 32), (32,)), ((1, 32, 16, 32), (32,))), Size(((1, 32, 16, 64), (32,)), ((1, 32, 16, 64), (32,)))])
29+
@pytest.mark.parametrize("compiled_model", compiled_model)
30+
def test_lightllm_prompt_attention(self, sizes, dtype, compiled_model):
31+
device = get_device()
32+
size = sizes.dynamic if compiled_model.dynamic else sizes.static
33+
input1 = torch.randn(size[0], dtype=dtype)
34+
input2 = torch.randn(size[0], dtype=dtype)
35+
input3 = torch.randn(size[0], dtype=dtype)
36+
input4 = torch.tensor(size[1], dtype=torch.int32)
37+
num_head = size[0][2]
38+
head_dim = size[0][3]
39+
40+
dicp_input1 = input1.to(device)
41+
dicp_input2 = input2.to(device)
42+
dicp_input3 = input3.to(device)
43+
dicp_input4 = input4.to(device)
44+
45+
output = model(input1, input2, input3, input4, num_head, head_dim).view(size[1][0], num_head * head_dim).half()
46+
dynamo.reset()
47+
update_dynamo_config(compiled_model.dynamic)
48+
dicp_output = compiled_model.model(dicp_input1.view(1, -1, num_head * head_dim), dicp_input2.view(1, -1, num_head * head_dim), dicp_input3.view(1, -1, num_head * head_dim), dicp_input4, num_head, head_dim).view(size[1][0], num_head * head_dim)
49+
50+
assert torch.allclose(output, dicp_output.cpu(), rtol=1e-02, atol=1e-02, equal_nan=True)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
3+
from dicp.vendor.AscendGraph import ext_ops
4+
from ..common.utils import (
5+
torch,
6+
dynamo,
7+
parse_args,
8+
compile_model,
9+
get_device,
10+
Size,
11+
update_dynamo_config,
12+
)
13+
14+
15+
class OpModule(torch.nn.Module):
16+
def forward(self, x, cos, sin):
17+
res = torch.ops.lightllm.rotary_emb.default(x, cos, sin)
18+
return res
19+
20+
21+
model = OpModule()
22+
args = parse_args()
23+
compiled_model = compile_model(model, args.backend, args.dynamic)
24+
25+
26+
class TestLightllmRotaryEmb():
27+
@pytest.mark.parametrize("dtype", [torch.float32])
28+
@pytest.mark.parametrize("sizes", [Size(((2, 32, 64), (2, 32), (2, 32)), ((2, 32, 64), (2, 32), (2, 32))), Size(((2, 32, 128), (2, 64), (2, 64)), ((2, 32, 128), (2, 64), (2, 64)))])
29+
@pytest.mark.parametrize("compiled_model", compiled_model)
30+
def test_lightllm_rotary_emb(self, sizes, dtype, compiled_model):
31+
device = get_device()
32+
size = sizes.dynamic if compiled_model.dynamic else sizes.static
33+
input1 = torch.randn(size[0], dtype=dtype)
34+
input2 = torch.randn(size[1], dtype=dtype)
35+
input3 = torch.randn(size[2], dtype=dtype)
36+
37+
dicp_input1 = input1.to(device)
38+
dicp_input2 = input2.to(device)
39+
dicp_input3 = input3.to(device)
40+
41+
output = model(input1, input2, input3)
42+
dynamo.reset()
43+
update_dynamo_config(compiled_model.dynamic)
44+
dicp_output = compiled_model.model(dicp_input1, dicp_input2, dicp_input3)
45+
46+
assert torch.allclose(output, dicp_output.cpu(), rtol=1e-02, atol=1e-02, equal_nan=True)

0 commit comments

Comments
 (0)