Skip to content

Commit a195653

Browse files
Adding key support for grid search, this is internal code right now and a layer for easy craetion similar to GridSearchCV will be added.
1. Migrated CV to use this internal method 2. Tests for grid search internal are now added 3. A few additional pipeline methods added for supporting grid_search techniques
1 parent 626de59 commit a195653

File tree

4 files changed

+525
-60
lines changed

4 files changed

+525
-60
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ def get_prev_xyrefs(self):
122122
"""
123123
return self.__prev_Xyrefs__
124124

125+
def __hash__(self):
126+
return self.__Xref__.__hash__() ^ self.__yref__.__hash__()
127+
128+
def __eq__(self, other):
129+
return (
130+
self.__class__ == other.__class__ and
131+
self.__Xref__ == other.__Xref__ and
132+
self.__yref__ == other.__yref__
133+
)
134+
125135

126136
class NodeInputType(Enum):
127137
OR = 0,
@@ -303,6 +313,18 @@ def __init__(self):
303313
self.__level_nodes__ = None
304314
self.__node_name_map__ = {}
305315

316+
def __hash__(self):
317+
result = 1234
318+
for node in self.__node_name_map__.keys():
319+
result = result ^ node.__hash__()
320+
return result
321+
322+
def __eq__(self, other):
323+
return (
324+
self.__class__ == other.__class__ and
325+
other.__pre_graph__ == self.__pre_graph__
326+
)
327+
306328
def add_node(self, node: Node):
307329
self.__node_levels__ = None
308330
self.__level_nodes__ = None
@@ -459,6 +481,18 @@ def get_input_nodes(self):
459481
def get_node(self, node_name: str) -> Node:
460482
return self.__node_name_map__[node_name]
461483

484+
def has_single_estimator(self):
485+
if len(self.get_output_nodes()) > 1:
486+
return False
487+
488+
for node in self.__node_name_map__.keys():
489+
is_node_estimator = (node.get_node_input_type() == NodeInputType.OR)
490+
if is_node_estimator:
491+
pre_nodes = self.get_pre_nodes(node)
492+
if len(pre_nodes) > 1:
493+
return False
494+
return True
495+
462496
def save(self, filehandle):
463497
nodes = {}
464498
edges = []
@@ -528,6 +562,9 @@ def get_xyrefs(self, node: Node):
528562
def get_edge_args(self):
529563
return self.__edge_args__
530564

565+
def get_out_args(self):
566+
return self.__out_args__
567+
531568

532569
class PipelineInput:
533570
"""

codeflare/pipelines/Runtime.py

Lines changed: 65 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from enum import Enum
99

1010
from queue import SimpleQueue
11+
import pandas as pd
1112

1213

1314
class ExecutionType(Enum):
@@ -219,8 +220,15 @@ def split(cross_validator: BaseCrossValidator, xy_ref):
219220
xy_test_refs = []
220221

221222
for train_index, test_index in cross_validator.split(x, y):
222-
x_train, x_test = x[train_index], x[test_index]
223-
y_train, y_test = y[train_index], y[test_index]
223+
if isinstance(x, pd.DataFrame) or isinstance(x, pd.Series):
224+
x_train, x_test = x.iloc[train_index], x.iloc[test_index]
225+
else:
226+
x_train, x_test = x[train_index], x[test_index]
227+
228+
if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series):
229+
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
230+
else:
231+
y_train, y_test = y[train_index], y[test_index]
224232

225233
x_train_ref = ray.put(x_train)
226234
y_train_ref = ray.put(y_train)
@@ -236,64 +244,22 @@ def split(cross_validator: BaseCrossValidator, xy_ref):
236244

237245

238246
def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput):
239-
pipeline_input_train = dm.PipelineInput()
240-
241-
pipeline_input_test = []
242-
k = cross_validator.get_n_splits()
243-
# add k pipeline inputs for testing
244-
for i in range(k):
245-
pipeline_input_test.append(dm.PipelineInput())
246-
247-
in_args = pipeline_input.get_in_args()
248-
for node, xyref_ptrs in in_args.items():
249-
# NOTE: The assumption is that this node has only one input!
250-
xyref_ptr = xyref_ptrs[0]
251-
xy_train_refs_ptr, xy_test_refs_ptr = split.remote(cross_validator, xyref_ptr)
252-
xy_train_refs = ray.get(xy_train_refs_ptr)
253-
xy_test_refs = ray.get(xy_test_refs_ptr)
254-
255-
for xy_train_ref in xy_train_refs:
256-
pipeline_input_train.add_xyref_arg(node, xy_train_ref)
257-
258-
# for testing, add only to the specific input
259-
for i in range(k):
260-
pipeline_input_test[i].add_xyref_arg(node, xy_test_refs[i])
261-
262-
# Ready for execution now that data has been prepared! This execution happens in parallel
263-
# because of the underlying pipeline graph and multiple input objects
264-
pipeline_output_train = execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input_train)
265-
266-
# Now we can choose the pipeline and then score for each of the chosen pipelines
267-
out_nodes = pipeline.get_output_nodes()
268-
if len(out_nodes) > 1:
269-
raise pe.PipelineException("Cannot cross validate as output is not a single node")
270-
271-
out_node = out_nodes[0]
272-
out_xyref_ptrs = pipeline_output_train.get_xyrefs(out_node)
273-
274-
k = cross_validator.get_n_splits()
275-
if len(out_xyref_ptrs) != k:
276-
raise pe.PipelineException("Number of outputs from pipeline fit is not equal to the folds from cross validator")
277-
278-
pipeline_score_outputs = []
279-
# Below, jobs get submitted and then we can collect the results in the next loop
280-
for i in range(k):
281-
selected_pipeline = select_pipeline(pipeline_output_train, out_xyref_ptrs[i])
282-
selected_pipeline_output = execute_pipeline(selected_pipeline, ExecutionType.SCORE, pipeline_input_test[i])
283-
pipeline_score_outputs.append(selected_pipeline_output)
284-
285-
result_scores = []
286-
for pipeline_score_output in pipeline_score_outputs:
287-
pipeline_out_xyrefs = pipeline_score_output.get_xyrefs(out_node)
288-
# again, only single xyref to be gotten out
289-
pipeline_out_xyref = pipeline_out_xyrefs[0]
290-
out_x = ray.get(pipeline_out_xyref.get_Xref())
291-
result_scores.append(out_x)
247+
has_single_estimator = pipeline.has_single_estimator()
248+
if not has_single_estimator:
249+
raise pe.PipelineException("Cross validation can only be done on pipelines with single estimator, "
250+
"use grid_search_cv instead")
251+
252+
result_grid_search_cv = grid_search_cv(cross_validator, pipeline, pipeline_input)
253+
# only one output here
254+
result_scores = None
255+
for scores in result_grid_search_cv.values():
256+
result_scores = scores
257+
break
292258

293259
return result_scores
294260

295261

296-
def grid_search(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput):
262+
def grid_search_cv(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput):
297263
pipeline_input_train = dm.PipelineInput()
298264

299265
pipeline_input_test = []
@@ -303,18 +269,24 @@ def grid_search(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipe
303269
pipeline_input_test.append(dm.PipelineInput())
304270

305271
in_args = pipeline_input.get_in_args()
272+
# Keep a map from the pointer of train to test
273+
train_test_mapper = {}
274+
306275
for node, xyref_ptrs in in_args.items():
307276
# NOTE: The assumption is that this node has only one input!
308277
xyref_ptr = xyref_ptrs[0]
309278
if len(xyref_ptrs) > 1:
310-
raise pe.PipelineException("Input to grid search is multiple objects, re-run with only single object")
279+
raise pe.PipelineException("Grid search supports single object input only, multiple provided, number is " + str(len(xyref_ptrs)))
311280

312281
xy_train_refs_ptr, xy_test_refs_ptr = split.remote(cross_validator, xyref_ptr)
313282
xy_train_refs = ray.get(xy_train_refs_ptr)
314283
xy_test_refs = ray.get(xy_test_refs_ptr)
315284

316-
for xy_train_ref in xy_train_refs:
285+
for i in range(len(xy_train_refs)):
286+
xy_train_ref = xy_train_refs[i]
287+
xy_test_ref = xy_test_refs[i]
317288
pipeline_input_train.add_xyref_arg(node, xy_train_ref)
289+
train_test_mapper[xy_train_ref] = xy_test_ref
318290

319291
# for testing, add only to the specific input
320292
for i in range(k):
@@ -324,9 +296,42 @@ def grid_search(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipe
324296
# because of the underlying pipeline graph and multiple input objects
325297
pipeline_output_train = execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input_train)
326298

327-
# For grid search, we will have multiple output nodes that need to be iterated on and select the pipeline
328-
# that is "best"
299+
# For grid search, we will have multiple output nodes that need to be iterated on
300+
selected_pipeline_test_outputs = {}
329301
out_nodes = pipeline.get_output_nodes()
302+
for out_node in out_nodes:
303+
out_node_xyrefs = pipeline_output_train.get_xyrefs(out_node)
304+
for out_node_xyref in out_node_xyrefs:
305+
selected_pipeline = select_pipeline(pipeline_output_train, out_node_xyref)
306+
selected_pipeline_input = get_pipeline_input(pipeline, pipeline_output_train, out_node_xyref)
307+
selected_pipeline_inargs = selected_pipeline_input.get_in_args()
308+
test_pipeline_input = dm.PipelineInput()
309+
for node, train_xyref_ptr in selected_pipeline_inargs.items():
310+
# xyrefs is a singleton by construction
311+
train_xyrefs = ray.get(train_xyref_ptr)
312+
test_xyref = train_test_mapper[train_xyrefs[0]]
313+
test_pipeline_input.add_xyref_arg(node, test_xyref)
314+
selected_pipeline_test_output = execute_pipeline(selected_pipeline, ExecutionType.SCORE, test_pipeline_input)
315+
if selected_pipeline not in selected_pipeline_test_outputs.keys():
316+
selected_pipeline_test_outputs[selected_pipeline] = []
317+
selected_pipeline_test_outputs[selected_pipeline].append(selected_pipeline_test_output)
318+
319+
# now, test outputs can be materialized
320+
result_scores = {}
321+
for selected_pipeline, selected_pipeline_test_output_list in selected_pipeline_test_outputs.items():
322+
output_nodes = selected_pipeline.get_output_nodes()
323+
# by design, output_nodes will only have one node
324+
output_node = output_nodes[0]
325+
for selected_pipeline_test_output in selected_pipeline_test_output_list:
326+
pipeline_out_xyrefs = selected_pipeline_test_output.get_xyrefs(output_node)
327+
# again, only single xyref to be gotten out
328+
pipeline_out_xyref = pipeline_out_xyrefs[0]
329+
out_x = ray.get(pipeline_out_xyref.get_Xref())
330+
if selected_pipeline not in result_scores.keys():
331+
result_scores[selected_pipeline] = []
332+
result_scores[selected_pipeline].append(out_x)
333+
334+
return result_scores
330335

331336

332337
def save(pipeline_output: dm.PipelineOutput, xy_ref: dm.XYRef, filehandle):

codeflare/pipelines/tests/test_runtime.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import codeflare.pipelines.Datamodel as dm
44
import codeflare.pipelines.Runtime as rt
55

6+
from sklearn.model_selection import KFold
7+
68

79
def test_runtime_pipeline_input_getter():
810
"""
@@ -41,3 +43,37 @@ def test_runtime_pipeline_input_getter():
4143
assert xyref.get_Xref() == input_xyref.get_Xref()
4244
assert xyref.get_yref() == input_xyref.get_yref()
4345

46+
47+
def test_grid_search():
48+
import ray
49+
ray.shutdown()
50+
ray.init()
51+
52+
import pandas as pd
53+
54+
X_train, X_test, y_train, y_test = test_helper.get_data()
55+
pipeline = test_helper.get_pipeline(X_train)
56+
57+
input_node = pipeline.get_node('preprocess')
58+
59+
pipeline_input = dm.PipelineInput()
60+
xy = dm.Xy(X_train, y_train)
61+
pipeline_input.add_xy_arg(input_node, xy)
62+
63+
kf = KFold(2)
64+
result = rt.grid_search_cv(kf, pipeline, pipeline_input)
65+
node_rf = pipeline.get_node('random_forest')
66+
node_gb = pipeline.get_node('gradient_boost')
67+
# result should have two pipelines, with two scored outputs each
68+
node_rf_pipeline = False
69+
node_gb_pipeline = False
70+
for cv_pipeline, scores in result.items():
71+
out_node = cv_pipeline.get_output_nodes()[0]
72+
if out_node.get_node_name() == node_rf.get_node_name():
73+
node_rf_pipeline = True
74+
elif out_node.get_node_name() == node_gb.get_node_name():
75+
node_gb_pipeline = True
76+
if len(scores) != 2:
77+
assert False
78+
assert node_rf_pipeline
79+
assert node_gb_pipeline

0 commit comments

Comments
 (0)