Skip to content

Commit 2613953

Browse files
committed
Merge branch 'develop' into test
2 parents 943cdd3 + c4c7990 commit 2613953

File tree

10 files changed

+705
-495
lines changed

10 files changed

+705
-495
lines changed

.github/pull_request_template.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
**Related Issue**
2+
3+
<!-- List any issues that this pr is related to, fixes or supports -->
4+
5+
Supports #ISSUE_NUMBER
6+
7+
**Related PRs**
8+
9+
<!-- Does this pr depend on or replace an existing pr? Link it here -->
10+
11+
This PR is not dependent on any other PR
12+
13+
**What does this PR do?**
14+
15+
<!-- The intent of this section is to help team members understand the scope of your changes and why you are making said changes. If there are any concerns or risks associated with your pull request, they should be documented here so your team members can give informative feedback in their reviews. -->
16+
17+
**Description of Changes**
18+
19+
<!-- The intent of this section is to describe the implementation details of the changes you made. Any architectural decisions made should be noted here as well as a short description of why you made that decision. -->
20+
21+
**What gif most accurately describes how I feel towards this PR?**
22+
23+
<!-- Cute animals are also acceptable. -->
24+
25+
![Example of a gif](https://media.github.ibm.com/user/59/files/074da280-78ba-11ea-81d1-49ce0654c29d)

codeflare/pipelines/Datamodel.py

Lines changed: 138 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from sklearn.base import BaseEstimator
21
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
4+
import sklearn.base as base
5+
from sklearn.base import TransformerMixin
6+
from sklearn.base import BaseEstimator
37

8+
import ray
9+
import codeflare.pipelines.Exceptions as pe
410

511
class Xy:
612
"""
@@ -35,9 +41,12 @@ class XYRef:
3541
computed), these holders are essential to the pipeline constructs.
3642
"""
3743

38-
def __init__(self, Xref, yref):
44+
def __init__(self, Xref, yref, prev_node_state_ref=None, curr_node_state_ref=None, prev_Xyrefs = None):
3945
self.__Xref__ = Xref
4046
self.__yref__ = yref
47+
self.__prev_node_state_ref__ = prev_node_state_ref
48+
self.__curr_node_state_ref__ = curr_node_state_ref
49+
self.__prev_Xyrefs__ = prev_Xyrefs
4150

4251
def get_Xref(self):
4352
"""
@@ -51,6 +60,32 @@ def get_yref(self):
5160
"""
5261
return self.__yref__
5362

63+
def get_prev_node_state_ref(self):
64+
return self.__prev_node_state_ref__
65+
66+
def get_curr_node_state_ref(self):
67+
return self.__curr_node_state_ref__
68+
69+
def get_prev_xyrefs(self):
70+
return self.__prev_Xyrefs__
71+
72+
73+
class NodeInputType(Enum):
74+
OR = 0,
75+
AND = 1
76+
77+
78+
class NodeFiringType(Enum):
79+
ANY = 0,
80+
ALL = 1
81+
82+
83+
class NodeStateType(Enum):
84+
STATELESS = 0,
85+
IMMUTABLE = 1,
86+
MUTABLE_SEQUENTIAL = 2,
87+
MUTABLE_AGGREGATE = 3
88+
5489

5590
class Node(ABC):
5691
"""
@@ -59,12 +94,27 @@ class Node(ABC):
5994
node name and the type of the node match.
6095
"""
6196

97+
def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type: NodeFiringType, node_state_type: NodeStateType):
98+
self.__node_name__ = node_name
99+
self.__node_input_type__ = node_input_type
100+
self.__node_firing_type__ = node_firing_type
101+
self.__node_state_type__ = node_state_type
102+
62103
def __str__(self):
63104
return self.__node_name__
64105

106+
def get_node_input_type(self):
107+
return self.__node_input_type__
108+
109+
def get_node_firing_type(self):
110+
return self.__node_firing_type__
111+
112+
def get_node_state_type(self):
113+
return self.__node_state_type__
114+
65115
@abstractmethod
66-
def get_and_flag(self):
67-
raise NotImplementedError("Please implement this method")
116+
def clone(self):
117+
raise NotImplementedError("Please implement the clone method")
68118

69119
def __hash__(self):
70120
"""
@@ -88,12 +138,11 @@ def __eq__(self, other):
88138
)
89139

90140

91-
class OrNode(Node):
141+
class EstimatorNode(Node):
92142
"""
93143
Or node, which is the basic node that would be the equivalent of any SKlearn pipeline
94144
stage. This node is initialized with an estimator that needs to extend sklearn.BaseEstimator.
95145
"""
96-
__estimator__ = None
97146

98147
def __init__(self, node_name: str, estimator: BaseEstimator):
99148
"""
@@ -102,7 +151,8 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
102151
:param node_name: Name of the node
103152
:param estimator: The base estimator
104153
"""
105-
self.__node_name__ = node_name
154+
155+
super().__init__(node_name, NodeInputType.OR, NodeFiringType.ANY, NodeStateType.IMMUTABLE)
106156
self.__estimator__ = estimator
107157

108158
def get_estimator(self) -> BaseEstimator:
@@ -113,37 +163,33 @@ def get_estimator(self) -> BaseEstimator:
113163
"""
114164
return self.__estimator__
115165

116-
def get_and_flag(self):
117-
"""
118-
A flag to check if node is AND or not. By definition, this is NOT
119-
an AND node.
120-
:return: False, always
121-
"""
122-
return False
166+
def clone(self):
167+
cloned_estimator = base.clone(self.__estimator__)
168+
return EstimatorNode(self.__node_name__, cloned_estimator)
123169

124170

125-
class AndFunc(ABC):
126-
"""
127-
Or nodes are init-ed from the
128-
"""
171+
class AndTransform(TransformerMixin, BaseEstimator):
172+
@abstractmethod
173+
def transform(self, xy_list: list) -> Xy:
174+
raise NotImplementedError("Please implement this method")
129175

176+
177+
class GeneralTransform(TransformerMixin, BaseEstimator):
130178
@abstractmethod
131-
def eval(self, xy_list: list) -> Xy:
179+
def transform(self, xy: Xy) -> Xy:
132180
raise NotImplementedError("Please implement this method")
133181

134182

135183
class AndNode(Node):
136-
__andfunc__ = None
137-
138-
def __init__(self, node_name: str, and_func: AndFunc):
139-
self.__node_name__ = node_name
184+
def __init__(self, node_name: str, and_func: AndTransform):
185+
super().__init__(node_name, NodeInputType.AND, NodeFiringType.ANY, NodeStateType.STATELESS)
140186
self.__andfunc__ = and_func
141187

142-
def get_and_func(self) -> AndFunc:
188+
def get_and_func(self) -> AndTransform:
143189
return self.__andfunc__
144190

145-
def get_and_flag(self):
146-
return True
191+
def clone(self):
192+
return AndNode(self.__node_name__, self.__andfunc__)
147193

148194

149195
class Edge:
@@ -322,5 +368,69 @@ def get_post_edges(self, node: Node):
322368
return post_edges
323369

324370
def is_terminal(self, node: Node):
325-
node_post_edges = self.get_post_edges(node)
326-
return len(node_post_edges) == 0
371+
post_nodes = self.__post_graph__[node]
372+
return not post_nodes
373+
374+
def get_terminal_nodes(self):
375+
# dict from level to nodes
376+
terminal_nodes = []
377+
for node in self.__pre_graph__.keys():
378+
if self.is_terminal(node):
379+
terminal_nodes.append(node)
380+
return terminal_nodes
381+
382+
383+
class PipelineOutput:
384+
"""
385+
Pipeline output to keep reference counters so that pipelines can be materialized
386+
"""
387+
def __init__(self, out_args, edge_args):
388+
self.__out_args__ = out_args
389+
self.__edge_args__ = edge_args
390+
391+
def get_xyrefs(self, node: Node):
392+
if node in self.__out_args__:
393+
xyrefs_ptr = self.__out_args__[node]
394+
elif node in self.__edge_args__:
395+
xyrefs_ptr = self.__edge_args__[node]
396+
else:
397+
raise pe.PipelineNodeNotFoundException("Node " + str(node) + " not found")
398+
399+
xyrefs = ray.get(xyrefs_ptr)
400+
return xyrefs
401+
402+
def get_edge_args(self):
403+
return self.__edge_args__
404+
405+
406+
class PipelineInput:
407+
"""
408+
in_args is a dict from a node -> [Xy]
409+
"""
410+
def __init__(self):
411+
self.__in_args__ = {}
412+
413+
def add_xyref_ptr_arg(self, node: Node, xyref_ptr):
414+
if node not in self.__in_args__:
415+
self.__in_args__[node] = []
416+
417+
self.__in_args__[node].append(xyref_ptr)
418+
419+
def add_xyref_arg(self, node: Node, xyref: XYRef):
420+
if node not in self.__in_args__:
421+
self.__in_args__[node] = []
422+
423+
xyref_ptr = ray.put(xyref)
424+
self.__in_args__[node].append(xyref_ptr)
425+
426+
def add_xy_arg(self, node: Node, xy: Xy):
427+
if node not in self.__in_args__:
428+
self.__in_args__[node] = []
429+
430+
x_ref = ray.put(xy.get_x())
431+
y_ref = ray.put(xy.get_y())
432+
xyref = XYRef(x_ref, y_ref)
433+
self.add_xyref_arg(node, xyref)
434+
435+
def get_in_args(self):
436+
return self.__in_args__

codeflare/pipelines/Exceptions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class BasePipelineException(Exception):
2+
pass
3+
4+
5+
class PipelineSaveException(BasePipelineException):
6+
def __init__(self, message):
7+
self.message = message
8+
9+
10+
class PipelineNodeNotFoundException(BasePipelineException):
11+
def __init__(self, message):
12+
self.message = message
13+
14+
15+
class PipelineException(BasePipelineException):
16+
def __init__(self, message):
17+
self.message = message

0 commit comments

Comments
 (0)