@@ -23,11 +23,11 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
23
23
# Blocking operation -- not avoidable
24
24
X = ray .get (xy_ref .get_Xref ())
25
25
y = ray .get (xy_ref .get_yref ())
26
+ prev_node_ptr = ray .put (node )
26
27
27
28
# TODO: Can optimize the node pointers without replicating them
28
29
if mode == ExecutionType .FIT :
29
30
cloned_node = node .clone ()
30
- prev_node_ptr = ray .put (node )
31
31
32
32
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
33
33
# Always clone before fit, else fit is invalid
@@ -49,22 +49,22 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
49
49
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
50
50
estimator = node .get_estimator ()
51
51
res_Xref = ray .put (estimator .score (X , y ))
52
- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
52
+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
53
53
return result
54
54
else :
55
55
res_Xref = ray .put (estimator .transform (X ))
56
- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
56
+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
57
57
58
58
return result
59
59
elif mode == ExecutionType .PREDICT :
60
60
# Test mode does not clone as it is a simple predict or transform
61
61
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
62
62
res_Xref = ray .put (estimator .predict (X ))
63
- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
63
+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
64
64
return result
65
65
else :
66
66
res_Xref = ray .put (estimator .transform (X ))
67
- result = dm .XYRef (res_Xref , xy_ref .get_yref ())
67
+ result = dm .XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , prev_node_ptr , [ xy_ref ] )
68
68
return result
69
69
70
70
@@ -84,38 +84,88 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
84
84
85
85
86
86
@ray .remote
87
- def execute_and_node_remote (node : dm .AndNode , Xyref_list ):
87
+ def execute_and_node_remote (node : dm .AndNode , mode : ExecutionType , Xyref_list ):
88
88
xy_list = []
89
89
prev_node_ptr = ray .put (node )
90
90
for Xyref in Xyref_list :
91
91
X = ray .get (Xyref .get_Xref ())
92
92
y = ray .get (Xyref .get_yref ())
93
93
xy_list .append (dm .Xy (X , y ))
94
94
95
- cloned_node = node .clone ()
96
- curr_node_ptr = ray .put (cloned_node )
95
+ estimator = node .get_estimator ()
96
+
97
+ # TODO: Can optimize the node pointers without replicating them
98
+ if mode == ExecutionType .FIT :
99
+ cloned_node = node .clone ()
100
+
101
+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
102
+ # Always clone before fit, else fit is invalid
103
+ cloned_estimator = cloned_node .get_estimator ()
104
+ cloned_estimator .fit (xy_list )
97
105
98
- cloned_and_func = cloned_node .get_and_func ()
99
- res_Xy = cloned_and_func .transform (xy_list )
100
- res_Xref = ray .put (res_Xy .get_x ())
101
- res_yref = ray .put (res_Xy .get_y ())
102
- return dm .XYRef (res_Xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
106
+ curr_node_ptr = ray .put (cloned_node )
107
+ res_xy = cloned_estimator .predict (xy_list )
108
+ res_xref = ray .put (res_xy .get_x ())
109
+ res_yref = ray .put (res_xy .get_y ())
103
110
111
+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
112
+ return result
113
+ else :
114
+ cloned_estimator = cloned_node .get_estimator ()
115
+ res_xy = cloned_estimator .fit_transform (xy_list )
116
+ res_xref = ray .put (res_xy .get_x ())
117
+ res_yref = ray .put (res_xy .get_y ())
104
118
105
- def execute_and_node_inner (node : dm .AndNode , Xyref_ptrs ):
119
+ curr_node_ptr = ray .put (cloned_node )
120
+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
121
+ return result
122
+ elif mode == ExecutionType .SCORE :
123
+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
124
+ estimator = node .get_estimator ()
125
+ res_xy = estimator .score (xy_list )
126
+ res_xref = ray .put (res_xy .get_x ())
127
+ res_yref = ray .put (res_xy .get_y ())
128
+
129
+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
130
+ return result
131
+ else :
132
+ res_xy = estimator .transform (xy_list )
133
+ res_xref = ray .put (res_xy .get_x ())
134
+ res_yref = ray .put (res_xy .get_y ())
135
+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
136
+
137
+ return result
138
+ elif mode == ExecutionType .PREDICT :
139
+ # Test mode does not clone as it is a simple predict or transform
140
+ if base .is_classifier (estimator ) or base .is_regressor (estimator ):
141
+ res_xy = estimator .predict (xy_list )
142
+ res_xref = ray .put (res_xy .get_x ())
143
+ res_yref = ray .put (res_xy .get_y ())
144
+
145
+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
146
+ return result
147
+ else :
148
+ res_xy = estimator .transform (xy_list )
149
+ res_xref = ray .put (res_xy .get_x ())
150
+ res_yref = ray .put (res_xy .get_y ())
151
+ result = dm .XYRef (res_xref , res_yref , prev_node_ptr , prev_node_ptr , Xyref_list )
152
+ return result
153
+
154
+
155
+ def execute_and_node_inner (node : dm .AndNode , mode : ExecutionType , Xyref_ptrs ):
106
156
result = []
107
157
108
158
Xyref_list = []
109
159
for Xyref_ptr in Xyref_ptrs :
110
160
Xyref = ray .get (Xyref_ptr )
111
161
Xyref_list .append (Xyref )
112
162
113
- Xyref_ptr = execute_and_node_remote .remote (node , Xyref_list )
163
+ Xyref_ptr = execute_and_node_remote .remote (node , mode , Xyref_list )
114
164
result .append (Xyref_ptr )
115
165
return result
116
166
117
167
118
- def execute_and_node (node , pre_edges , edge_args , post_edges ):
168
+ def execute_and_node (node , pre_edges , edge_args , post_edges , mode : ExecutionType ):
119
169
edge_args_lists = list ()
120
170
for pre_edge in pre_edges :
121
171
edge_args_lists .append (edge_args [pre_edge ])
@@ -125,7 +175,7 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
125
175
cross_product = itertools .product (* edge_args_lists )
126
176
127
177
for element in cross_product :
128
- exec_xyref_ptrs = execute_and_node_inner (node , element )
178
+ exec_xyref_ptrs = execute_and_node_inner (node , mode , element )
129
179
for post_edge in post_edges :
130
180
if post_edge not in edge_args .keys ():
131
181
edge_args [post_edge ] = []
@@ -151,7 +201,7 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, pipeline_input:
151
201
if node .get_node_input_type () == dm .NodeInputType .OR :
152
202
execute_or_node (node , pre_edges , edge_args , post_edges , mode )
153
203
elif node .get_node_input_type () == dm .NodeInputType .AND :
154
- execute_and_node (node , pre_edges , edge_args , post_edges )
204
+ execute_and_node (node , pre_edges , edge_args , post_edges , mode )
155
205
156
206
out_args = {}
157
207
terminal_nodes = pipeline .get_output_nodes ()
@@ -249,7 +299,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
249
299
raise pe .PipelineException ("Cross validation can only be done on pipelines with single estimator, "
250
300
"use grid_search_cv instead" )
251
301
252
- result_grid_search_cv = grid_search_cv (cross_validator , pipeline , pipeline_input )
302
+ result_grid_search_cv = _grid_search_cv (cross_validator , pipeline , pipeline_input )
253
303
# only one output here
254
304
result_scores = None
255
305
for scores in result_grid_search_cv .values ():
@@ -259,7 +309,13 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
259
309
return result_scores
260
310
261
311
262
- def grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
312
+ def grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput , pipeline_params : dm .PipelineParam ):
313
+ parameterized_pipeline = pipeline .get_parameterized_pipeline (pipeline_params )
314
+ parameterized_pipeline_input = pipeline_input .get_parameterized_input (pipeline , parameterized_pipeline )
315
+ return _grid_search_cv (cross_validator , parameterized_pipeline , parameterized_pipeline_input )
316
+
317
+
318
+ def _grid_search_cv (cross_validator : BaseCrossValidator , pipeline : dm .Pipeline , pipeline_input : dm .PipelineInput ):
263
319
pipeline_input_train = dm .PipelineInput ()
264
320
265
321
pipeline_input_test = []
0 commit comments