@@ -209,14 +209,18 @@ def __init__(self, graph_id, device_id):
209209 self .output_dtypes = []
210210 self .output_datasize = []
211211 for item in shapes :
212+ if item == '' :
213+ self .output_shapes .append ([])
214+ continue
212215 elems = item .split (',' )
213216 elems = [int (x ) for x in elems ]
214217 self .output_shapes .append (elems )
215218 for item in dtypes :
216219 elem = int (item )
217220 self .output_dtypes .append (elem )
218221 for i in range (len (shapes )):
219- elem_size = math .prod (self .output_shapes [i ])
222+ elem_size = math .prod (self .output_shapes [i ]) if len (
223+ self .output_shapes [i ]) > 0 else 1
220224 self .output_datasize .append (
221225 elem_size * acl .data_type_size (self .output_dtypes [i ]))
222226 self .output_datasize_c = (
@@ -242,14 +246,18 @@ def __init__(self, graph_id, device_id):
242246 self .input_datasize = []
243247
244248 for item in shapes :
249+ if item == '' :
250+ self .input_shapes .append ([])
251+ continue
245252 elems = item .split (',' )
246253 elems = [int (x ) for x in elems ]
247254 self .input_shapes .append (elems )
248255 for item in dtypes :
249256 elem = int (item )
250257 self .input_dtypes .append (elem )
251258 for i in range (len (shapes )):
252- elem_size = math .prod (self .input_shapes [i ])
259+ elem_size = math .prod (self .input_shapes [i ]) if len (
260+ self .input_shapes [i ]) > 0 else 1
253261 self .input_datasize .append (
254262 elem_size * acl .data_type_size (self .input_dtypes [i ]))
255263 self .input_datasize_c = (
0 commit comments