@@ -54,8 +54,6 @@ def build(self, x):
5454 # repeat 8 times
5555 y0, w0 = ops.repeat(add_weight_graph0, 8, x0, inputs_dict={add_weight0.w: w0})
5656
57- See also `PyTorch Tensor.repeat <https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html>`__, `NumPy repeat <https://numpy.org/doc/stable/reference/generated/numpy.repeat.html>`__.
58-
5957 Args:
6058 graph (Graph): User defined graph to repeat `repeat_count` times.
6159 repeat_count (int): Number of times to repeat calling the graph.
@@ -74,7 +72,10 @@ def build(self, x):
7472 Tuple[Tensor, ...]:
7573 Tuple of the output tensors of the call in the parent graph.
7674 """
77- loop_info = repeat_with_info (graph , repeat_count , * inputs , inputs_dict = inputs_dict )
75+ loop_info = repeat_with_info (graph ,
76+ repeat_count ,
77+ * inputs ,
78+ inputs_dict = inputs_dict )
7879
7980 out_tensors = loop_info .outputs
8081 return out_tensors
@@ -193,8 +194,7 @@ def build(self, x):
193194 if total_inputs < total_outputs :
194195 raise ValueError (
195196 f"To repeat the subgraph ({ graph .id } ) the number of inputs must be greater than or equal to the number of outputs."
196- f" { total_inputs } < { total_outputs } "
197- )
197+ f" { total_inputs } < { total_outputs } " )
198198
199199 # For clarity, we rename our graphs:
200200 # - Bottom: The user provided bottom level graph. We call this with a call op. This has gone
@@ -215,16 +215,14 @@ def build(self, x):
215215
216216 # Create the middle graph, call and loop ops
217217 pb_middle_graph , pb_callop , pb_loop_op = _setup_call_and_repeat (
218- pb_ir , pb_top_graph , pb_bottom_graph
219- )
218+ pb_ir , pb_top_graph , pb_bottom_graph )
220219
221220 # set the number of times to loop
222221 pb_loop_op .setTripCountValue (repeat_count )
223222
224223 # Prep and validate inputs
225- inputs_all = _prep_and_validate_inputs (
226- check_inputs , top_graph , graph , "called" , inputs , inputs_dict
227- )
224+ inputs_all = _prep_and_validate_inputs (check_inputs , top_graph , graph ,
225+ "called" , inputs , inputs_dict )
228226
229227 # 1, 2. Connect inputs.
230228 _setup_inputs (
@@ -236,9 +234,8 @@ def build(self, x):
236234 )
237235
238236 # 3. Connect outputs.
239- _ = _setup_outputs (
240- pb_top_graph , pb_bottom_graph , pb_middle_graph , pb_callop , pb_loop_op
241- )
237+ _ = _setup_outputs (pb_top_graph , pb_bottom_graph , pb_middle_graph ,
238+ pb_callop , pb_loop_op )
242239
243240 pb_callop .setup ()
244241 pb_loop_op .setup ()
@@ -250,13 +247,14 @@ def build(self, x):
250247 loop_carried_inputs = pb_loop_op .getNumExplicitInputs ()
251248 for bottom_t in bottom_graph ._by_ref_inputs :
252249 middle_t = c_info .graph_to_parent (bottom_t )
253- loop_carried = pb_middle_graph .getInputIndex (middle_t .id ) < loop_carried_inputs
250+ loop_carried = pb_middle_graph .getInputIndex (
251+ middle_t .id ) < loop_carried_inputs
254252 # If a tensor was set as a by_ref_input, we should also do the same for the looped subgraph.
255253 c_info .set_parent_input_modified (
256- middle_t , infer_modified_regions = not loop_carried
257- )
254+ middle_t , infer_modified_regions = not loop_carried )
258255 top_t = r_info .graph_to_parent (middle_t )
259- r_info .set_parent_input_modified (top_t , infer_modified_regions = not loop_carried )
256+ r_info .set_parent_input_modified (
257+ top_t , infer_modified_regions = not loop_carried )
260258 r_info .called_graph ._by_ref_inputs .add (middle_t )
261259
262260 return r_info
@@ -280,34 +278,35 @@ def _setup_call_and_repeat(
280278 # This is the graph we will repeat.
281279 pb_middle_graph = pb_ir .createGraph (
282280 _ir .GraphId (
283- pb_ir .createUniqueSubgraphId (f"{ pb_bottom_graph .id .str ()} __loop_wrapper" )
284- )
285- )
281+ pb_ir .createUniqueSubgraphId (
282+ f"{ pb_bottom_graph .id .str ()} __loop_wrapper" )))
286283
287- opid = _ir .OperatorIdentifier ("ai.graphcore" , "Call" , 1 , _ir .NumInputs (), 0 )
284+ opid = _ir .OperatorIdentifier ("ai.graphcore" , "Call" , 1 , _ir .NumInputs (),
285+ 0 )
288286 op_name = pb_middle_graph .id .str () + "__call__" + pb_bottom_graph .id .str ()
289287
290288 ctx = get_current_context ()
291289 # Call the bottom_graph
292- pb_callop = pb_middle_graph .createOp_CallOp (
293- opid , pb_bottom_graph , ctx ._get_op_settings (op_name )
294- )
290+ pb_callop = pb_middle_graph .createOp_CallOp (opid , pb_bottom_graph ,
291+ ctx ._get_op_settings (op_name ))
295292
296293 opid = _ir .OperatorIdentifier ("ai.onnx" , "Loop" , 11 , _ir .NumInputs (), 0 )
297294 op_name = pb_top_graph .id .str () + "__loop__" + pb_middle_graph .id .str ()
298295
299296 # Loop the middle_graph
300- pb_loop_op = pb_top_graph .createOp_LoopOp (
301- opid , ctx ._get_op_settings (op_name ), pb_middle_graph
302- )
297+ pb_loop_op = pb_top_graph .createOp_LoopOp (opid ,
298+ ctx ._get_op_settings (op_name ),
299+ pb_middle_graph )
303300
304301 # Add mandatory loop iterator tensor to graph (is not an output)
305302 repeatIterId = _ir .addScope (pb_middle_graph , "Iterator___" )
306- pb_middle_graph .addInput (repeatIterId , _ir .TensorInfo (_ir .DataType .INT32 , ()))
303+ pb_middle_graph .addInput (repeatIterId ,
304+ _ir .TensorInfo (_ir .DataType .INT32 , ()))
307305
308306 # Add mandatory loop condition tensor to graph (is also an output)
309307 repeatCondId = _ir .addScope (pb_middle_graph , "LoopCond___" )
310- pb_middle_graph .addInput (repeatCondId , _ir .TensorInfo (_ir .DataType .BOOL , ()))
308+ pb_middle_graph .addInput (repeatCondId ,
309+ _ir .TensorInfo (_ir .DataType .BOOL , ()))
311310 pb_middle_graph .markAsOutput (repeatCondId )
312311
313312 return pb_middle_graph , pb_callop , pb_loop_op
@@ -354,8 +353,7 @@ def _setup_inputs(
354353 False ,
355354 )
356355 pb_callop .connectInTensor (
357- call_in_idx , _ir .addScope (pb_middle_graph , parent_tensor .name )
358- )
356+ call_in_idx , _ir .addScope (pb_middle_graph , parent_tensor .name ))
359357
360358
361359def _setup_outputs (
@@ -385,33 +383,31 @@ def _setup_outputs(
385383
386384 for pb_subgraph_out_id in pb_bottom_graph .getOutputIds ():
387385 top_tensor_id = _ir .addScope (
388- pb_top_graph , _ir .removeScope (pb_bottom_graph , pb_subgraph_out_id )
389- )
386+ pb_top_graph , _ir .removeScope (pb_bottom_graph , pb_subgraph_out_id ))
390387 # Already has scope added
391388 middle_tensor_id = _ir .removeScope (pb_bottom_graph , pb_subgraph_out_id )
392389 bottom_tensor_id = _ir .addScope (
393- pb_bottom_graph , _ir . removeScope ( pb_bottom_graph , pb_subgraph_out_id )
394- )
390+ pb_bottom_graph ,
391+ _ir . removeScope ( pb_bottom_graph , pb_subgraph_out_id ) )
395392
396393 sgOutIdx = pb_bottom_graph .getOutputIndex (bottom_tensor_id )
397394 callOutIdx = pb_callop .subgraphOutToOpOutIndex (sgOutIdx )
398395
399396 # Avoid tensor name collisions
400397 middle_tensor_id = pb_middle_graph .getIr ().createIntermediateTensorId (
401- middle_tensor_id
402- )
398+ middle_tensor_id )
403399 pb_callop .createAndConnectOutTensor (callOutIdx , middle_tensor_id )
404400
405401 pb_middle_graph .markAsOutput (middle_tensor_id )
406402 sgOutIdx = pb_middle_graph .getOutputIndex (middle_tensor_id )
407403 repeatOutIdx = pb_loop_op .subgraphOutToOpOutIndex (sgOutIdx )
408404 # Avoid tensor name collisions
409405 top_tensor_id = pb_middle_graph .getIr ().createIntermediateTensorId (
410- top_tensor_id
411- )
406+ top_tensor_id )
412407 # We overwrite here as we added the middle_tensor_id as an output above, but we want to make
413408 # sure the loop op is setup correctly.
414- pb_loop_op .addLoopOutput (repeatOutIdx , top_tensor_id , middle_tensor_id , True )
409+ pb_loop_op .addLoopOutput (repeatOutIdx , top_tensor_id , middle_tensor_id ,
410+ True )
415411
416412 outnames .append (top_tensor_id )
417413 return outnames
0 commit comments