1
- from sklearn .base import BaseEstimator
2
1
from abc import ABC , abstractmethod
2
+ from enum import Enum
3
+
4
+ import sklearn .base as base
5
+ from sklearn .base import TransformerMixin
6
+ from sklearn .base import BaseEstimator
3
7
8
+ import ray
9
+ import codeflare .pipelines .Exceptions as pe
4
10
5
11
class Xy :
6
12
"""
@@ -35,9 +41,12 @@ class XYRef:
35
41
computed), these holders are essential to the pipeline constructs.
36
42
"""
37
43
38
- def __init__ (self , Xref , yref ):
44
+ def __init__ (self , Xref , yref , prev_node_state_ref = None , curr_node_state_ref = None , prev_Xyrefs = None ):
39
45
self .__Xref__ = Xref
40
46
self .__yref__ = yref
47
+ self .__prev_node_state_ref__ = prev_node_state_ref
48
+ self .__curr_node_state_ref__ = curr_node_state_ref
49
+ self .__prev_Xyrefs__ = prev_Xyrefs
41
50
42
51
def get_Xref (self ):
43
52
"""
@@ -51,6 +60,32 @@ def get_yref(self):
51
60
"""
52
61
return self .__yref__
53
62
63
+ def get_prev_node_state_ref (self ):
64
+ return self .__prev_node_state_ref__
65
+
66
+ def get_curr_node_state_ref (self ):
67
+ return self .__curr_node_state_ref__
68
+
69
+ def get_prev_xyrefs (self ):
70
+ return self .__prev_Xyrefs__
71
+
72
+
73
+ class NodeInputType (Enum ):
74
+ OR = 0 ,
75
+ AND = 1
76
+
77
+
78
+ class NodeFiringType (Enum ):
79
+ ANY = 0 ,
80
+ ALL = 1
81
+
82
+
83
+ class NodeStateType (Enum ):
84
+ STATELESS = 0 ,
85
+ IMMUTABLE = 1 ,
86
+ MUTABLE_SEQUENTIAL = 2 ,
87
+ MUTABLE_AGGREGATE = 3
88
+
54
89
55
90
class Node (ABC ):
56
91
"""
@@ -59,12 +94,27 @@ class Node(ABC):
59
94
node name and the type of the node match.
60
95
"""
61
96
97
+ def __init__ (self , node_name , node_input_type : NodeInputType , node_firing_type : NodeFiringType , node_state_type : NodeStateType ):
98
+ self .__node_name__ = node_name
99
+ self .__node_input_type__ = node_input_type
100
+ self .__node_firing_type__ = node_firing_type
101
+ self .__node_state_type__ = node_state_type
102
+
62
103
def __str__ (self ):
63
104
return self .__node_name__
64
105
106
+ def get_node_input_type (self ):
107
+ return self .__node_input_type__
108
+
109
+ def get_node_firing_type (self ):
110
+ return self .__node_firing_type__
111
+
112
+ def get_node_state_type (self ):
113
+ return self .__node_state_type__
114
+
65
115
@abstractmethod
66
- def get_and_flag (self ):
67
- raise NotImplementedError ("Please implement this method" )
116
+ def clone (self ):
117
+ raise NotImplementedError ("Please implement the clone method" )
68
118
69
119
def __hash__ (self ):
70
120
"""
@@ -88,12 +138,11 @@ def __eq__(self, other):
88
138
)
89
139
90
140
91
- class OrNode (Node ):
141
+ class EstimatorNode (Node ):
92
142
"""
93
143
Or node, which is the basic node that would be the equivalent of any SKlearn pipeline
94
144
stage. This node is initialized with an estimator that needs to extend sklearn.BaseEstimator.
95
145
"""
96
- __estimator__ = None
97
146
98
147
def __init__ (self , node_name : str , estimator : BaseEstimator ):
99
148
"""
@@ -102,7 +151,8 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
102
151
:param node_name: Name of the node
103
152
:param estimator: The base estimator
104
153
"""
105
- self .__node_name__ = node_name
154
+
155
+ super ().__init__ (node_name , NodeInputType .OR , NodeFiringType .ANY , NodeStateType .IMMUTABLE )
106
156
self .__estimator__ = estimator
107
157
108
158
def get_estimator (self ) -> BaseEstimator :
@@ -113,37 +163,33 @@ def get_estimator(self) -> BaseEstimator:
113
163
"""
114
164
return self .__estimator__
115
165
116
- def get_and_flag (self ):
117
- """
118
- A flag to check if node is AND or not. By definition, this is NOT
119
- an AND node.
120
- :return: False, always
121
- """
122
- return False
166
+ def clone (self ):
167
+ cloned_estimator = base .clone (self .__estimator__ )
168
+ return EstimatorNode (self .__node_name__ , cloned_estimator )
123
169
124
170
125
- class AndFunc ( ABC ):
126
- """
127
- Or nodes are init-ed from the
128
- """
171
+ class AndTransform ( TransformerMixin , BaseEstimator ):
172
+ @ abstractmethod
173
+ def transform ( self , xy_list : list ) -> Xy :
174
+ raise NotImplementedError ( "Please implement this method" )
129
175
176
+
177
+ class GeneralTransform (TransformerMixin , BaseEstimator ):
130
178
@abstractmethod
131
- def eval (self , xy_list : list ) -> Xy :
179
+ def transform (self , xy : Xy ) -> Xy :
132
180
raise NotImplementedError ("Please implement this method" )
133
181
134
182
135
183
class AndNode (Node ):
136
- __andfunc__ = None
137
-
138
- def __init__ (self , node_name : str , and_func : AndFunc ):
139
- self .__node_name__ = node_name
184
+ def __init__ (self , node_name : str , and_func : AndTransform ):
185
+ super ().__init__ (node_name , NodeInputType .AND , NodeFiringType .ANY , NodeStateType .STATELESS )
140
186
self .__andfunc__ = and_func
141
187
142
- def get_and_func (self ) -> AndFunc :
188
+ def get_and_func (self ) -> AndTransform :
143
189
return self .__andfunc__
144
190
145
- def get_and_flag (self ):
146
- return True
191
+ def clone (self ):
192
+ return AndNode ( self . __node_name__ , self . __andfunc__ )
147
193
148
194
149
195
class Edge :
@@ -322,5 +368,69 @@ def get_post_edges(self, node: Node):
322
368
return post_edges
323
369
324
370
def is_terminal (self , node : Node ):
325
- node_post_edges = self .get_post_edges (node )
326
- return len (node_post_edges ) == 0
371
+ post_nodes = self .__post_graph__ [node ]
372
+ return not post_nodes
373
+
374
+ def get_terminal_nodes (self ):
375
+ # dict from level to nodes
376
+ terminal_nodes = []
377
+ for node in self .__pre_graph__ .keys ():
378
+ if self .is_terminal (node ):
379
+ terminal_nodes .append (node )
380
+ return terminal_nodes
381
+
382
+
383
+ class PipelineOutput :
384
+ """
385
+ Pipeline output to keep reference counters so that pipelines can be materialized
386
+ """
387
+ def __init__ (self , out_args , edge_args ):
388
+ self .__out_args__ = out_args
389
+ self .__edge_args__ = edge_args
390
+
391
+ def get_xyrefs (self , node : Node ):
392
+ if node in self .__out_args__ :
393
+ xyrefs_ptr = self .__out_args__ [node ]
394
+ elif node in self .__edge_args__ :
395
+ xyrefs_ptr = self .__edge_args__ [node ]
396
+ else :
397
+ raise pe .PipelineNodeNotFoundException ("Node " + str (node ) + " not found" )
398
+
399
+ xyrefs = ray .get (xyrefs_ptr )
400
+ return xyrefs
401
+
402
+ def get_edge_args (self ):
403
+ return self .__edge_args__
404
+
405
+
406
+ class PipelineInput :
407
+ """
408
+ in_args is a dict from a node -> [Xy]
409
+ """
410
+ def __init__ (self ):
411
+ self .__in_args__ = {}
412
+
413
+ def add_xyref_ptr_arg (self , node : Node , xyref_ptr ):
414
+ if node not in self .__in_args__ :
415
+ self .__in_args__ [node ] = []
416
+
417
+ self .__in_args__ [node ].append (xyref_ptr )
418
+
419
+ def add_xyref_arg (self , node : Node , xyref : XYRef ):
420
+ if node not in self .__in_args__ :
421
+ self .__in_args__ [node ] = []
422
+
423
+ xyref_ptr = ray .put (xyref )
424
+ self .__in_args__ [node ].append (xyref_ptr )
425
+
426
+ def add_xy_arg (self , node : Node , xy : Xy ):
427
+ if node not in self .__in_args__ :
428
+ self .__in_args__ [node ] = []
429
+
430
+ x_ref = ray .put (xy .get_x ())
431
+ y_ref = ray .put (xy .get_y ())
432
+ xyref = XYRef (x_ref , y_ref )
433
+ self .add_xyref_arg (node , xyref )
434
+
435
+ def get_in_args (self ):
436
+ return self .__in_args__
0 commit comments