Skip to content

Commit be7f237

Browse files
Adding grid search in runtime that provides a similar API as GridSearchCV, but a richer one that can explore non-gridded spaces as well.
Tests have been added to cover this Modifications to datamodel have been done to accommodate for grid search
1 parent a20a9c4 commit be7f237

File tree

9 files changed

+450
-386
lines changed

9 files changed

+450
-386
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ dist/
55
ray-graphs.egginfo/
66
.idea/
77
.ipynb_checkpoints/
8+
*__pycache*

codeflare/pipelines/Datamodel.py

Lines changed: 82 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from enum import Enum
33

44
import sklearn.base as base
5-
from codeflare.pipelines.Datamodel import PipelineParam
6-
from sklearn.base import TransformerMixin
75
from sklearn.base import BaseEstimator
86
from sklearn.model_selection import ParameterGrid
97

@@ -159,16 +157,17 @@ class Node(ABC):
159157
node name and the type of the node match.
160158
"""
161159

162-
def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type: NodeFiringType, node_state_type: NodeStateType):
163-
if '__' in node_name:
164-
raise pe.PipelineException("Node name cannot have __, please rename")
160+
def __init__(self, node_name, estimator: BaseEstimator, node_input_type: NodeInputType, node_firing_type: NodeFiringType, node_state_type: NodeStateType):
165161
self.__node_name__ = node_name
162+
self.__estimator__ = estimator
166163
self.__node_input_type__ = node_input_type
167164
self.__node_firing_type__ = node_firing_type
168165
self.__node_state_type__ = node_state_type
169166

170167
def __str__(self):
171-
return self.__node_name__
168+
estimator_params_str = str(self.get_estimator().get_params())
169+
retval = self.__node_name__ + estimator_params_str
170+
return retval
172171

173172
def get_node_name(self):
174173
return self.__node_name__
@@ -182,6 +181,16 @@ def get_node_firing_type(self):
182181
def get_node_state_type(self):
183182
return self.__node_state_type__
184183

184+
def get_estimator(self):
185+
return self.__estimator__
186+
187+
def get_parameterized_node(self, node_name, **params):
188+
cloned_node = self.clone()
189+
cloned_node.__node_name__ = node_name
190+
estimator = cloned_node.get_estimator()
191+
estimator.set_params(**params)
192+
return cloned_node
193+
185194
@abstractmethod
186195
def clone(self):
187196
raise NotImplementedError("Please implement the clone method")
@@ -222,44 +231,50 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
222231
:param estimator: The base estimator
223232
"""
224233

225-
super().__init__(node_name, NodeInputType.OR, NodeFiringType.ANY, NodeStateType.IMMUTABLE)
226-
self.__estimator__ = estimator
227-
228-
def get_estimator(self) -> BaseEstimator:
229-
"""
230-
Return the estimator that this was initialize with
231-
232-
:return: Estimator
233-
"""
234-
return self.__estimator__
234+
super().__init__(node_name, estimator, NodeInputType.OR, NodeFiringType.ANY, NodeStateType.IMMUTABLE)
235235

236236
def clone(self):
237237
cloned_estimator = base.clone(self.__estimator__)
238238
return EstimatorNode(self.__node_name__, cloned_estimator)
239239

240240

241-
class AndTransform(TransformerMixin, BaseEstimator):
241+
class AndEstimator(BaseEstimator):
242242
@abstractmethod
243243
def transform(self, xy_list: list) -> Xy:
244-
raise NotImplementedError("Please implement this method")
244+
raise NotImplementedError("And estimator needs to implement a transform method")
245+
246+
@abstractmethod
247+
def fit(self, xy_list: list):
248+
raise NotImplementedError("And estimator needs to implement a fit method")
245249

250+
@abstractmethod
251+
def fit_transform(self, xy_list: list):
252+
raise NotImplementedError("And estimator needs to implement a fit method")
246253

247-
class GeneralTransform(TransformerMixin, BaseEstimator):
248254
@abstractmethod
249-
def transform(self, xy: Xy) -> Xy:
250-
raise NotImplementedError("Please implement this method")
255+
def predict(self, xy_list: list) -> Xy:
256+
raise NotImplementedError("And classifier needs to implement the predict method")
251257

258+
@abstractmethod
259+
def score(self, xy_list: list) -> Xy:
260+
raise NotImplementedError("And classifier needs to implement the score method")
252261

253-
class AndNode(Node):
254-
def __init__(self, node_name: str, and_func: AndTransform):
255-
super().__init__(node_name, NodeInputType.AND, NodeFiringType.ANY, NodeStateType.STATELESS)
256-
self.__andfunc__ = and_func
262+
@abstractmethod
263+
def get_estimator_type(self):
264+
raise NotImplementedError("And classifier needs to implement the get_estimator_type method")
265+
266+
@abstractmethod
267+
def clone(self):
268+
raise NotImplementedError("And estimator needs to implement a clone method")
257269

258-
def get_and_func(self) -> AndTransform:
259-
return self.__andfunc__
270+
271+
class AndNode(Node):
272+
def __init__(self, node_name: str, and_estimator: AndEstimator):
273+
super().__init__(node_name, and_estimator, NodeInputType.AND, NodeFiringType.ANY, NodeStateType.STATELESS)
260274

261275
def clone(self):
262-
return AndNode(self.__node_name__, self.__andfunc__)
276+
cloned_estimator = self.__estimator__.clone()
277+
return AndNode(self.__node_name__, cloned_estimator)
263278

264279

265280
class Edge:
@@ -477,7 +492,7 @@ def is_input(self, node: Node):
477492
def get_input_nodes(self):
478493
input_nodes = []
479494
for node in self.__node_name_map__.values():
480-
if self.get_node_level() == 0:
495+
if self.get_node_level(node) == 0:
481496
input_nodes.append(node)
482497

483498
return input_nodes
@@ -514,7 +529,7 @@ def save(self, filehandle):
514529
saved_pipeline = _SavedPipeline(nodes, edges)
515530
pickle.dump(saved_pipeline, filehandle)
516531

517-
def set_param_grid(self, pipeline_param: PipelineParam):
532+
def get_parameterized_pipeline(self, pipeline_param):
518533
result = Pipeline()
519534
pipeline_params = pipeline_param.get_all_params()
520535
parameterized_nodes = {}
@@ -523,14 +538,27 @@ def set_param_grid(self, pipeline_param: PipelineParam):
523538
if node_name_part not in parameterized_nodes.keys():
524539
parameterized_nodes[node_name_part] = []
525540
node = self.__node_name_map__[node_name_part]
526-
estimator = node.get_estimator()
527-
cloned_estimator = estimator.clone()
528-
cloned_estimator.set_params(**params)
541+
parameterized_node = node.get_parameterized_node(node_name, **params)
542+
parameterized_nodes[node_name_part].append(parameterized_node)
529543

530-
parameterized_nodes[node_name_part].append()
531-
result.add_node()
544+
# update parameterized nodes with missing non-parameterized nodes for completeness
545+
for node in self.__pre_graph__.keys():
546+
node_name = node.get_node_name()
547+
if node_name not in parameterized_nodes.keys():
548+
parameterized_nodes[node_name] = [node]
549+
550+
# loop through the graph and add edges
551+
for node, pre_nodes in self.__pre_graph__.items():
552+
node_name = node.get_node_name()
553+
expanded_nodes = parameterized_nodes[node_name]
554+
for pre_node in pre_nodes:
555+
pre_node_name = pre_node.get_node_name()
556+
expanded_pre_nodes = parameterized_nodes[pre_node_name]
557+
for expanded_pre_node in expanded_pre_nodes:
558+
for expanded_node in expanded_nodes:
559+
result.add_edge(expanded_pre_node, expanded_node)
532560

533-
# construct nodes
561+
return result
534562

535563
@staticmethod
536564
def load(filehandle):
@@ -617,9 +645,27 @@ def add_xy_arg(self, node: Node, xy: Xy):
617645
xyref = XYRef(x_ref, y_ref)
618646
self.add_xyref_arg(node, xyref)
619647

648+
def add_all(self, node, node_inargs):
649+
self.__in_args__[node] = node_inargs
650+
620651
def get_in_args(self):
621652
return self.__in_args__
622653

654+
def get_parameterized_input(self, pipeline: Pipeline, parameterized_pipeline: Pipeline):
655+
input_nodes = parameterized_pipeline.get_input_nodes()
656+
parameterized_pipeline_input = PipelineInput()
657+
for input_node in input_nodes:
658+
input_node_name = input_node.get_node_name()
659+
if '__' not in input_node_name:
660+
node_name = input_node_name
661+
else:
662+
node_name, param = input_node.get_node_name().split('__', 1)
663+
664+
pipeline_node = pipeline.get_node(node_name)
665+
if pipeline_node in self.__in_args__:
666+
parameterized_pipeline_input.add_all(input_node, self.__in_args__[pipeline_node])
667+
return parameterized_pipeline_input
668+
623669

624670
class PipelineParam:
625671
def __init__(self):

codeflare/pipelines/Runtime.py

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
2323
# Blocking operation -- not avoidable
2424
X = ray.get(xy_ref.get_Xref())
2525
y = ray.get(xy_ref.get_yref())
26+
prev_node_ptr = ray.put(node)
2627

2728
# TODO: Can optimize the node pointers without replicating them
2829
if mode == ExecutionType.FIT:
2930
cloned_node = node.clone()
30-
prev_node_ptr = ray.put(node)
3131

3232
if base.is_classifier(estimator) or base.is_regressor(estimator):
3333
# Always clone before fit, else fit is invalid
@@ -49,22 +49,22 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
4949
if base.is_classifier(estimator) or base.is_regressor(estimator):
5050
estimator = node.get_estimator()
5151
res_Xref = ray.put(estimator.score(X, y))
52-
result = dm.XYRef(res_Xref, xy_ref.get_yref())
52+
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
5353
return result
5454
else:
5555
res_Xref = ray.put(estimator.transform(X))
56-
result = dm.XYRef(res_Xref, xy_ref.get_yref())
56+
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
5757

5858
return result
5959
elif mode == ExecutionType.PREDICT:
6060
# Test mode does not clone as it is a simple predict or transform
6161
if base.is_classifier(estimator) or base.is_regressor(estimator):
6262
res_Xref = ray.put(estimator.predict(X))
63-
result = dm.XYRef(res_Xref, xy_ref.get_yref())
63+
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
6464
return result
6565
else:
6666
res_Xref = ray.put(estimator.transform(X))
67-
result = dm.XYRef(res_Xref, xy_ref.get_yref())
67+
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
6868
return result
6969

7070

@@ -84,38 +84,88 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
8484

8585

8686
@ray.remote
87-
def execute_and_node_remote(node: dm.AndNode, Xyref_list):
87+
def execute_and_node_remote(node: dm.AndNode, mode: ExecutionType, Xyref_list):
8888
xy_list = []
8989
prev_node_ptr = ray.put(node)
9090
for Xyref in Xyref_list:
9191
X = ray.get(Xyref.get_Xref())
9292
y = ray.get(Xyref.get_yref())
9393
xy_list.append(dm.Xy(X, y))
9494

95-
cloned_node = node.clone()
96-
curr_node_ptr = ray.put(cloned_node)
95+
estimator = node.get_estimator()
96+
97+
# TODO: Can optimize the node pointers without replicating them
98+
if mode == ExecutionType.FIT:
99+
cloned_node = node.clone()
100+
101+
if base.is_classifier(estimator) or base.is_regressor(estimator):
102+
# Always clone before fit, else fit is invalid
103+
cloned_estimator = cloned_node.get_estimator()
104+
cloned_estimator.fit(xy_list)
97105

98-
cloned_and_func = cloned_node.get_and_func()
99-
res_Xy = cloned_and_func.transform(xy_list)
100-
res_Xref = ray.put(res_Xy.get_x())
101-
res_yref = ray.put(res_Xy.get_y())
102-
return dm.XYRef(res_Xref, res_yref, prev_node_ptr, curr_node_ptr, Xyref_list)
106+
curr_node_ptr = ray.put(cloned_node)
107+
res_xy = cloned_estimator.predict(xy_list)
108+
res_xref = ray.put(res_xy.get_x())
109+
res_yref = ray.put(res_xy.get_y())
103110

111+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, curr_node_ptr, Xyref_list)
112+
return result
113+
else:
114+
cloned_estimator = cloned_node.get_estimator()
115+
res_xy = cloned_estimator.fit_transform(xy_list)
116+
res_xref = ray.put(res_xy.get_x())
117+
res_yref = ray.put(res_xy.get_y())
104118

105-
def execute_and_node_inner(node: dm.AndNode, Xyref_ptrs):
119+
curr_node_ptr = ray.put(cloned_node)
120+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, curr_node_ptr, Xyref_list)
121+
return result
122+
elif mode == ExecutionType.SCORE:
123+
if base.is_classifier(estimator) or base.is_regressor(estimator):
124+
estimator = node.get_estimator()
125+
res_xy = estimator.score(xy_list)
126+
res_xref = ray.put(res_xy.get_x())
127+
res_yref = ray.put(res_xy.get_y())
128+
129+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
130+
return result
131+
else:
132+
res_xy = estimator.transform(xy_list)
133+
res_xref = ray.put(res_xy.get_x())
134+
res_yref = ray.put(res_xy.get_y())
135+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
136+
137+
return result
138+
elif mode == ExecutionType.PREDICT:
139+
# Test mode does not clone as it is a simple predict or transform
140+
if base.is_classifier(estimator) or base.is_regressor(estimator):
141+
res_xy = estimator.predict(xy_list)
142+
res_xref = ray.put(res_xy.get_x())
143+
res_yref = ray.put(res_xy.get_y())
144+
145+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
146+
return result
147+
else:
148+
res_xy = estimator.transform(xy_list)
149+
res_xref = ray.put(res_xy.get_x())
150+
res_yref = ray.put(res_xy.get_y())
151+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
152+
return result
153+
154+
155+
def execute_and_node_inner(node: dm.AndNode, mode: ExecutionType, Xyref_ptrs):
106156
result = []
107157

108158
Xyref_list = []
109159
for Xyref_ptr in Xyref_ptrs:
110160
Xyref = ray.get(Xyref_ptr)
111161
Xyref_list.append(Xyref)
112162

113-
Xyref_ptr = execute_and_node_remote.remote(node, Xyref_list)
163+
Xyref_ptr = execute_and_node_remote.remote(node, mode, Xyref_list)
114164
result.append(Xyref_ptr)
115165
return result
116166

117167

118-
def execute_and_node(node, pre_edges, edge_args, post_edges):
168+
def execute_and_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType):
119169
edge_args_lists = list()
120170
for pre_edge in pre_edges:
121171
edge_args_lists.append(edge_args[pre_edge])
@@ -125,7 +175,7 @@ def execute_and_node(node, pre_edges, edge_args, post_edges):
125175
cross_product = itertools.product(*edge_args_lists)
126176

127177
for element in cross_product:
128-
exec_xyref_ptrs = execute_and_node_inner(node, element)
178+
exec_xyref_ptrs = execute_and_node_inner(node, mode, element)
129179
for post_edge in post_edges:
130180
if post_edge not in edge_args.keys():
131181
edge_args[post_edge] = []
@@ -151,7 +201,7 @@ def execute_pipeline(pipeline: dm.Pipeline, mode: ExecutionType, pipeline_input:
151201
if node.get_node_input_type() == dm.NodeInputType.OR:
152202
execute_or_node(node, pre_edges, edge_args, post_edges, mode)
153203
elif node.get_node_input_type() == dm.NodeInputType.AND:
154-
execute_and_node(node, pre_edges, edge_args, post_edges)
204+
execute_and_node(node, pre_edges, edge_args, post_edges, mode)
155205

156206
out_args = {}
157207
terminal_nodes = pipeline.get_output_nodes()
@@ -249,7 +299,7 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
249299
raise pe.PipelineException("Cross validation can only be done on pipelines with single estimator, "
250300
"use grid_search_cv instead")
251301

252-
result_grid_search_cv = grid_search_cv(cross_validator, pipeline, pipeline_input)
302+
result_grid_search_cv = _grid_search_cv(cross_validator, pipeline, pipeline_input)
253303
# only one output here
254304
result_scores = None
255305
for scores in result_grid_search_cv.values():
@@ -259,7 +309,13 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
259309
return result_scores
260310

261311

262-
def grid_search_cv(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput):
312+
def grid_search_cv(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput, pipeline_params: dm.PipelineParam):
313+
parameterized_pipeline = pipeline.get_parameterized_pipeline(pipeline_params)
314+
parameterized_pipeline_input = pipeline_input.get_parameterized_input(pipeline, parameterized_pipeline)
315+
return _grid_search_cv(cross_validator, parameterized_pipeline, parameterized_pipeline_input)
316+
317+
318+
def _grid_search_cv(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, pipeline_input: dm.PipelineInput):
263319
pipeline_input_train = dm.PipelineInput()
264320

265321
pipeline_input_test = []

0 commit comments

Comments
 (0)