Skip to content

Commit e2aac4e

Browse files
yuanchi2807GitHub Enterprise
authored andcommitted
Merge pull request #63 from codeflare/viz
Viz
2 parents a5b14b2 + 186fbc1 commit e2aac4e

File tree

6 files changed

+386
-56
lines changed

6 files changed

+386
-56
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ class Node(ABC):
200200
A node class that is an abstract one, this is capturing basic info re the Node.
201201
The hash code of this node is the name of the node and equality is defined if the
202202
node name and the type of the node match.
203+
204+
When doing a grid search, a node can be parameterized with new params for the estimator and updated. This
205+
is an internal method used by grid search.
203206
"""
204207

205208
def __init__(self, node_name, estimator: BaseEstimator, node_input_type: NodeInputType, node_firing_type: NodeFiringType, node_state_type: NodeStateType):
@@ -210,6 +213,11 @@ def __init__(self, node_name, estimator: BaseEstimator, node_input_type: NodeInp
210213
self.__node_state_type__ = node_state_type
211214

212215
def __str__(self):
216+
"""
217+
Returns a string representation of the node along with the parameters of the estimator of the node.
218+
219+
:return: String representation of the node
220+
"""
213221
estimator_params_str = str(self.get_estimator().get_params())
214222
retval = self.__node_name__ + estimator_params_str
215223
return retval
@@ -247,9 +255,22 @@ def get_node_state_type(self) -> NodeStateType:
247255
return self.__node_state_type__
248256

249257
def get_estimator(self):
258+
"""
259+
Return the estimator of the node
260+
261+
:return: The node's estimator
262+
"""
250263
return self.__estimator__
251264

252265
def get_parameterized_node(self, node_name, **params):
266+
"""
267+
Get a parameterized node, given kwargs **params, convert this node and update the estimator with the
268+
new set of parameters. It will clone the node and its underlying estimator.
269+
270+
:param node_name: New node name
271+
:param params: Updated parameters
272+
:return:
273+
"""
253274
cloned_node = self.clone()
254275
cloned_node.__node_name__ = node_name
255276
estimator = cloned_node.get_estimator()
@@ -311,7 +332,6 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
311332
"""
312333
super().__init__(node_name, estimator, NodeInputType.OR, NodeFiringType.ANY, NodeStateType.IMMUTABLE)
313334

314-
315335
def clone(self):
316336
"""
317337
Clones the given node and the underlying estimator as well, if it was initialized with
@@ -323,6 +343,17 @@ def clone(self):
323343

324344

325345
class AndEstimator(BaseEstimator):
346+
"""
347+
An and estimator, is part of the AndNode, it is very similar to a standard estimator, however the key
348+
difference is that it takes a `xy_list` as input and outputs an `xy`, contrasting to the EstimatorNode,
349+
which takes an input as `xy` and outputs `xy_t`.
350+
351+
In the pipeline execution, we expect three modes: (a) FIT: A regressor or classifier will call the fit
352+
and then pass on the transform results downstream, a non-regressor/classifier will call the fit_transform
353+
method, (b) PREDICT: A regressor or classifier will call the predict method, whereas a non-regressor/classifier
354+
will call the transform method, and (c) SCORE: A regressor will call the score method, and a non-regressor/classifer
355+
will call the transform method.
356+
"""
326357
@abstractmethod
327358
def transform(self, xy_list: list) -> Xy:
328359
raise NotImplementedError("And estimator needs to implement a transform method")

codeflare/pipelines/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import graphviz
2+
import codeflare.pipelines.Datamodel as dm
3+
4+
5+
def pipeline_to_graph(pipeline: dm.Pipeline) -> graphviz.Digraph:
6+
"""
7+
Converts the given pipeline to a networkX graph for visualization.
8+
9+
:param pipeline: Pipeline to convert to networkX graph
10+
:return: A directed graph representing this pipeline
11+
"""
12+
graph = graphviz.Digraph()
13+
pipeline_nodes = pipeline.get_nodes()
14+
for pre_node in pipeline_nodes.values():
15+
post_nodes = pipeline.get_post_nodes(pre_node)
16+
graph.node(pre_node.get_node_name())
17+
for post_node in post_nodes:
18+
graph.node(post_node.get_node_name())
19+
graph.edge(pre_node.get_node_name(), post_node.get_node_name())
20+
return graph

codeflare_pipelines.egg-info/SOURCES.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ codeflare/pipelines/Datamodel.py
55
codeflare/pipelines/Exceptions.py
66
codeflare/pipelines/Runtime.py
77
codeflare/pipelines/__init__.py
8+
codeflare/pipelines/utils.py
89
codeflare_pipelines.egg-info/PKG-INFO
910
codeflare_pipelines.egg-info/SOURCES.txt
1011
codeflare_pipelines.egg-info/dependency_links.txt

docs/.DS_Store

6 KB
Binary file not shown.

0 commit comments

Comments
 (0)