1414
1515aten = torch .ops .aten
1616
17+
1718def negative_in_shape (shape ):
1819 for elem in shape :
1920 if elem < 0 :
@@ -43,12 +44,12 @@ def __init__(self):
4344
4445 def infer_result (self , x , shape ):
4546 x , x_shape , _ , x_dtype = get_fake_tensor_meta_val (x )
46- if isinstance (shape , torch ._subclasses .fake_tensor .FakeTensor ): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
47+ if isinstance (shape , torch ._subclasses .fake_tensor .FakeTensor ): # case1: shape is a fakeTensor, like conversion for 'scatter' and 'where'
4748 shape , shape_shape , _ , _ = get_fake_tensor_meta_val (shape )
4849 shape = shape_shape
49- elif isinstance (shape , Tuple ): # case2: shape is tuple from 'Const' , like conversion for 'lt'
50- shape , _ , _ , _ = get_op_const_arg_kwarg (shape )
51- else : # other cases, unsupported yet
50+ elif isinstance (shape , Tuple ): # case2: shape is tuple from 'Const' , like conversion for 'lt'
51+ shape , _ , _ , _ = get_op_const_arg_kwarg (shape )
52+ else : # other cases, unsupported yet
5253 assert False , self .__class__ .__name__ + "unsupported 'shape' input type!"
5354
5455 out_shape = get_broadcast_res_two_shape (x_shape , shape )
@@ -97,7 +98,7 @@ def __init__(self):
9798class MatMul (Operator ):
9899 def __init__ (self ):
99100 super ().__init__ ("MatMul" )
100-
101+
101102 def infer_result (self , x1 , x2 , adj_x1 = False , adj_x2 = False ):
102103 attr = acl .op .create_attr ()
103104 check_ret ("acl.op.set_attr_bool" , acl .op .set_attr_bool (attr , "transpose_x1" , adj_x1 ))
@@ -290,6 +291,14 @@ def infer_result(self, x, dims, keepdim):
290291 return reduce_op_infer (x , dims , keepdim )
291292
292293
294+ class ReduceSum (Operator ):
295+ def __init__ (self ):
296+ super ().__init__ ("ReduceSum" )
297+
298+ def infer_result (self , x , dims , keepdim ):
299+ return reduce_op_infer (x , dims , keepdim )
300+
301+
293302class Unsqueeze (Operator ):
294303 def __init__ (self ):
295304 super ().__init__ ("Unsqueeze" )
@@ -628,7 +637,7 @@ def infer_result(self, x, index, orig_index):
628637
629638 # assume not none index, and replace prefix x_shape dims
630639 len_idx_shape = len (orig_index )
631- assert (len_idx_shape > 0 )
640+ assert (len_idx_shape > 0 )
632641 bcast_index_shape = list (orig_index [0 ].shape )
633642 x_shape = bcast_index_shape + list (x_shape [len_idx_shape :])
634643 return torch .empty (x_shape , dtype = x_dtype , memory_format = get_memory_format (x ))
@@ -962,6 +971,14 @@ def infer_result(self, x1, x2):
962971 return common_binary_op_infer (x1 , x2 , torch .bool )
963972
964973
974+ class LogicalNot (Operator ):
975+ def __init__ (self ):
976+ super ().__init__ ("LogicalNot" )
977+
978+ def infer_result (self , x ):
979+ return common_binary_op_infer (x , torch .bool )
980+
981+
965982class Tril (Operator ):
966983 def __init__ (self ):
967984 super ().__init__ ("Tril" )
@@ -1023,7 +1040,7 @@ def infer_result(
10231040 output_batch_var = torch .empty (
10241041 [channel_size ], dtype = torch .float32 , memory_format = torch .contiguous_format
10251042 )
1026- return [output_y ,output_mean ,output_var ,output_batch_mean ,output_batch_var ]
1043+ return [output_y , output_mean , output_var , output_batch_mean , output_batch_var ]
10271044
10281045
10291046class TileWithAxis (Operator ):
@@ -1032,6 +1049,38 @@ def __init__(self):
10321049 self .torch_op = aten .repeat_interleave .self_int
10331050
10341051
1052+ class RotaryMul (Operator ):
1053+ def __init__ (self ):
1054+ super ().__init__ ("RotaryMul" )
1055+
1056+ def infer_result (self , x , cos , sin ):
1057+ return torch .empty_like (x )
1058+
1059+
1060+ class RmsNorm (Operator ):
1061+ def __init__ (self ):
1062+ super ().__init__ ("RmsNorm" )
1063+
1064+ def infer_result (self , x , weight , eps ):
1065+ return torch .empty_like (x )
1066+
1067+
1068+ class PromptFlashAttention (Operator ):
1069+ def __init__ (self ):
1070+ super ().__init__ ("PromptFlashAttention" )
1071+
1072+ def infer_result (self , q , k , v , num_head , seqlen , mask , head_dim ):
1073+ return torch .empty_like (q )
1074+
1075+
1076+ class IncreFlashAttention (Operator ):
1077+ def __init__ (self ):
1078+ super ().__init__ ("IncreFlashAttention" )
1079+
1080+ def infer_result (self , q , k , v , head_num ):
1081+ return torch .empty_like (q )
1082+
1083+
10351084class TensorScatterUpdate (Operator ):
10361085 def __init__ (self ):
10371086 super ().__init__ ("TensorScatterUpdate" )
@@ -1054,6 +1103,38 @@ def infer_result(self, x, indices, updates):
10541103 return torch .empty (x_shape , dtype = x_dtype , memory_format = get_memory_format (x ))
10551104
10561105
1106+ class ExpandDims (Operator ):
1107+ def __init__ (self ):
1108+ super ().__init__ ("ExpandDims" )
1109+
1110+ def infer_result (self , x , axis ):
1111+ return torch .unsqueeze (x , axis )
1112+
1113+
1114+ class MaskedScatter (Operator ):
1115+ def __init__ (self ):
1116+ super ().__init__ ("MaskedScatter" )
1117+
1118+ def infer_result (self , x , mask , updates ):
1119+ return x
1120+
1121+
1122+ class ViewCopy (Operator ):
1123+ def __init__ (self ):
1124+ super ().__init__ ("ViewCopy" )
1125+
1126+ def infer_result (self , dst , dst_size , dst_stride , dst_storage_offset , src , src_size , src_stride , src_storage_offset ):
1127+ return dst
1128+
1129+
1130+ class ScatterNdUpdate (Operator ):
1131+ def __init__ (self ):
1132+ super ().__init__ ("ScatterNdUpdate" )
1133+
1134+ def infer_result (self , x , indices , updates ):
1135+ return x
1136+
1137+
10571138def ret_triple (a , b , c ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
10581139 return a , b , c
10591140
0 commit comments