Skip to content

Commit 626de59

Browse files
Grid search implementation start, adding:
1. Support for getting pipeline input on a given pipeline and the chosen xyref 2. Adding some new methods for easier access to pipeline internals 3. Refactoring pre_image/post_image -- older names 4. Added a test for pipeline_input check
1 parent 47137c1 commit 626de59

File tree

6 files changed

+205
-70
lines changed

6 files changed

+205
-70
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,15 @@ def __init__(self):
301301
self.__post_graph__ = {}
302302
self.__node_levels__ = None
303303
self.__level_nodes__ = None
304+
self.__node_name_map__ = {}
304305

305306
def add_node(self, node: Node):
306307
self.__node_levels__ = None
307308
self.__level_nodes__ = None
308309
if node not in self.__pre_graph__.keys():
309310
self.__pre_graph__[node] = []
310311
self.__post_graph__[node] = []
312+
self.__node_name_map__[node.get_node_name()] = node
311313

312314
def __str__(self):
313315
res = ''
@@ -333,23 +335,17 @@ def add_edge(self, from_node: Node, to_node: Node):
333335
self.__pre_graph__[to_node].append(from_node)
334336
self.__post_graph__[from_node].append(to_node)
335337

336-
def get_preimage(self, node: Node):
337-
return self.__pre_graph__[node]
338-
339-
def get_postimage(self, node: Node):
340-
return self.__post_graph__[node]
341-
342338
def compute_node_level(self, node: Node, result: dict):
343339
if node in result:
344340
return result[node]
345341

346-
node_preimage = self.get_preimage(node)
347-
if not node_preimage:
342+
pre_nodes = self.get_pre_nodes(node)
343+
if not pre_nodes:
348344
result[node] = 0
349345
return 0
350346

351347
max_level = 0
352-
for p_node in node_preimage:
348+
for p_node in pre_nodes:
353349
level = self.compute_node_level(p_node, result)
354350
max_level = max(level, max_level)
355351

@@ -369,6 +365,10 @@ def compute_node_levels(self):
369365

370366
return self.__node_levels__
371367

368+
def get_node_level(self, node: Node):
369+
self.compute_node_levels()
370+
return self.__node_levels__[node]
371+
372372
def compute_max_level(self):
373373
levels = self.compute_node_levels()
374374
max_level = 0
@@ -423,30 +423,42 @@ def get_post_edges(self, node: Node):
423423
post_edges.append(Edge(node, post_node))
424424
return post_edges
425425

426-
def is_terminal(self, node: Node):
427-
post_nodes = self.__post_graph__[node]
426+
def is_output(self, node: Node):
427+
post_nodes = self.get_post_nodes(node)
428428
return not post_nodes
429429

430-
def get_terminal_nodes(self):
430+
def get_output_nodes(self):
431431
# dict from level to nodes
432432
terminal_nodes = []
433433
for node in self.__pre_graph__.keys():
434-
if self.is_terminal(node):
434+
if self.is_output(node):
435435
terminal_nodes.append(node)
436436
return terminal_nodes
437437

438438
def get_nodes(self):
439-
nodes = {}
440-
for node in self.__pre_graph__.keys():
441-
nodes[node.get_node_name()] = node
442-
return nodes
439+
return self.__node_name_map__
443440

444441
def get_pre_nodes(self, node):
445442
return self.__pre_graph__[node]
446443

447444
def get_post_nodes(self, node):
448445
return self.__post_graph__[node]
449446

447+
def is_input(self, node: Node):
448+
pre_nodes = self.get_pre_nodes(node)
449+
return not pre_nodes
450+
451+
def get_input_nodes(self):
452+
input_nodes = []
453+
for node in self.__node_name_map__.values():
454+
if self.get_node_level() == 0:
455+
input_nodes.append(node)
456+
457+
return input_nodes
458+
459+
def get_node(self, node_name: str) -> Node:
460+
return self.__node_name_map__[node_name]
461+
450462
def save(self, filehandle):
451463
nodes = {}
452464
edges = []

codeflare/pipelines/Runtime.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, pipeline_input:
153153
execute_and_node(node, pre_edges, edge_args, post_edges)
154154

155155
out_args = {}
156-
terminal_nodes = pipeline.get_terminal_nodes()
156+
terminal_nodes = pipeline.get_output_nodes()
157157
for terminal_node in terminal_nodes:
158158
edge = dm.Edge(terminal_node, None)
159159
out_args[terminal_node] = edge_args[edge]
@@ -184,6 +184,32 @@ def select_pipeline(pipeline_output: dm.PipelineOutput, chosen_xyref: dm.XYRef):
184184
return pipeline
185185

186186

187+
def get_pipeline_input(pipeline: dm.Pipeline, pipeline_output: dm.PipelineOutput, chosen_xyref: dm.XYRef):
188+
pipeline_input = dm.PipelineInput()
189+
190+
xyref_queue = SimpleQueue()
191+
xyref_queue.put(chosen_xyref)
192+
while not xyref_queue.empty():
193+
curr_xyref = xyref_queue.get()
194+
curr_node_state_ptr = curr_xyref.get_curr_node_state_ref()
195+
curr_node = ray.get(curr_node_state_ptr)
196+
curr_node_level = pipeline.get_node_level(curr_node)
197+
prev_xyrefs = curr_xyref.get_prev_xyrefs()
198+
199+
if curr_node_level == 0:
200+
# This is an input node
201+
for prev_xyref in prev_xyrefs:
202+
pipeline_input.add_xyref_arg(curr_node, prev_xyref)
203+
204+
for prev_xyref in prev_xyrefs:
205+
prev_node_state_ptr = prev_xyref.get_curr_node_state_ref()
206+
if prev_node_state_ptr is None:
207+
continue
208+
xyref_queue.put(prev_xyref)
209+
210+
return pipeline_input
211+
212+
187213
@ray.remote(num_returns=2)
188214
def split(cross_validator: BaseCrossValidator, xy_ref):
189215
x = ray.get(xy_ref.get_Xref())
@@ -220,7 +246,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
220246

221247
in_args = pipeline_input.get_in_args()
222248
for node, xyref_ptrs in in_args.items():
223-
# NOTE: The assumption is that this node has only one input, the check earlier will ensure this!
249+
# NOTE: The assumption is that this node has only one input!
224250
xyref_ptr = xyref_ptrs[0]
225251
xy_train_refs_ptr, xy_test_refs_ptr = split.remote(cross_validator, xyref_ptr)
226252
xy_train_refs = ray.get(xy_train_refs_ptr)
@@ -238,7 +264,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
238264
pipeline_output_train = execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input_train)
239265

240266
# Now we can choose the pipeline and then score for each of the chosen pipelines
241-
out_nodes = pipeline.get_terminal_nodes()
267+
out_nodes = pipeline.get_output_nodes()
242268
if len(out_nodes) > 1:
243269
raise pe.PipelineException("Cannot cross validate as output is not a single node")
244270

@@ -267,6 +293,42 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
267293
return result_scores
268294

269295

296+
def grid_search(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput):
297+
pipeline_input_train = dm.PipelineInput()
298+
299+
pipeline_input_test = []
300+
k = cross_validator.get_n_splits()
301+
# add k pipeline inputs for testing
302+
for i in range(k):
303+
pipeline_input_test.append(dm.PipelineInput())
304+
305+
in_args = pipeline_input.get_in_args()
306+
for node, xyref_ptrs in in_args.items():
307+
# NOTE: The assumption is that this node has only one input!
308+
xyref_ptr = xyref_ptrs[0]
309+
if len(xyref_ptrs) > 1:
310+
raise pe.PipelineException("Input to grid search is multiple objects, re-run with only single object")
311+
312+
xy_train_refs_ptr, xy_test_refs_ptr = split.remote(cross_validator, xyref_ptr)
313+
xy_train_refs = ray.get(xy_train_refs_ptr)
314+
xy_test_refs = ray.get(xy_test_refs_ptr)
315+
316+
for xy_train_ref in xy_train_refs:
317+
pipeline_input_train.add_xyref_arg(node, xy_train_ref)
318+
319+
# for testing, add only to the specific input
320+
for i in range(k):
321+
pipeline_input_test[i].add_xyref_arg(node, xy_test_refs[i])
322+
323+
# Ready for execution now that data has been prepared! This execution happens in parallel
324+
# because of the underlying pipeline graph and multiple input objects
325+
pipeline_output_train = execute_pipeline(pipeline, ExecutionType.FIT, pipeline_input_train)
326+
327+
# For grid search, we will have multiple output nodes that need to be iterated on and select the pipeline
328+
# that is "best"
329+
out_nodes = pipeline.get_output_nodes()
330+
331+
270332
def save(pipeline_output: dm.PipelineOutput, xy_ref: dm.XYRef, filehandle):
271333
pipeline = select_pipeline(pipeline_output, xy_ref)
272334
pipeline.save(filehandle)

codeflare/pipelines/test_Datamodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_multibranch(self):
7676
pipeline_input = dm.PipelineInput()
7777
pipeline_input.add_xy_arg(node_a, dm.Xy(X_train, y_train))
7878

79-
terminal_nodes = pipeline.get_terminal_nodes()
79+
terminal_nodes = pipeline.get_output_nodes()
8080
assert len(terminal_nodes) == 4
8181

8282
## execute the codeflare pipeline
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import codeflare.pipelines.Datamodel as dm
2+
3+
import pandas as pd
4+
from sklearn.pipeline import Pipeline
5+
from sklearn.impute import SimpleImputer
6+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
7+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
8+
9+
10+
def get_pipeline(train) -> dm.Pipeline:
11+
imputer = SimpleImputer(strategy='median')
12+
scaler = StandardScaler()
13+
14+
numeric_transformer = Pipeline(steps=[
15+
('imputer', imputer),
16+
('scaler', scaler)])
17+
18+
cat_imputer = SimpleImputer(strategy='constant', fill_value='missing')
19+
cat_onehot = OneHotEncoder(handle_unknown='ignore')
20+
21+
categorical_transformer = Pipeline(steps=[
22+
('imputer', cat_imputer),
23+
('onehot', cat_onehot)])
24+
numeric_features = train.select_dtypes(include=['int64', 'float64']).columns
25+
categorical_features = train.select_dtypes(include=['object']).columns
26+
from sklearn.compose import ColumnTransformer
27+
preprocessor = ColumnTransformer(
28+
transformers=[
29+
('num', numeric_transformer, numeric_features),
30+
('cat', categorical_transformer, categorical_features)])
31+
32+
classifiers = [
33+
RandomForestClassifier(),
34+
GradientBoostingClassifier()
35+
]
36+
pipeline = dm.Pipeline()
37+
node_pre = dm.EstimatorNode('preprocess', preprocessor)
38+
node_rf = dm.EstimatorNode('random_forest', classifiers[0])
39+
node_gb = dm.EstimatorNode('gradient_boost', classifiers[1])
40+
41+
pipeline.add_edge(node_pre, node_rf)
42+
pipeline.add_edge(node_pre, node_gb)
43+
44+
return pipeline
45+
46+
47+
def get_data():
48+
train = pd.read_csv('../../../resources/data/train_ctrUa4K.csv')
49+
train = train.drop('Loan_ID', axis=1)
50+
51+
X = train.drop('Loan_Status', axis=1)
52+
y = train['Loan_Status']
53+
from sklearn.model_selection import train_test_split
54+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
55+
56+
return X_train, X_test, y_train, y_test
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from codeflare.pipelines.tests import test_helper
2+
3+
import codeflare.pipelines.Datamodel as dm
4+
import codeflare.pipelines.Runtime as rt
5+
6+
7+
def test_runtime_pipeline_input_getter():
8+
"""
9+
A test to get the pipeline inputs after a selection is done
10+
:return:
11+
"""
12+
13+
import ray
14+
ray.shutdown()
15+
ray.init()
16+
X_train, X_test, y_train, y_test = test_helper.get_data()
17+
pipeline = test_helper.get_pipeline(X_train)
18+
19+
node_rf = pipeline.get_node('random_forest')
20+
node_gb = pipeline.get_node('gradient_boost')
21+
input_node = pipeline.get_node('preprocess')
22+
23+
pipeline_input = dm.PipelineInput()
24+
xy = dm.Xy(X_train, y_train)
25+
pipeline_input.add_xy_arg(input_node, xy)
26+
27+
pipeline_output = rt.execute_pipeline(pipeline, rt.ExecutionType.FIT, pipeline_input)
28+
node_rf_xyrefs = pipeline_output.get_xyrefs(node_rf)
29+
30+
selected_pipeline_input = rt.get_pipeline_input(pipeline, pipeline_output, node_rf_xyrefs[0])
31+
in_args = selected_pipeline_input.get_in_args()
32+
is_input_node_present = (input_node in in_args.keys())
33+
assert is_input_node_present
34+
35+
# check if the XYref is the same
36+
xyref_ptrs = in_args[input_node]
37+
xyref_ptr = xyref_ptrs[0]
38+
xyref = ray.get(xyref_ptr)
39+
40+
input_xyref = ray.get(pipeline_input.get_in_args()[input_node][0])
41+
assert xyref.get_Xref() == input_xyref.get_Xref()
42+
assert xyref.get_yref() == input_xyref.get_yref()
43+

0 commit comments

Comments
 (0)