@@ -153,7 +153,7 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, pipeline_input:
153
153
execute_and_node (node , pre_edges , edge_args , post_edges )
154
154
155
155
out_args = {}
156
- terminal_nodes = pipeline .get_terminal_nodes ()
156
+ terminal_nodes = pipeline .get_output_nodes ()
157
157
for terminal_node in terminal_nodes :
158
158
edge = dm .Edge (terminal_node , None )
159
159
out_args [terminal_node ] = edge_args [edge ]
@@ -184,6 +184,32 @@ def select_pipeline(pipeline_output: dm.PipelineOutput, chosen_xyref: dm.XYRef):
184
184
return pipeline
185
185
186
186
187
+ def get_pipeline_input (pipeline : dm .Pipeline , pipeline_output : dm .PipelineOutput , chosen_xyref : dm .XYRef ):
188
+ pipeline_input = dm .PipelineInput ()
189
+
190
+ xyref_queue = SimpleQueue ()
191
+ xyref_queue .put (chosen_xyref )
192
+ while not xyref_queue .empty ():
193
+ curr_xyref = xyref_queue .get ()
194
+ curr_node_state_ptr = curr_xyref .get_curr_node_state_ref ()
195
+ curr_node = ray .get (curr_node_state_ptr )
196
+ curr_node_level = pipeline .get_node_level (curr_node )
197
+ prev_xyrefs = curr_xyref .get_prev_xyrefs ()
198
+
199
+ if curr_node_level == 0 :
200
+ # This is an input node
201
+ for prev_xyref in prev_xyrefs :
202
+ pipeline_input .add_xyref_arg (curr_node , prev_xyref )
203
+
204
+ for prev_xyref in prev_xyrefs :
205
+ prev_node_state_ptr = prev_xyref .get_curr_node_state_ref ()
206
+ if prev_node_state_ptr is None :
207
+ continue
208
+ xyref_queue .put (prev_xyref )
209
+
210
+ return pipeline_input
211
+
212
+
187
213
@ray .remote (num_returns = 2 )
188
214
def split (cross_validator : BaseCrossValidator , xy_ref ):
189
215
x = ray .get (xy_ref .get_Xref ())
@@ -220,7 +246,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
220
246
221
247
in_args = pipeline_input .get_in_args ()
222
248
for node , xyref_ptrs in in_args .items ():
223
- # NOTE: The assumption is that this node has only one input, the check earlier will ensure this !
249
+ # NOTE: The assumption is that this node has only one input!
224
250
xyref_ptr = xyref_ptrs [0 ]
225
251
xy_train_refs_ptr , xy_test_refs_ptr = split .remote (cross_validator , xyref_ptr )
226
252
xy_train_refs = ray .get (xy_train_refs_ptr )
@@ -238,7 +264,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
238
264
pipeline_output_train = execute_pipeline (pipeline , ExecutionType .FIT , pipeline_input_train )
239
265
240
266
# Now we can choose the pipeline and then score for each of the chosen pipelines
241
- out_nodes = pipeline .get_terminal_nodes ()
267
+ out_nodes = pipeline .get_output_nodes ()
242
268
if len (out_nodes ) > 1 :
243
269
raise pe .PipelineException ("Cannot cross validate as output is not a single node" )
244
270
@@ -267,6 +293,42 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
267
293
return result_scores
268
294
269
295
296
+ def grid_search (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
297
+ pipeline_input_train = dm .PipelineInput ()
298
+
299
+ pipeline_input_test = []
300
+ k = cross_validator .get_n_splits ()
301
+ # add k pipeline inputs for testing
302
+ for i in range (k ):
303
+ pipeline_input_test .append (dm .PipelineInput ())
304
+
305
+ in_args = pipeline_input .get_in_args ()
306
+ for node , xyref_ptrs in in_args .items ():
307
+ # NOTE: The assumption is that this node has only one input!
308
+ xyref_ptr = xyref_ptrs [0 ]
309
+ if len (xyref_ptrs ) > 1 :
310
+ raise pe .PipelineException ("Input to grid search is multiple objects, re-run with only single object" )
311
+
312
+ xy_train_refs_ptr , xy_test_refs_ptr = split .remote (cross_validator , xyref_ptr )
313
+ xy_train_refs = ray .get (xy_train_refs_ptr )
314
+ xy_test_refs = ray .get (xy_test_refs_ptr )
315
+
316
+ for xy_train_ref in xy_train_refs :
317
+ pipeline_input_train .add_xyref_arg (node , xy_train_ref )
318
+
319
+ # for testing, add only to the specific input
320
+ for i in range (k ):
321
+ pipeline_input_test [i ].add_xyref_arg (node , xy_test_refs [i ])
322
+
323
+ # Ready for execution now that data has been prepared! This execution happens in parallel
324
+ # because of the underlying pipeline graph and multiple input objects
325
+ pipeline_output_train = execute_pipeline (pipeline , ExecutionType .FIT , pipeline_input_train )
326
+
327
+ # For grid search, we will have multiple output nodes that need to be iterated on and select the pipeline
328
+ # that is "best"
329
+ out_nodes = pipeline .get_output_nodes ()
330
+
331
+
270
332
def save (pipeline_output : dm .PipelineOutput , xy_ref : dm .XYRef , filehandle ):
271
333
pipeline = select_pipeline (pipeline_output , xy_ref )
272
334
pipeline .save (filehandle )
0 commit comments