Skip to content

Commit a20a9c4

Browse files
Adding a proper API for grid search CV, in the midst of change
1 parent a195653 commit a20a9c4

File tree

5 files changed

+176
-28
lines changed

5 files changed

+176
-28
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from enum import Enum
33

44
import sklearn.base as base
5+
from codeflare.pipelines.Datamodel import PipelineParam
56
from sklearn.base import TransformerMixin
67
from sklearn.base import BaseEstimator
8+
from sklearn.model_selection import ParameterGrid
79

810
import ray
911
import pickle5 as pickle
@@ -158,6 +160,8 @@ class Node(ABC):
158160
"""
159161

160162
def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type: NodeFiringType, node_state_type: NodeStateType):
163+
if '__' in node_name:
164+
raise pe.PipelineException("Node name cannot have __, please rename")
161165
self.__node_name__ = node_name
162166
self.__node_input_type__ = node_input_type
163167
self.__node_firing_type__ = node_firing_type
@@ -510,6 +514,24 @@ def save(self, filehandle):
510514
saved_pipeline = _SavedPipeline(nodes, edges)
511515
pickle.dump(saved_pipeline, filehandle)
512516

517+
def set_param_grid(self, pipeline_param: PipelineParam):
518+
result = Pipeline()
519+
pipeline_params = pipeline_param.get_all_params()
520+
parameterized_nodes = {}
521+
for node_name, params in pipeline_params.items():
522+
node_name_part, num = node_name.split('__', 1)
523+
if node_name_part not in parameterized_nodes.keys():
524+
parameterized_nodes[node_name_part] = []
525+
node = self.__node_name_map__[node_name_part]
526+
estimator = node.get_estimator()
527+
cloned_estimator = estimator.clone()
528+
cloned_estimator.set_params(**params)
529+
530+
parameterized_nodes[node_name_part].append()
531+
result.add_node()
532+
533+
# construct nodes
534+
513535
@staticmethod
514536
def load(filehandle):
515537
saved_pipeline = pickle.load(filehandle)
@@ -597,3 +619,46 @@ def add_xy_arg(self, node: Node, xy: Xy):
597619

598620
def get_in_args(self):
599621
return self.__in_args__
622+
623+
624+
class PipelineParam:
625+
def __init__(self):
626+
self.__node_name_param_map__ = {}
627+
628+
@staticmethod
629+
def from_param_grid(fit_params: dict):
630+
pipeline_param = PipelineParam()
631+
fit_params_nodes = {}
632+
for pname, pval in fit_params.items():
633+
if '__' not in pname:
634+
raise ValueError(
635+
"Pipeline.fit does not accept the {} parameter. "
636+
"You can pass parameters to specific steps of your "
637+
"pipeline using the stepname__parameter format, e.g. "
638+
"`Pipeline.fit(X, y, logisticregression__sample_weight"
639+
"=sample_weight)`.".format(pname))
640+
node_name, param = pname.split('__', 1)
641+
if node_name not in fit_params_nodes.keys():
642+
fit_params_nodes[node_name] = {}
643+
644+
fit_params_nodes[node_name][param] = pval
645+
646+
# we have the split based on convention, now to create paramter grid for each node
647+
for node_name, param in fit_params_nodes.items():
648+
pg = ParameterGrid(param)
649+
pg_list = list(pg)
650+
for i in range(len(pg_list)):
651+
p = pg_list[i]
652+
curr_node_name = node_name + '__' + str(i)
653+
pipeline_param.add_param(curr_node_name, p)
654+
655+
return pipeline_param
656+
657+
def add_param(self, node_name: str, params: dict):
658+
self.__node_name_param_map__[node_name] = params
659+
660+
def get_param(self, node_name: str):
661+
return self.__node_name_param_map__[node_name]
662+
663+
def get_all_params(self):
664+
return self.__node_name_param_map__
File renamed without changes.

codeflare/pipelines/tests/test_runtime.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def test_grid_search():
6060
xy = dm.Xy(X_train, y_train)
6161
pipeline_input.add_xy_arg(input_node, xy)
6262

63-
kf = KFold(2)
63+
k = 2
64+
kf = KFold(k)
6465
result = rt.grid_search_cv(kf, pipeline, pipeline_input)
6566
node_rf = pipeline.get_node('random_forest')
6667
node_gb = pipeline.get_node('gradient_boost')
@@ -73,7 +74,7 @@ def test_grid_search():
7374
node_rf_pipeline = True
7475
elif out_node.get_node_name() == node_gb.get_node_name():
7576
node_gb_pipeline = True
76-
if len(scores) != 2:
77+
if len(scores) != k:
7778
assert False
7879
assert node_rf_pipeline
7980
assert node_gb_pipeline

codeflare_pipelines.egg-info/SOURCES.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ codeflare/pipelines/Datamodel.py
55
codeflare/pipelines/Exceptions.py
66
codeflare/pipelines/Runtime.py
77
codeflare/pipelines/__init__.py
8-
codeflare/pipelines/test_Datamodel.py
98
codeflare_pipelines.egg-info/PKG-INFO
109
codeflare_pipelines.egg-info/SOURCES.txt
1110
codeflare_pipelines.egg-info/dependency_links.txt

notebooks/Untitled.ipynb

Lines changed: 108 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 3,
15+
"execution_count": 2,
1616
"id": "da96167d",
1717
"metadata": {},
1818
"outputs": [],
@@ -22,20 +22,75 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": 4,
25+
"execution_count": 17,
2626
"id": "a559b7bb",
2727
"metadata": {},
2828
"outputs": [],
2929
"source": [
3030
"param_grid = {\n",
3131
" 'pca__n_components': [5, 15, 30, 45, 64],\n",
32+
" 'pca__m_components': [6, 10],\n",
3233
" 'logistic__C': np.logspace(-4, 4, 4),\n",
3334
"}"
3435
]
3536
},
3637
{
3738
"cell_type": "code",
38-
"execution_count": 5,
39+
"execution_count": 18,
40+
"id": "4e1ff38c",
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"import codeflare.pipelines.Datamodel as dm"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 19,
50+
"id": "4691ea01",
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"pipeline_params = dm.PipelineParam.from_param_grid(param_grid)"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 20,
60+
"id": "1da83403",
61+
"metadata": {},
62+
"outputs": [
63+
{
64+
"data": {
65+
"text/plain": [
66+
"{'pca__0': {'m_components': 6, 'n_components': 5},\n",
67+
" 'pca__1': {'m_components': 6, 'n_components': 15},\n",
68+
" 'pca__2': {'m_components': 6, 'n_components': 30},\n",
69+
" 'pca__3': {'m_components': 6, 'n_components': 45},\n",
70+
" 'pca__4': {'m_components': 6, 'n_components': 64},\n",
71+
" 'pca__5': {'m_components': 10, 'n_components': 5},\n",
72+
" 'pca__6': {'m_components': 10, 'n_components': 15},\n",
73+
" 'pca__7': {'m_components': 10, 'n_components': 30},\n",
74+
" 'pca__8': {'m_components': 10, 'n_components': 45},\n",
75+
" 'pca__9': {'m_components': 10, 'n_components': 64},\n",
76+
" 'logistic__0': {'C': 0.0001},\n",
77+
" 'logistic__1': {'C': 0.046415888336127774},\n",
78+
" 'logistic__2': {'C': 21.54434690031882},\n",
79+
" 'logistic__3': {'C': 10000.0}}"
80+
]
81+
},
82+
"execution_count": 20,
83+
"metadata": {},
84+
"output_type": "execute_result"
85+
}
86+
],
87+
"source": [
88+
"pipeline_params.__node_name_param_map__"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": 8,
3994
"id": "ec28e356",
4095
"metadata": {},
4196
"outputs": [],
@@ -45,7 +100,7 @@
45100
},
46101
{
47102
"cell_type": "code",
48-
"execution_count": 6,
103+
"execution_count": 9,
49104
"id": "fca4ae99",
50105
"metadata": {},
51106
"outputs": [],
@@ -55,34 +110,62 @@
55110
},
56111
{
57112
"cell_type": "code",
58-
"execution_count": 8,
113+
"execution_count": null,
114+
"id": "d24579ee",
115+
"metadata": {},
116+
"outputs": [],
117+
"source": []
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": 10,
59122
"id": "acc69456",
60123
"metadata": {},
61124
"outputs": [
62125
{
63126
"name": "stdout",
64127
"output_type": "stream",
65128
"text": [
66-
"{'logistic__C': 0.0001, 'pca__n_components': 5}\n",
67-
"{'logistic__C': 0.0001, 'pca__n_components': 15}\n",
68-
"{'logistic__C': 0.0001, 'pca__n_components': 30}\n",
69-
"{'logistic__C': 0.0001, 'pca__n_components': 45}\n",
70-
"{'logistic__C': 0.0001, 'pca__n_components': 64}\n",
71-
"{'logistic__C': 0.046415888336127774, 'pca__n_components': 5}\n",
72-
"{'logistic__C': 0.046415888336127774, 'pca__n_components': 15}\n",
73-
"{'logistic__C': 0.046415888336127774, 'pca__n_components': 30}\n",
74-
"{'logistic__C': 0.046415888336127774, 'pca__n_components': 45}\n",
75-
"{'logistic__C': 0.046415888336127774, 'pca__n_components': 64}\n",
76-
"{'logistic__C': 21.54434690031882, 'pca__n_components': 5}\n",
77-
"{'logistic__C': 21.54434690031882, 'pca__n_components': 15}\n",
78-
"{'logistic__C': 21.54434690031882, 'pca__n_components': 30}\n",
79-
"{'logistic__C': 21.54434690031882, 'pca__n_components': 45}\n",
80-
"{'logistic__C': 21.54434690031882, 'pca__n_components': 64}\n",
81-
"{'logistic__C': 10000.0, 'pca__n_components': 5}\n",
82-
"{'logistic__C': 10000.0, 'pca__n_components': 15}\n",
83-
"{'logistic__C': 10000.0, 'pca__n_components': 30}\n",
84-
"{'logistic__C': 10000.0, 'pca__n_components': 45}\n",
85-
"{'logistic__C': 10000.0, 'pca__n_components': 64}\n"
129+
"{'logistic__C': 0.0001, 'pca__m_components': 6, 'pca__n__components': 5}\n",
130+
"{'logistic__C': 0.0001, 'pca__m_components': 6, 'pca__n__components': 15}\n",
131+
"{'logistic__C': 0.0001, 'pca__m_components': 6, 'pca__n__components': 30}\n",
132+
"{'logistic__C': 0.0001, 'pca__m_components': 6, 'pca__n__components': 45}\n",
133+
"{'logistic__C': 0.0001, 'pca__m_components': 6, 'pca__n__components': 64}\n",
134+
"{'logistic__C': 0.0001, 'pca__m_components': 10, 'pca__n__components': 5}\n",
135+
"{'logistic__C': 0.0001, 'pca__m_components': 10, 'pca__n__components': 15}\n",
136+
"{'logistic__C': 0.0001, 'pca__m_components': 10, 'pca__n__components': 30}\n",
137+
"{'logistic__C': 0.0001, 'pca__m_components': 10, 'pca__n__components': 45}\n",
138+
"{'logistic__C': 0.0001, 'pca__m_components': 10, 'pca__n__components': 64}\n",
139+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 6, 'pca__n__components': 5}\n",
140+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 6, 'pca__n__components': 15}\n",
141+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 6, 'pca__n__components': 30}\n",
142+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 6, 'pca__n__components': 45}\n",
143+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 6, 'pca__n__components': 64}\n",
144+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 10, 'pca__n__components': 5}\n",
145+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 10, 'pca__n__components': 15}\n",
146+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 10, 'pca__n__components': 30}\n",
147+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 10, 'pca__n__components': 45}\n",
148+
"{'logistic__C': 0.046415888336127774, 'pca__m_components': 10, 'pca__n__components': 64}\n",
149+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 6, 'pca__n__components': 5}\n",
150+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 6, 'pca__n__components': 15}\n",
151+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 6, 'pca__n__components': 30}\n",
152+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 6, 'pca__n__components': 45}\n",
153+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 6, 'pca__n__components': 64}\n",
154+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 10, 'pca__n__components': 5}\n",
155+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 10, 'pca__n__components': 15}\n",
156+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 10, 'pca__n__components': 30}\n",
157+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 10, 'pca__n__components': 45}\n",
158+
"{'logistic__C': 21.54434690031882, 'pca__m_components': 10, 'pca__n__components': 64}\n",
159+
"{'logistic__C': 10000.0, 'pca__m_components': 6, 'pca__n__components': 5}\n",
160+
"{'logistic__C': 10000.0, 'pca__m_components': 6, 'pca__n__components': 15}\n",
161+
"{'logistic__C': 10000.0, 'pca__m_components': 6, 'pca__n__components': 30}\n",
162+
"{'logistic__C': 10000.0, 'pca__m_components': 6, 'pca__n__components': 45}\n",
163+
"{'logistic__C': 10000.0, 'pca__m_components': 6, 'pca__n__components': 64}\n",
164+
"{'logistic__C': 10000.0, 'pca__m_components': 10, 'pca__n__components': 5}\n",
165+
"{'logistic__C': 10000.0, 'pca__m_components': 10, 'pca__n__components': 15}\n",
166+
"{'logistic__C': 10000.0, 'pca__m_components': 10, 'pca__n__components': 30}\n",
167+
"{'logistic__C': 10000.0, 'pca__m_components': 10, 'pca__n__components': 45}\n",
168+
"{'logistic__C': 10000.0, 'pca__m_components': 10, 'pca__n__components': 64}\n"
86169
]
87170
}
88171
],

0 commit comments

Comments
 (0)