1
1
import ray
2
2
3
- from codeflare .pipelines .Datamodel import OrNode
3
+
4
+ from codeflare .pipelines .Datamodel import EstimatorNode
4
5
from codeflare .pipelines .Datamodel import AndNode
5
6
from codeflare .pipelines .Datamodel import Edge
6
7
from codeflare .pipelines .Datamodel import Pipeline
7
8
from codeflare .pipelines .Datamodel import XYRef
8
9
from codeflare .pipelines .Datamodel import Xy
10
+ from codeflare .pipelines .Datamodel import NodeInputType
11
+ from codeflare .pipelines .Datamodel import NodeStateType
12
+ from codeflare .pipelines .Datamodel import NodeFiringType
9
13
10
14
import sklearn .base as base
11
15
from enum import Enum
@@ -18,47 +22,60 @@ class ExecutionType(Enum):
18
22
19
23
20
24
@ray .remote
21
- def execute_or_node_inner (node : OrNode , train_mode : ExecutionType , Xy : XYRef ):
25
+ def execute_or_node_remote (node : EstimatorNode , train_mode : ExecutionType , xy_ref : XYRef ):
22
26
estimator = node .get_estimator ()
23
27
# Blocking operation -- not avoidable
24
- X = ray .get (Xy .get_Xref ())
25
- y = ray .get (Xy .get_yref ())
28
+ X = ray .get (xy_ref .get_Xref ())
29
+ y = ray .get (xy_ref .get_yref ())
26
30
31
+ # TODO: Can optimize the node pointers without replicating them
27
32
if train_mode == ExecutionType .FIT :
33
+ cloned_node = node .clone ()
34
+ prev_node_ptr = ray .put (node )
35
+
28
36
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
29
37
# Always clone before fit, else fit is invalid
30
- cloned_estimator = base . clone ( estimator )
38
+ cloned_estimator = cloned_node . get_estimator ( )
31
39
cloned_estimator .fit (X , y )
40
+
41
+ curr_node_ptr = ray .put (cloned_node )
32
42
# TODO: For now, make yref passthrough - this has to be fixed more comprehensively
33
43
res_Xref = ray .put (cloned_estimator .predict (X ))
34
- result = XYRef (res_Xref , Xy .get_yref ())
44
+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [ xy_ref ] )
35
45
return result
36
46
else :
37
- # No need to clone as it is a transform pass through on the fitted estimator
38
- res_Xref = ray .put (estimator .fit_transform (X , y ))
39
- result = XYRef (res_Xref , Xy .get_yref ())
47
+ cloned_estimator = cloned_node .get_estimator ()
48
+ res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
49
+ curr_node_ptr = ray .put (cloned_node )
50
+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
40
51
return result
41
52
elif train_mode == ExecutionType .SCORE :
53
+ cloned_node = node .clone ()
54
+ prev_node_ptr = ray .put (node )
55
+
42
56
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
43
- cloned_estimator = base . clone ( estimator )
57
+ cloned_estimator = cloned_node . get_estimator ( )
44
58
cloned_estimator .fit (X , y )
59
+ curr_node_ptr = ray .put (cloned_node )
45
60
res_Xref = ray .put (cloned_estimator .score (X , y ))
46
- result = XYRef (res_Xref , Xy .get_yref ())
61
+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [ xy_ref ] )
47
62
return result
48
63
else :
49
- # No need to clone as it is a transform pass through on the fitted estimator
50
- res_Xref = ray .put (estimator .fit_transform (X , y ))
51
- result = XYRef (res_Xref , Xy .get_yref ())
64
+ cloned_estimator = cloned_node .get_estimator ()
65
+ res_Xref = ray .put (cloned_estimator .fit_transform (X , y ))
66
+ curr_node_ptr = ray .put (cloned_node )
67
+ result = XYRef (res_Xref , xy_ref .get_yref (), prev_node_ptr , curr_node_ptr , [xy_ref ])
68
+
52
69
return result
53
70
elif train_mode == ExecutionType .PREDICT :
54
71
# Test mode does not clone as it is a simple predict or transform
55
72
if base .is_classifier (estimator ) or base .is_regressor (estimator ):
56
73
res_Xref = estimator .predict (X )
57
- result = XYRef (res_Xref , Xy .get_yref ())
74
+ result = XYRef (res_Xref , xy_ref .get_yref ())
58
75
return result
59
76
else :
60
77
res_Xref = estimator .transform (X )
61
- result = XYRef (res_Xref , Xy .get_yref ())
78
+ result = XYRef (res_Xref , xy_ref .get_yref ())
62
79
return result
63
80
64
81
@@ -68,7 +85,7 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
68
85
exec_xyrefs = []
69
86
for xy_ref_ptr in Xyref_ptrs :
70
87
xy_ref = ray .get (xy_ref_ptr )
71
- inner_result = execute_or_node_inner .remote (node , mode , xy_ref )
88
+ inner_result = execute_or_node_remote .remote (node , mode , xy_ref )
72
89
exec_xyrefs .append (inner_result )
73
90
74
91
for post_edge in post_edges :
@@ -78,29 +95,33 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
78
95
79
96
80
97
@ray .remote
81
- def and_node_eval ( and_func , Xyref_list ):
98
+ def execute_and_node_remote ( node : AndNode , Xyref_list ):
82
99
xy_list = []
100
+ prev_node_ptr = ray .put (node )
83
101
for Xyref in Xyref_list :
84
102
X = ray .get (Xyref .get_Xref ())
85
103
y = ray .get (Xyref .get_yref ())
86
104
xy_list .append (Xy (X , y ))
87
105
88
- res_Xy = and_func .eval (xy_list )
106
+ cloned_node = node .clone ()
107
+ curr_node_ptr = ray .put (cloned_node )
108
+
109
+ cloned_and_func = cloned_node .get_and_func ()
110
+ res_Xy = cloned_and_func .transform (xy_list )
89
111
res_Xref = ray .put (res_Xy .get_x ())
90
112
res_yref = ray .put (res_Xy .get_y ())
91
- return XYRef (res_Xref , res_yref )
113
+ return XYRef (res_Xref , res_yref , prev_node_ptr , curr_node_ptr , Xyref_list )
92
114
93
115
94
116
def execute_and_node_inner (node : AndNode , Xyref_ptrs ):
95
- and_func = node .get_and_func ()
96
117
result = []
97
118
98
119
Xyref_list = []
99
120
for Xyref_ptr in Xyref_ptrs :
100
121
Xyref = ray .get (Xyref_ptr )
101
122
Xyref_list .append (Xyref )
102
123
103
- Xyref_ptr = and_node_eval .remote (and_func , Xyref_list )
124
+ Xyref_ptr = execute_and_node_remote .remote (node , Xyref_list )
104
125
result .append (Xyref_ptr )
105
126
return result
106
127
@@ -136,9 +157,9 @@ def execute_pipeline(pipeline: Pipeline, mode: ExecutionType, in_args: dict):
136
157
for node in nodes :
137
158
pre_edges = pipeline .get_pre_edges (node )
138
159
post_edges = pipeline .get_post_edges (node )
139
- if not node .get_and_flag () :
160
+ if node .get_node_input_type () == NodeInputType . OR :
140
161
execute_or_node (node , pre_edges , edge_args , post_edges , mode )
141
- elif node .get_and_flag () :
162
+ elif node .get_node_input_type () == NodeInputType . AND :
142
163
execute_and_node (node , pre_edges , edge_args , post_edges )
143
164
144
165
out_args = {}
0 commit comments