Skip to content

Commit 616516b

Browse files
yuanchi2807GitHub Enterprise
authored andcommitted
Merge pull request #36 from codeflare/pickle
Pickle
2 parents c4c7990 + 6506ada commit 616516b

File tree

5 files changed

+209
-3
lines changed

5 files changed

+209
-3
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sklearn.base import BaseEstimator
77

88
import ray
9+
import pickle5 as pickle
910
import codeflare.pipelines.Exceptions as pe
1011

1112
class Xy:
@@ -103,6 +104,9 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
103104
def __str__(self):
104105
return self.__node_name__
105106

107+
def get_node_name(self):
108+
return self.__node_name__
109+
106110
def get_node_input_type(self):
107111
return self.__node_input_type__
108112

@@ -379,6 +383,64 @@ def get_terminal_nodes(self):
379383
terminal_nodes.append(node)
380384
return terminal_nodes
381385

386+
def get_nodes(self):
387+
nodes = {}
388+
for node in self.__pre_graph__.keys():
389+
nodes[node.get_node_name()] = node
390+
return nodes
391+
392+
def get_pre_nodes(self, node):
393+
return self.__pre_graph__[node]
394+
395+
def get_post_nodes(self, node):
396+
return self.__post_graph__[node]
397+
398+
def save(self, filehandle):
399+
nodes = {}
400+
edges = []
401+
402+
for node in self.__pre_graph__.keys():
403+
nodes[node.get_node_name()] = node
404+
pre_edges = self.get_pre_edges(node)
405+
for edge in pre_edges:
406+
# Since we are iterating on pre_edges, to_node cannot be None
407+
from_node = edge.get_from_node()
408+
if from_node is not None:
409+
to_node = edge.get_to_node()
410+
edge_tuple = (from_node.get_node_name(), to_node.get_node_name())
411+
edges.append(edge_tuple)
412+
saved_pipeline = _SavedPipeline(nodes, edges)
413+
pickle.dump(saved_pipeline, filehandle)
414+
415+
@staticmethod
416+
def load(filehandle):
417+
saved_pipeline = pickle.load(filehandle)
418+
if not isinstance(saved_pipeline, _SavedPipeline):
419+
raise pe.PipelineException("Filehandle is not a saved pipeline instance")
420+
421+
nodes = saved_pipeline.get_nodes()
422+
edges = saved_pipeline.get_edges()
423+
424+
pipeline = Pipeline()
425+
for edge in edges:
426+
(from_node_str, to_node_str) = edge
427+
from_node = nodes[from_node_str]
428+
to_node = nodes[to_node_str]
429+
pipeline.add_edge(from_node, to_node)
430+
return pipeline
431+
432+
433+
class _SavedPipeline:
434+
def __init__(self, nodes, edges):
435+
self.__nodes__ = nodes
436+
self.__edges__ = edges
437+
438+
def get_nodes(self):
439+
return self.__nodes__
440+
441+
def get_edges(self):
442+
return self.__edges__
443+
382444

383445
class PipelineOutput:
384446
"""

codeflare/pipelines/Runtime.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
5858
elif mode == ExecutionType.PREDICT:
5959
# Test mode does not clone as it is a simple predict or transform
6060
if base.is_classifier(estimator) or base.is_regressor(estimator):
61-
res_Xref = estimator.predict(X)
61+
res_Xref = ray.put(estimator.predict(X))
6262
result = dm.XYRef(res_Xref, xy_ref.get_yref())
6363
return result
6464
else:
65-
res_Xref = estimator.transform(X)
65+
res_Xref = ray.put(estimator.transform(X))
6666
result = dm.XYRef(res_Xref, xy_ref.get_yref())
6767
return result
6868

@@ -265,3 +265,8 @@ def cross_validate(cross_validator: BaseCrossValidator, pipeline: dm.Pipeline, p
265265
result_scores.append(out_x)
266266

267267
return result_scores
268+
269+
270+
def save(pipeline_output: dm.PipelineOutput, xy_ref: dm.XYRef, filehandle):
271+
pipeline = select_pipeline(pipeline_output, xy_ref)
272+
pipeline.save(filehandle)

codeflare/pipelines/tests/__init__.py

Whitespace-only changes.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import codeflare.pipelines.Datamodel as dm
2+
import codeflare.pipelines.Runtime as rt
3+
4+
import numpy as np
5+
from sklearn.preprocessing import MinMaxScaler
6+
import os
7+
import pandas as pd
8+
from sklearn.pipeline import Pipeline
9+
from sklearn.impute import SimpleImputer
10+
from sklearn.preprocessing import StandardScaler, OneHotEncoder
11+
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
12+
13+
import ray
14+
15+
16+
class FeatureUnion(dm.AndTransform):
17+
def __init__(self):
18+
pass
19+
20+
def transform(self, xy_list):
21+
X_list = []
22+
y_list = []
23+
24+
for xy in xy_list:
25+
X_list.append(xy.get_x())
26+
X_concat = np.concatenate(X_list, axis=0)
27+
28+
return dm.Xy(X_concat, None)
29+
30+
31+
def test_save_load():
32+
"""
33+
A simple save load test for a pipeline graph
34+
:return:
35+
"""
36+
pipeline = dm.Pipeline()
37+
minmax_scaler = MinMaxScaler()
38+
39+
node_a = dm.EstimatorNode('a', minmax_scaler)
40+
node_b = dm.EstimatorNode('b', minmax_scaler)
41+
node_c = dm.AndNode('c', FeatureUnion())
42+
43+
pipeline.add_edge(node_a, node_c)
44+
pipeline.add_edge(node_b, node_c)
45+
46+
fname = 'save_pipeline.cfp'
47+
fh = open(fname, 'wb')
48+
pipeline.save(fh)
49+
fh.close()
50+
51+
r_fh = open(fname, 'rb')
52+
saved_pipeline = dm.Pipeline.load(r_fh)
53+
pre_edges = saved_pipeline.get_pre_edges(node_c)
54+
assert (len(pre_edges) == 2)
55+
os.remove(fname)
56+
57+
58+
def test_runtime_save_load():
59+
"""
60+
Tests for selecting a pipeline and save/load it, we also test the predict to ensure state is
61+
captured accurately
62+
:return:
63+
"""
64+
train = pd.read_csv('../../../resources/data/train_ctrUa4K.csv')
65+
train = train.drop('Loan_ID', axis=1)
66+
67+
X = train.drop('Loan_Status', axis=1)
68+
y = train['Loan_Status']
69+
from sklearn.model_selection import train_test_split
70+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
71+
imputer = SimpleImputer(strategy='median')
72+
scaler = StandardScaler()
73+
74+
numeric_transformer = Pipeline(steps=[
75+
('imputer', imputer),
76+
('scaler', scaler)])
77+
78+
cat_imputer = SimpleImputer(strategy='constant', fill_value='missing')
79+
cat_onehot = OneHotEncoder(handle_unknown='ignore')
80+
81+
categorical_transformer = Pipeline(steps=[
82+
('imputer', cat_imputer),
83+
('onehot', cat_onehot)])
84+
numeric_features = train.select_dtypes(include=['int64', 'float64']).columns
85+
categorical_features = train.select_dtypes(include=['object']).drop(['Loan_Status'], axis=1).columns
86+
from sklearn.compose import ColumnTransformer
87+
preprocessor = ColumnTransformer(
88+
transformers=[
89+
('num', numeric_transformer, numeric_features),
90+
('cat', categorical_transformer, categorical_features)])
91+
92+
classifiers = [
93+
RandomForestClassifier(),
94+
GradientBoostingClassifier()
95+
]
96+
pipeline = dm.Pipeline()
97+
node_pre = dm.EstimatorNode('preprocess', preprocessor)
98+
node_rf = dm.EstimatorNode('random_forest', classifiers[0])
99+
node_gb = dm.EstimatorNode('gradient_boost', classifiers[1])
100+
101+
pipeline.add_edge(node_pre, node_rf)
102+
pipeline.add_edge(node_pre, node_gb)
103+
104+
import ray
105+
ray.shutdown()
106+
ray.init()
107+
pipeline_input = dm.PipelineInput()
108+
xy = dm.Xy(X_train, y_train)
109+
pipeline_input.add_xy_arg(node_pre, xy)
110+
111+
pipeline_output = rt.execute_pipeline(pipeline, rt.ExecutionType.FIT, pipeline_input)
112+
node_rf_xyrefs = pipeline_output.get_xyrefs(node_rf)
113+
114+
# save this pipeline for random forest and load and then predict on test data
115+
fname = 'random_forest.cfp'
116+
w_fh = open(fname, 'wb')
117+
rt.save(pipeline_output, node_rf_xyrefs[0], w_fh)
118+
w_fh.close()
119+
120+
# load it
121+
r_fh = open(fname, 'rb')
122+
saved_pipeline = dm.Pipeline.load(r_fh)
123+
nodes = saved_pipeline.get_nodes()
124+
# this should not exist in the saved pipeline
125+
assert(node_gb.get_node_name() not in nodes.keys())
126+
127+
# should be preditable as well
128+
predict_pipeline_input = dm.PipelineInput()
129+
predict_pipeline_input.add_xy_arg(node_pre, dm.Xy(X_test, y_test))
130+
try:
131+
predict_pipeline_output = rt.execute_pipeline(saved_pipeline, rt.ExecutionType.PREDICT, predict_pipeline_input)
132+
predict_pipeline_output.get_xyrefs(node_rf)
133+
except Exception:
134+
assert False
135+
136+
os.remove(fname)

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ ray~=1.3.0
22
setuptools~=52.0.0
33
sklearn~=0.0
44
scikit-learn~=0.24.1
5-
pandas~=1.2.4
5+
pandas~=1.2.4
6+
pytest~=6.2.4
7+
numpy~=1.18.5
8+
pickle5~=0.0.11

0 commit comments

Comments
 (0)