Skip to content

Commit c4c7990

Browse files
yuanchi2807GitHub Enterprise
authored andcommitted
Merge pull request #27 from codeflare/pipeline_terminal_bug
Pipeline terminal bug
2 parents 14c53e4 + 93c1a40 commit c4c7990

File tree

9 files changed

+504
-363
lines changed

9 files changed

+504
-363
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from abc import ABC, abstractmethod
2-
import uuid
32
from enum import Enum
43

5-
64
import sklearn.base as base
75
from sklearn.base import TransformerMixin
86
from sklearn.base import BaseEstimator
97

8+
import ray
9+
import codeflare.pipelines.Exceptions as pe
10+
1011
class Xy:
1112
"""
1213
Holder class for Xy, where X is array-like and y is array-like. This is the base
@@ -40,11 +41,11 @@ class XYRef:
4041
computed), these holders are essential to the pipeline constructs.
4142
"""
4243

43-
def __init__(self, Xref, yref, prev_noderef=None, curr_noderef=None, prev_Xyrefs = None):
44+
def __init__(self, Xref, yref, prev_node_state_ref=None, curr_node_state_ref=None, prev_Xyrefs = None):
4445
self.__Xref__ = Xref
4546
self.__yref__ = yref
46-
self.__prevnoderef__ = prev_noderef
47-
self.__currnoderef__ = curr_noderef
47+
self.__prev_node_state_ref__ = prev_node_state_ref
48+
self.__curr_node_state_ref__ = curr_node_state_ref
4849
self.__prev_Xyrefs__ = prev_Xyrefs
4950

5051
def get_Xref(self):
@@ -59,11 +60,11 @@ def get_yref(self):
5960
"""
6061
return self.__yref__
6162

62-
def get_prevnoderef(self):
63-
return self.__prevnoderef__
63+
def get_prev_node_state_ref(self):
64+
return self.__prev_node_state_ref__
6465

65-
def get_currnoderef(self):
66-
return self.__currnoderef__
66+
def get_curr_node_state_ref(self):
67+
return self.__curr_node_state_ref__
6768

6869
def get_prev_xyrefs(self):
6970
return self.__prev_Xyrefs__
@@ -98,14 +99,10 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
9899
self.__node_input_type__ = node_input_type
99100
self.__node_firing_type__ = node_firing_type
100101
self.__node_state_type__ = node_state_type
101-
self.__id__ = uuid.uuid4()
102102

103103
def __str__(self):
104104
return self.__node_name__
105105

106-
def get_id(self):
107-
return self.__id__
108-
109106
def get_node_input_type(self):
110107
return self.__node_input_type__
111108

@@ -125,8 +122,7 @@ def __hash__(self):
125122
126123
:return: Hash code
127124
"""
128-
129-
return self.__id__.__hash__()
125+
return self.__node_name__.__hash__()
130126

131127
def __eq__(self, other):
132128
"""
@@ -138,7 +134,6 @@ def __eq__(self, other):
138134
"""
139135
return (
140136
self.__class__ == other.__class__ and
141-
self.__id__ == other.__id__ and
142137
self.__node_name__ == other.__node_name__
143138
)
144139

@@ -373,5 +368,69 @@ def get_post_edges(self, node: Node):
373368
return post_edges
374369

375370
def is_terminal(self, node: Node):
376-
node_post_edges = self.get_post_edges(node)
377-
return len(node_post_edges) == 0
371+
post_nodes = self.__post_graph__[node]
372+
return not post_nodes
373+
374+
def get_terminal_nodes(self):
375+
# dict from level to nodes
376+
terminal_nodes = []
377+
for node in self.__pre_graph__.keys():
378+
if self.is_terminal(node):
379+
terminal_nodes.append(node)
380+
return terminal_nodes
381+
382+
383+
class PipelineOutput:
384+
"""
385+
Pipeline output to keep reference counters so that pipelines can be materialized
386+
"""
387+
def __init__(self, out_args, edge_args):
388+
self.__out_args__ = out_args
389+
self.__edge_args__ = edge_args
390+
391+
def get_xyrefs(self, node: Node):
392+
if node in self.__out_args__:
393+
xyrefs_ptr = self.__out_args__[node]
394+
elif node in self.__edge_args__:
395+
xyrefs_ptr = self.__edge_args__[node]
396+
else:
397+
raise pe.PipelineNodeNotFoundException("Node " + str(node) + " not found")
398+
399+
xyrefs = ray.get(xyrefs_ptr)
400+
return xyrefs
401+
402+
def get_edge_args(self):
403+
return self.__edge_args__
404+
405+
406+
class PipelineInput:
407+
"""
408+
in_args is a dict from a node -> [Xy]
409+
"""
410+
def __init__(self):
411+
self.__in_args__ = {}
412+
413+
def add_xyref_ptr_arg(self, node: Node, xyref_ptr):
414+
if node not in self.__in_args__:
415+
self.__in_args__[node] = []
416+
417+
self.__in_args__[node].append(xyref_ptr)
418+
419+
def add_xyref_arg(self, node: Node, xyref: XYRef):
420+
if node not in self.__in_args__:
421+
self.__in_args__[node] = []
422+
423+
xyref_ptr = ray.put(xyref)
424+
self.__in_args__[node].append(xyref_ptr)
425+
426+
def add_xy_arg(self, node: Node, xy: Xy):
427+
if node not in self.__in_args__:
428+
self.__in_args__[node] = []
429+
430+
x_ref = ray.put(xy.get_x())
431+
y_ref = ray.put(xy.get_y())
432+
xyref = XYRef(x_ref, y_ref)
433+
self.add_xyref_arg(node, xyref)
434+
435+
def get_in_args(self):
436+
return self.__in_args__

codeflare/pipelines/Exceptions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class BasePipelineException(Exception):
2+
pass
3+
4+
5+
class PipelineSaveException(BasePipelineException):
6+
def __init__(self, message):
7+
self.message = message
8+
9+
10+
class PipelineNodeNotFoundException(BasePipelineException):
11+
def __init__(self, message):
12+
self.message = message
13+
14+
15+
class PipelineException(BasePipelineException):
16+
def __init__(self, message):
17+
self.message = message

0 commit comments

Comments
 (0)