Skip to content

Commit 7546e7a

Browse files
authored
[DICP][Ascend] Support llama2 7B with lightllm. (#787)
1 parent b21abb9 commit 7546e7a

File tree

13 files changed

+638
-110
lines changed

13 files changed

+638
-110
lines changed

dicp/dicp/dynamo_bridge/compile_fx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ def compile_fx_210(
208208
def fw_compiler_base(model: torch.fx.GraphModule, example_inputs, is_inference):
209209
if is_inference:
210210
# partition_fn won't be called
211-
joint_graph_passes(model)
211+
# joint_graph_passes(model)
212+
pass
212213

213214
fixed = len(example_inputs) - num_example_inputs
214215
return inner_compile(
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from collections import defaultdict
2+
from typing import Callable, Dict, Sequence, Union
3+
4+
import torch
5+
from torch._decomp import register_decomposition
6+
from torch._ops import OpOverload, OpOverloadPacket
7+
8+
dicp_decomposition_table = {}
9+
aten = torch.ops.aten
10+
11+
12+
def register_decomposition_for_dicp(fn):
13+
return register_decomposition(fn, registry=dicp_decomposition_table)
14+
15+
16+
@register_decomposition_for_dicp(aten.count_nonzero.default)
17+
def count_nonzero_default(x, dim=None):
18+
cond = x != 0
19+
dim = [] if dim is None else dim
20+
return aten.sum.dim_IntList(cond, dim=dim, keepdim=False, dtype=torch.int64)
21+
22+
23+
def get_decompositions(
24+
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
25+
target_decomposition_table: Dict[OpOverload, Callable] = None,
26+
) -> Dict[OpOverload, Callable]:
27+
registry = dicp_decomposition_table
28+
packets_to_overloads = defaultdict(list)
29+
for opo in registry:
30+
packets_to_overloads[opo.overloadpacket].append(opo)
31+
decompositions = target_decomposition_table if target_decomposition_table else {}
32+
for op in aten_ops:
33+
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
34+
for op_overload in packets_to_overloads[op]:
35+
decompositions[op_overload] = registry[op_overload]
36+
elif isinstance(op, OpOverload) and op in registry:
37+
decompositions[op] = registry[op]
38+
return decompositions

dicp/dicp/vendor/AscendGraph/ascend_op.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
aten = torch.ops.aten
1616

17+
1718
def negative_in_shape(shape):
1819
for elem in shape:
1920
if elem < 0:
@@ -43,12 +44,12 @@ def __init__(self):
4344

4445
def infer_result(self, x, shape):
4546
x, x_shape, _, x_dtype = get_fake_tensor_meta_val(x)
46-
if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
47+
if isinstance(shape, torch._subclasses.fake_tensor.FakeTensor): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
4748
shape, shape_shape, _, _ = get_fake_tensor_meta_val(shape)
4849
shape = shape_shape
49-
elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt'
50-
shape, _, _, _ =get_op_const_arg_kwarg(shape)
51-
else: # other cases, unsupported yet
50+
elif isinstance(shape, Tuple): # case2: shape is tuple from 'Const' , like conversion for 'lt'
51+
shape, _, _, _ = get_op_const_arg_kwarg(shape)
52+
else: # other cases, unsupported yet
5253
assert False, self.__class__.__name__ + "unsupported 'shape' input type!"
5354

5455
out_shape = get_broadcast_res_two_shape(x_shape, shape)
@@ -97,7 +98,7 @@ def __init__(self):
9798
class MatMul(Operator):
9899
def __init__(self):
99100
super().__init__("MatMul")
100-
101+
101102
def infer_result(self, x1, x2, adj_x1=False, adj_x2=False):
102103
attr = acl.op.create_attr()
103104
check_ret("acl.op.set_attr_bool", acl.op.set_attr_bool(attr, "transpose_x1", adj_x1))
@@ -290,6 +291,14 @@ def infer_result(self, x, dims, keepdim):
290291
return reduce_op_infer(x, dims, keepdim)
291292

292293

294+
class ReduceSum(Operator):
295+
def __init__(self):
296+
super().__init__("ReduceSum")
297+
298+
def infer_result(self, x, dims, keepdim):
299+
return reduce_op_infer(x, dims, keepdim)
300+
301+
293302
class Unsqueeze(Operator):
294303
def __init__(self):
295304
super().__init__("Unsqueeze")
@@ -628,7 +637,7 @@ def infer_result(self, x, index, orig_index):
628637

629638
# assume not none index, and replace prefix x_shape dims
630639
len_idx_shape = len(orig_index)
631-
assert(len_idx_shape > 0)
640+
assert (len_idx_shape > 0)
632641
bcast_index_shape = list(orig_index[0].shape)
633642
x_shape = bcast_index_shape + list(x_shape[len_idx_shape:])
634643
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))
@@ -962,6 +971,14 @@ def infer_result(self, x1, x2):
962971
return common_binary_op_infer(x1, x2, torch.bool)
963972

964973

974+
class LogicalNot(Operator):
975+
def __init__(self):
976+
super().__init__("LogicalNot")
977+
978+
def infer_result(self, x):
979+
return common_binary_op_infer(x, torch.bool)
980+
981+
965982
class Tril(Operator):
966983
def __init__(self):
967984
super().__init__("Tril")
@@ -1023,7 +1040,7 @@ def infer_result(
10231040
output_batch_var = torch.empty(
10241041
[channel_size], dtype=torch.float32, memory_format=torch.contiguous_format
10251042
)
1026-
return [output_y,output_mean,output_var,output_batch_mean,output_batch_var]
1043+
return [output_y, output_mean, output_var, output_batch_mean, output_batch_var]
10271044

10281045

10291046
class TileWithAxis(Operator):
@@ -1032,6 +1049,38 @@ def __init__(self):
10321049
self.torch_op = aten.repeat_interleave.self_int
10331050

10341051

1052+
class RotaryMul(Operator):
1053+
def __init__(self):
1054+
super().__init__("RotaryMul")
1055+
1056+
def infer_result(self, x, cos, sin):
1057+
return torch.empty_like(x)
1058+
1059+
1060+
class RmsNorm(Operator):
1061+
def __init__(self):
1062+
super().__init__("RmsNorm")
1063+
1064+
def infer_result(self, x, weight, eps):
1065+
return torch.empty_like(x)
1066+
1067+
1068+
class PromptFlashAttention(Operator):
1069+
def __init__(self):
1070+
super().__init__("PromptFlashAttention")
1071+
1072+
def infer_result(self, q, k, v, num_head, seqlen, mask, head_dim):
1073+
return torch.empty_like(q)
1074+
1075+
1076+
class IncreFlashAttention(Operator):
1077+
def __init__(self):
1078+
super().__init__("IncreFlashAttention")
1079+
1080+
def infer_result(self, q, k, v, head_num):
1081+
return torch.empty_like(q)
1082+
1083+
10351084
class TensorScatterUpdate(Operator):
10361085
def __init__(self):
10371086
super().__init__("TensorScatterUpdate")
@@ -1054,6 +1103,38 @@ def infer_result(self, x, indices, updates):
10541103
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))
10551104

10561105

1106+
class ExpandDims(Operator):
1107+
def __init__(self):
1108+
super().__init__("ExpandDims")
1109+
1110+
def infer_result(self, x, axis):
1111+
return torch.unsqueeze(x, axis)
1112+
1113+
1114+
class MaskedScatter(Operator):
1115+
def __init__(self):
1116+
super().__init__("MaskedScatter")
1117+
1118+
def infer_result(self, x, mask, updates):
1119+
return x
1120+
1121+
1122+
class ViewCopy(Operator):
1123+
def __init__(self):
1124+
super().__init__("ViewCopy")
1125+
1126+
def infer_result(self, dst, dst_size, dst_stride, dst_storage_offset, src, src_size, src_stride, src_storage_offset):
1127+
return dst
1128+
1129+
1130+
class ScatterNdUpdate(Operator):
1131+
def __init__(self):
1132+
super().__init__("ScatterNdUpdate")
1133+
1134+
def infer_result(self, x, indices, updates):
1135+
return x
1136+
1137+
10571138
def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
10581139
return a, b, c
10591140

0 commit comments

Comments
 (0)