Skip to content

Commit 5077e2d

Browse files
Updating docs with more info, pipeline done
1 parent fef91d0 commit 5077e2d

File tree

1 file changed

+200
-7
lines changed

1 file changed

+200
-7
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 200 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
199199
self.__node_state_type__ = node_state_type
200200

201201
def __str__(self):
202+
"""
203+
The string representation, which is the node name itself
204+
205+
:return: Node string
206+
"""
202207
return self.__node_name__
203208

204209
def get_node_name(self) -> str:
@@ -266,6 +271,17 @@ class EstimatorNode(Node):
266271
267272
This estimator node is typically an OR node, with ANY firing semantics, and IMMUTABLE state. For
268273
partial fit, we will have to define a different node type to keep semantics very clear.
274+
275+
.. code-block:: python
276+
277+
random_forest = RandomForestClassifier(n_estimators=200)
278+
node_rf = dm.EstimatorNode('randomforest', random_forest)
279+
280+
# get the estimator
281+
node_rf_estimator = node_rf.get_estimator()
282+
283+
# clone the node, clones the estimator as well
284+
node_rf_cloned = node_rf.clone()
269285
"""
270286

271287
def __init__(self, node_name: str, estimator: BaseEstimator):
@@ -283,11 +299,16 @@ def get_estimator(self) -> BaseEstimator:
283299
"""
284300
Return the estimator that this was initialize with
285301
286-
:return: Estimator
302+
:return: Estimator that was initialized
287303
"""
288304
return self.__estimator__
289305

290306
def clone(self):
307+
"""
308+
Clones the given node and the underlying estimator as well, if it was initialized with
309+
310+
:return: A cloned node
311+
"""
291312
cloned_estimator = base.clone(self.__estimator__)
292313
return EstimatorNode(self.__node_name__, cloned_estimator)
293314

@@ -317,17 +338,31 @@ def clone(self):
317338

318339

319340
class Edge:
320-
__from_node__ = None
321-
__to_node__ = None
341+
"""
342+
An edge connects two nodes, it's an internal data structure for pipeline construction. An edge
343+
is a directed edge and has a "from_node" and a "to_node".
322344
345+
An edge also defines a hash function and an equality, where the equality is on the from and to
346+
node names being the same.
347+
"""
323348
def __init__(self, from_node: Node, to_node: Node):
324349
self.__from_node__ = from_node
325350
self.__to_node__ = to_node
326351

327352
def get_from_node(self) -> Node:
353+
"""
354+
The from_node of this edge (originating node)
355+
356+
:return: The from_node of this edge
357+
"""
328358
return self.__from_node__
329359

330360
def get_to_node(self) -> Node:
361+
"""
362+
The to_node of this edge (terminating node)
363+
364+
:return: The to_node of this edge
365+
"""
331366
return self.__to_node__
332367

333368
def __str__(self):
@@ -361,7 +396,62 @@ def get_object_ref(self):
361396

362397
class Pipeline:
363398
"""
364-
The pipeline class that defines the DAG structure composed of Node(s). The
399+
The pipeline class that defines the DAG structure composed of Node(s). This is the core data structure that
400+
defines the computation graph. A key note is that unlike SKLearn pipeline, CodeFlare pipelines are "abstract"
401+
graphs and get realized only when executed. Upon execution, they can potentially be multiple pathways in
402+
the pipeline, i.e. multiple "single" pipelines can be realized.
403+
404+
Examples
405+
--------
406+
Pipelines can be constructed quite simply using the builder paradigm with add_node and/or add_edge. In its
407+
simplest form, one can create nodes and then wire the DAG by adding edges. An example that does a simple
408+
pipeline is below:
409+
410+
.. code-block:: python
411+
412+
feature_union = FeatureUnion(transformer_list=[('PCA', PCA()),
413+
('Nystroem', Nystroem()), ('SelectKBest', SelectKBest(k=3))])
414+
random_forest = RandomForestClassifier(n_estimators=200)
415+
node_fu = dm.EstimatorNode('feature_union', feature_union)
416+
node_rf = dm.EstimatorNode('randomforest', random_forest)
417+
pipeline.add_edge(node_fu, node_rf)
418+
419+
One can of course construct complex pipelines with multiple outgoing edges as well. An example of one that
420+
explores multiple models is shown below:
421+
422+
.. code-block:: python
423+
424+
preprocessor = ColumnTransformer(
425+
transformers=[
426+
('num', numeric_transformer, numeric_features),
427+
('cat', categorical_transformer, categorical_features)])
428+
429+
classifiers = [
430+
RandomForestClassifier(),
431+
GradientBoostingClassifier()
432+
]
433+
pipeline = dm.Pipeline()
434+
node_pre = dm.EstimatorNode('preprocess', preprocessor)
435+
node_rf = dm.EstimatorNode('random_forest', classifiers[0])
436+
node_gb = dm.EstimatorNode('gradient_boost', classifiers[1])
437+
438+
pipeline.add_edge(node_pre, node_rf)
439+
pipeline.add_edge(node_pre, node_gb)
440+
441+
A pipeline can be saved and loaded, which in essence saves the "graph" and not the state of this pipeline.
442+
For saving the state of the pipeline, one can use the Runtime's save method! Save/load of pipeline uses
443+
Pickle protocol 5.
444+
445+
.. code-block:: python
446+
447+
fname = 'save_pipeline.cfp'
448+
fh = open(fname, 'wb')
449+
pipeline.save(fh)
450+
fh.close()
451+
452+
r_fh = open(fname, 'rb')
453+
saved_pipeline = dm.Pipeline.load(r_fh)
454+
365455
"""
366456

367457
def __init__(self):
@@ -371,6 +461,12 @@ def __init__(self):
371461
self.__level_nodes__ = None
372462

373463
def add_node(self, node: Node):
464+
"""
465+
Adds a node to this pipeline
466+
467+
:param node: The node to add
468+
:return: None
469+
"""
374470
self.__node_levels__ = None
375471
self.__level_nodes__ = None
376472
if node not in self.__pre_graph__.keys():
@@ -395,6 +491,13 @@ def get_str(nodes: list):
395491
return res
396492

397493
def add_edge(self, from_node: Node, to_node: Node):
494+
"""
495+
Adds an edge to this pipeline
496+
497+
:param from_node: The from node
498+
:param to_node: The to node
499+
:return: None
500+
"""
398501
self.add_node(from_node)
399502
self.add_node(to_node)
400503

@@ -408,6 +511,14 @@ def get_postimage(self, node: Node):
408511
return self.__post_graph__[node]
409512

410513
def compute_node_level(self, node: Node, result: dict):
514+
"""
515+
Computes the node levels for a given node, an internal supporting function that is recursive, so it
516+
takes the result computed so far.
517+
518+
:param node: The node for which level needs to be computed
519+
:param result: The node levels that have already been computed
520+
:return: The level for this node
521+
"""
411522
if node in result:
412523
return result[node]
413524

@@ -426,6 +537,13 @@ def compute_node_level(self, node: Node, result: dict):
426537
return max_level + 1
427538

428539
def compute_node_levels(self):
540+
"""
541+
Computes node levels for all nodes. If a cache of node levels from previous calls exists, it will return
542+
the cache to avoid repeated computation.
543+
544+
:return: The mapping from node to its level as a dict
545+
"""
546+
# TODO: This is incorrect when pipelines are mutable
429547
if self.__node_levels__:
430548
return self.__node_levels__
431549

@@ -438,13 +556,24 @@ def compute_node_levels(self):
438556
return self.__node_levels__
439557

440558
def compute_max_level(self):
559+
"""
560+
Get the max depth of this pipeline graph.
561+
562+
:return: The max depth of pipeline
563+
"""
441564
levels = self.compute_node_levels()
442565
max_level = 0
443566
for node, node_level in levels.items():
444567
max_level = max(node_level, max_level)
445568
return max_level
446569

447570
def get_nodes_by_level(self):
571+
"""
572+
A mapping from level to a list of nodes, useful for pipeline execution time. Similar to compute_levels,
573+
this method will return a cache if it exists, else will compute the levels and cache it.
574+
575+
:return: The mapping from level to a list of nodes at that level
576+
"""
448577
if self.__level_nodes__:
449578
return self.__level_nodes__
450579

@@ -460,16 +589,19 @@ def get_nodes_by_level(self):
460589
self.__level_nodes__ = result
461590
return self.__level_nodes__
462591

463-
###
464-
# Get downstream node
465-
###
466592
def get_post_nodes(self, node: Node):
467593
return self.__post_graph__[node]
468594

469595
def get_pre_nodes(self, node: Node):
470596
return self.__pre_graph__[node]
471597

472598
def get_pre_edges(self, node: Node):
599+
"""
600+
Get the incoming edges to a specific node.
601+
602+
:param node: Given node
603+
:return: Incoming edges for the node
604+
"""
473605
pre_edges = []
474606
pre_nodes = self.__pre_graph__[node]
475607
# Empty pre
@@ -481,6 +613,12 @@ def get_pre_edges(self, node: Node):
481613
return pre_edges
482614

483615
def get_post_edges(self, node: Node):
616+
"""
617+
Get the outgoing edges for the given node
618+
619+
:param node: Given node
620+
:return: Outgoing edges for the node
621+
"""
484622
post_edges = []
485623
post_nodes = self.__post_graph__[node]
486624
# Empty post
@@ -492,10 +630,21 @@ def get_post_edges(self, node: Node):
492630
return post_edges
493631

494632
def is_terminal(self, node: Node):
633+
"""
634+
Checks if the given node is a terminal node, i.e. has no outgoing edges
635+
636+
:param node: Node to check terminal condition on
637+
:return: True if terminal else False
638+
"""
495639
post_nodes = self.__post_graph__[node]
496640
return not post_nodes
497641

498642
def get_terminal_nodes(self):
643+
"""
644+
Get all the terminal nodes for this pipeline
645+
646+
:return: List of all terminal nodes
647+
"""
499648
# dict from level to nodes
500649
terminal_nodes = []
501650
for node in self.__pre_graph__.keys():
@@ -504,18 +653,42 @@ def get_terminal_nodes(self):
504653
return terminal_nodes
505654

506655
def get_nodes(self):
656+
"""
657+
Get the nodes in this pipeline
658+
659+
:return: Node name to node dict
660+
"""
507661
nodes = {}
508662
for node in self.__pre_graph__.keys():
509663
nodes[node.get_node_name()] = node
510664
return nodes
511665

512666
def get_pre_nodes(self, node):
667+
"""
668+
Get the nodes that have edges incoming to the given node
669+
670+
:param node: Given node
671+
:return: List of nodes with incoming edges to the provided node
672+
"""
513673
return self.__pre_graph__[node]
514674

515675
def get_post_nodes(self, node):
676+
"""
677+
Get the nodes that have edges outgoing to the given node
678+
679+
:param node: Given node
680+
:return: List of nodes with outgoing edges from the provided node
681+
"""
516682
return self.__post_graph__[node]
517683

518684
def save(self, filehandle):
685+
"""
686+
Saves the pipeline graph (without state) to a file. A filehandle with write and binary mode
687+
is expected.
688+
689+
:param filehandle: Filehandle with wb mode
690+
:return: None
691+
"""
519692
nodes = {}
520693
edges = []
521694

@@ -534,6 +707,12 @@ def save(self, filehandle):
534707

535708
@staticmethod
536709
def load(filehandle):
710+
"""
711+
Loads a pipeline that has been saved given the filehandle. Filehandle is in rb format.
712+
713+
:param filehandle: Filehandle to load pipeline from
714+
:return:
715+
"""
537716
saved_pipeline = pickle.load(filehandle)
538717
if not isinstance(saved_pipeline, _SavedPipeline):
539718
raise pe.PipelineException("Filehandle is not a saved pipeline instance")
@@ -551,14 +730,28 @@ def load(filehandle):
551730

552731

553732
class _SavedPipeline:
733+
"""
734+
Internal class that serializes the pipeline so that it can be pickled. As noted, this only captures
735+
the graph and not the state of the pipeline.
736+
"""
554737
def __init__(self, nodes, edges):
555738
self.__nodes__ = nodes
556739
self.__edges__ = edges
557740

558741
def get_nodes(self):
742+
"""
743+
Nodes of the saved pipeline
744+
745+
:return: Dict of node name to node mapping
746+
"""
559747
return self.__nodes__
560748

561749
def get_edges(self):
750+
"""
751+
Edges of the saved pipeline
752+
753+
:return: List of edges
754+
"""
562755
return self.__edges__
563756

564757

0 commit comments

Comments
 (0)