1
1
from abc import ABC , abstractmethod
2
- import uuid
3
2
from enum import Enum
4
3
5
-
6
4
import sklearn .base as base
7
5
from sklearn .base import TransformerMixin
8
6
from sklearn .base import BaseEstimator
9
7
8
+ import ray
9
+ import codeflare .pipelines .Exceptions as pe
10
+
10
11
class Xy :
11
12
"""
12
13
Holder class for Xy, where X is array-like and y is array-like. This is the base
@@ -40,11 +41,11 @@ class XYRef:
40
41
computed), these holders are essential to the pipeline constructs.
41
42
"""
42
43
43
- def __init__ (self , Xref , yref , prev_noderef = None , curr_noderef = None , prev_Xyrefs = None ):
44
+ def __init__ (self , Xref , yref , prev_node_state_ref = None , curr_node_state_ref = None , prev_Xyrefs = None ):
44
45
self .__Xref__ = Xref
45
46
self .__yref__ = yref
46
- self .__prevnoderef__ = prev_noderef
47
- self .__currnoderef__ = curr_noderef
47
+ self .__prev_node_state_ref__ = prev_node_state_ref
48
+ self .__curr_node_state_ref__ = curr_node_state_ref
48
49
self .__prev_Xyrefs__ = prev_Xyrefs
49
50
50
51
def get_Xref (self ):
@@ -59,11 +60,11 @@ def get_yref(self):
59
60
"""
60
61
return self .__yref__
61
62
62
- def get_prevnoderef (self ):
63
- return self .__prevnoderef__
63
+ def get_prev_node_state_ref (self ):
64
+ return self .__prev_node_state_ref__
64
65
65
- def get_currnoderef (self ):
66
- return self .__currnoderef__
66
+ def get_curr_node_state_ref (self ):
67
+ return self .__curr_node_state_ref__
67
68
68
69
def get_prev_xyrefs (self ):
69
70
return self .__prev_Xyrefs__
@@ -98,14 +99,10 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
98
99
self .__node_input_type__ = node_input_type
99
100
self .__node_firing_type__ = node_firing_type
100
101
self .__node_state_type__ = node_state_type
101
- self .__id__ = uuid .uuid4 ()
102
102
103
103
def __str__ (self ):
104
104
return self .__node_name__
105
105
106
- def get_id (self ):
107
- return self .__id__
108
-
109
106
def get_node_input_type (self ):
110
107
return self .__node_input_type__
111
108
@@ -125,8 +122,7 @@ def __hash__(self):
125
122
126
123
:return: Hash code
127
124
"""
128
-
129
- return self .__id__ .__hash__ ()
125
+ return self .__node_name__ .__hash__ ()
130
126
131
127
def __eq__ (self , other ):
132
128
"""
@@ -138,7 +134,6 @@ def __eq__(self, other):
138
134
"""
139
135
return (
140
136
self .__class__ == other .__class__ and
141
- self .__id__ == other .__id__ and
142
137
self .__node_name__ == other .__node_name__
143
138
)
144
139
@@ -373,5 +368,69 @@ def get_post_edges(self, node: Node):
373
368
return post_edges
374
369
375
370
def is_terminal (self , node : Node ):
376
- node_post_edges = self .get_post_edges (node )
377
- return len (node_post_edges ) == 0
371
+ post_nodes = self .__post_graph__ [node ]
372
+ return not post_nodes
373
+
374
+ def get_terminal_nodes (self ):
375
+ # dict from level to nodes
376
+ terminal_nodes = []
377
+ for node in self .__pre_graph__ .keys ():
378
+ if self .is_terminal (node ):
379
+ terminal_nodes .append (node )
380
+ return terminal_nodes
381
+
382
+
383
+ class PipelineOutput :
384
+ """
385
+ Pipeline output to keep reference counters so that pipelines can be materialized
386
+ """
387
+ def __init__ (self , out_args , edge_args ):
388
+ self .__out_args__ = out_args
389
+ self .__edge_args__ = edge_args
390
+
391
+ def get_xyrefs (self , node : Node ):
392
+ if node in self .__out_args__ :
393
+ xyrefs_ptr = self .__out_args__ [node ]
394
+ elif node in self .__edge_args__ :
395
+ xyrefs_ptr = self .__edge_args__ [node ]
396
+ else :
397
+ raise pe .PipelineNodeNotFoundException ("Node " + str (node ) + " not found" )
398
+
399
+ xyrefs = ray .get (xyrefs_ptr )
400
+ return xyrefs
401
+
402
+ def get_edge_args (self ):
403
+ return self .__edge_args__
404
+
405
+
406
+ class PipelineInput :
407
+ """
408
+ in_args is a dict from a node -> [Xy]
409
+ """
410
+ def __init__ (self ):
411
+ self .__in_args__ = {}
412
+
413
+ def add_xyref_ptr_arg (self , node : Node , xyref_ptr ):
414
+ if node not in self .__in_args__ :
415
+ self .__in_args__ [node ] = []
416
+
417
+ self .__in_args__ [node ].append (xyref_ptr )
418
+
419
+ def add_xyref_arg (self , node : Node , xyref : XYRef ):
420
+ if node not in self .__in_args__ :
421
+ self .__in_args__ [node ] = []
422
+
423
+ xyref_ptr = ray .put (xyref )
424
+ self .__in_args__ [node ].append (xyref_ptr )
425
+
426
+ def add_xy_arg (self , node : Node , xy : Xy ):
427
+ if node not in self .__in_args__ :
428
+ self .__in_args__ [node ] = []
429
+
430
+ x_ref = ray .put (xy .get_x ())
431
+ y_ref = ray .put (xy .get_y ())
432
+ xyref = XYRef (x_ref , y_ref )
433
+ self .add_xyref_arg (node , xyref )
434
+
435
+ def get_in_args (self ):
436
+ return self .__in_args__
0 commit comments