@@ -199,6 +199,11 @@ def __init__(self, node_name, node_input_type: NodeInputType, node_firing_type:
199
199
self .__node_state_type__ = node_state_type
200
200
201
201
def __str__ (self ):
202
+ """
203
+ The string representation, which is the node name itself
204
+
205
+ :return: Node string
206
+ """
202
207
return self .__node_name__
203
208
204
209
def get_node_name (self ) -> str :
@@ -266,6 +271,17 @@ class EstimatorNode(Node):
266
271
267
272
This estimator node is typically an OR node, with ANY firing semantics, and IMMUTABLE state. For
268
273
partial fit, we will have to define a different node type to keep semantics very clear.
274
+
275
+ .. code-block:: python
276
+
277
+ random_forest = RandomForestClassifier(n_estimators=200)
278
+ node_rf = dm.EstimatorNode('randomforest', random_forest)
279
+
280
+ # get the estimator
281
+ node_rf_estimator = node_rf.get_estimator()
282
+
283
+ # clone the node, clones the estimator as well
284
+ node_rf_cloned = node_rf.clone()
269
285
"""
270
286
271
287
def __init__ (self , node_name : str , estimator : BaseEstimator ):
@@ -283,11 +299,16 @@ def get_estimator(self) -> BaseEstimator:
283
299
"""
284
300
Return the estimator that this was initialize with
285
301
286
- :return: Estimator
302
+ :return: Estimator that was initialized
287
303
"""
288
304
return self .__estimator__
289
305
290
306
def clone (self ):
307
+ """
308
+ Clones the given node and the underlying estimator as well, if it was initialized with
309
+
310
+ :return: A cloned node
311
+ """
291
312
cloned_estimator = base .clone (self .__estimator__ )
292
313
return EstimatorNode (self .__node_name__ , cloned_estimator )
293
314
@@ -317,17 +338,31 @@ def clone(self):
317
338
318
339
319
340
class Edge :
320
- __from_node__ = None
321
- __to_node__ = None
341
+ """
342
+ An edge connects two nodes, it's an internal data structure for pipeline construction. An edge
343
+ is a directed edge and has a "from_node" and a "to_node".
322
344
345
+ An edge also defines a hash function and an equality, where the equality is on the from and to
346
+ node names being the same.
347
+ """
323
348
def __init__ (self , from_node : Node , to_node : Node ):
324
349
self .__from_node__ = from_node
325
350
self .__to_node__ = to_node
326
351
327
352
def get_from_node (self ) -> Node :
353
+ """
354
+ The from_node of this edge (originating node)
355
+
356
+ :return: The from_node of this edge
357
+ """
328
358
return self .__from_node__
329
359
330
360
def get_to_node (self ) -> Node :
361
+ """
362
+ The to_node of this edge (terminating node)
363
+
364
+ :return: The to_node of this edge
365
+ """
331
366
return self .__to_node__
332
367
333
368
def __str__ (self ):
@@ -361,7 +396,62 @@ def get_object_ref(self):
361
396
362
397
class Pipeline :
363
398
"""
364
- The pipeline class that defines the DAG structure composed of Node(s). The
399
+ The pipeline class that defines the DAG structure composed of Node(s). This is the core data structure that
400
+ defines the computation graph. A key note is that unlike SKLearn pipeline, CodeFlare pipelines are "abstract"
401
+ graphs and get realized only when executed. Upon execution, they can potentially be multiple pathways in
402
+ the pipeline, i.e. multiple "single" pipelines can be realized.
403
+
404
+ Examples
405
+ --------
406
+ Pipelines can be constructed quite simply using the builder paradigm with add_node and/or add_edge. In its
407
+ simplest form, one can create nodes and then wire the DAG by adding edges. An example that does a simple
408
+ pipeline is below:
409
+
410
+ .. code-block:: python
411
+
412
+ feature_union = FeatureUnion(transformer_list=[('PCA', PCA()),
413
+ ('Nystroem', Nystroem()), ('SelectKBest', SelectKBest(k=3))])
414
+ random_forest = RandomForestClassifier(n_estimators=200)
415
+ node_fu = dm.EstimatorNode('feature_union', feature_union)
416
+ node_rf = dm.EstimatorNode('randomforest', random_forest)
417
+ pipeline.add_edge(node_fu, node_rf)
418
+
419
+ One can of course construct complex pipelines with multiple outgoing edges as well. An example of one that
420
+ explores multiple models is shown below:
421
+
422
+ .. code-block:: python
423
+
424
+ preprocessor = ColumnTransformer(
425
+ transformers=[
426
+ ('num', numeric_transformer, numeric_features),
427
+ ('cat', categorical_transformer, categorical_features)])
428
+
429
+ classifiers = [
430
+ RandomForestClassifier(),
431
+ GradientBoostingClassifier()
432
+ ]
433
+ pipeline = dm.Pipeline()
434
+ node_pre = dm.EstimatorNode('preprocess', preprocessor)
435
+ node_rf = dm.EstimatorNode('random_forest', classifiers[0])
436
+ node_gb = dm.EstimatorNode('gradient_boost', classifiers[1])
437
+
438
+ pipeline.add_edge(node_pre, node_rf)
439
+ pipeline.add_edge(node_pre, node_gb)
440
+
441
+ A pipeline can be saved and loaded, which in essence saves the "graph" and not the state of this pipeline.
442
+ For saving the state of the pipeline, one can use the Runtime's save method! Save/load of pipeline uses
443
+ Pickle protocol 5.
444
+
445
+ .. code-block:: python
446
+
447
+ fname = 'save_pipeline.cfp'
448
+ fh = open(fname, 'wb')
449
+ pipeline.save(fh)
450
+ fh.close()
451
+
452
+ r_fh = open(fname, 'rb')
453
+ saved_pipeline = dm.Pipeline.load(r_fh)
454
+
365
455
"""
366
456
367
457
def __init__ (self ):
@@ -371,6 +461,12 @@ def __init__(self):
371
461
self .__level_nodes__ = None
372
462
373
463
def add_node (self , node : Node ):
464
+ """
465
+ Adds a node to this pipeline
466
+
467
+ :param node: The node to add
468
+ :return: None
469
+ """
374
470
self .__node_levels__ = None
375
471
self .__level_nodes__ = None
376
472
if node not in self .__pre_graph__ .keys ():
@@ -395,6 +491,13 @@ def get_str(nodes: list):
395
491
return res
396
492
397
493
def add_edge (self , from_node : Node , to_node : Node ):
494
+ """
495
+ Adds an edge to this pipeline
496
+
497
+ :param from_node: The from node
498
+ :param to_node: The to node
499
+ :return: None
500
+ """
398
501
self .add_node (from_node )
399
502
self .add_node (to_node )
400
503
@@ -408,6 +511,14 @@ def get_postimage(self, node: Node):
408
511
return self .__post_graph__ [node ]
409
512
410
513
def compute_node_level (self , node : Node , result : dict ):
514
+ """
515
+ Computes the node levels for a given node, an internal supporting function that is recursive, so it
516
+ takes the result computed so far.
517
+
518
+ :param node: The node for which level needs to be computed
519
+ :param result: The node levels that have already been computed
520
+ :return: The level for this node
521
+ """
411
522
if node in result :
412
523
return result [node ]
413
524
@@ -426,6 +537,13 @@ def compute_node_level(self, node: Node, result: dict):
426
537
return max_level + 1
427
538
428
539
def compute_node_levels (self ):
540
+ """
541
+ Computes node levels for all nodes. If a cache of node levels from previous calls exists, it will return
542
+ the cache to avoid repeated computation.
543
+
544
+ :return: The mapping from node to its level as a dict
545
+ """
546
+ # TODO: This is incorrect when pipelines are mutable
429
547
if self .__node_levels__ :
430
548
return self .__node_levels__
431
549
@@ -438,13 +556,24 @@ def compute_node_levels(self):
438
556
return self .__node_levels__
439
557
440
558
def compute_max_level (self ):
559
+ """
560
+ Get the max depth of this pipeline graph.
561
+
562
+ :return: The max depth of pipeline
563
+ """
441
564
levels = self .compute_node_levels ()
442
565
max_level = 0
443
566
for node , node_level in levels .items ():
444
567
max_level = max (node_level , max_level )
445
568
return max_level
446
569
447
570
def get_nodes_by_level (self ):
571
+ """
572
+ A mapping from level to a list of nodes, useful for pipeline execution time. Similar to compute_levels,
573
+ this method will return a cache if it exists, else will compute the levels and cache it.
574
+
575
+ :return: The mapping from level to a list of nodes at that level
576
+ """
448
577
if self .__level_nodes__ :
449
578
return self .__level_nodes__
450
579
@@ -460,16 +589,19 @@ def get_nodes_by_level(self):
460
589
self .__level_nodes__ = result
461
590
return self .__level_nodes__
462
591
463
- ###
464
- # Get downstream node
465
- ###
466
592
def get_post_nodes (self , node : Node ):
467
593
return self .__post_graph__ [node ]
468
594
469
595
def get_pre_nodes (self , node : Node ):
470
596
return self .__pre_graph__ [node ]
471
597
472
598
def get_pre_edges (self , node : Node ):
599
+ """
600
+ Get the incoming edges to a specific node.
601
+
602
+ :param node: Given node
603
+ :return: Incoming edges for the node
604
+ """
473
605
pre_edges = []
474
606
pre_nodes = self .__pre_graph__ [node ]
475
607
# Empty pre
@@ -481,6 +613,12 @@ def get_pre_edges(self, node: Node):
481
613
return pre_edges
482
614
483
615
def get_post_edges (self , node : Node ):
616
+ """
617
+ Get the outgoing edges for the given node
618
+
619
+ :param node: Given node
620
+ :return: Outgoing edges for the node
621
+ """
484
622
post_edges = []
485
623
post_nodes = self .__post_graph__ [node ]
486
624
# Empty post
@@ -492,10 +630,21 @@ def get_post_edges(self, node: Node):
492
630
return post_edges
493
631
494
632
def is_terminal (self , node : Node ):
633
+ """
634
+ Checks if the given node is a terminal node, i.e. has no outgoing edges
635
+
636
+ :param node: Node to check terminal condition on
637
+ :return: True if terminal else False
638
+ """
495
639
post_nodes = self .__post_graph__ [node ]
496
640
return not post_nodes
497
641
498
642
def get_terminal_nodes (self ):
643
+ """
644
+ Get all the terminal nodes for this pipeline
645
+
646
+ :return: List of all terminal nodes
647
+ """
499
648
# dict from level to nodes
500
649
terminal_nodes = []
501
650
for node in self .__pre_graph__ .keys ():
@@ -504,18 +653,42 @@ def get_terminal_nodes(self):
504
653
return terminal_nodes
505
654
506
655
def get_nodes (self ):
656
+ """
657
+ Get the nodes in this pipeline
658
+
659
+ :return: Node name to node dict
660
+ """
507
661
nodes = {}
508
662
for node in self .__pre_graph__ .keys ():
509
663
nodes [node .get_node_name ()] = node
510
664
return nodes
511
665
512
666
def get_pre_nodes (self , node ):
667
+ """
668
+ Get the nodes that have edges incoming to the given node
669
+
670
+ :param node: Given node
671
+ :return: List of nodes with incoming edges to the provided node
672
+ """
513
673
return self .__pre_graph__ [node ]
514
674
515
675
def get_post_nodes (self , node ):
676
+ """
677
+ Get the nodes that have edges outgoing to the given node
678
+
679
+ :param node: Given node
680
+ :return: List of nodes with outgoing edges from the provided node
681
+ """
516
682
return self .__post_graph__ [node ]
517
683
518
684
def save (self , filehandle ):
685
+ """
686
+ Saves the pipeline graph (without state) to a file. A filehandle with write and binary mode
687
+ is expected.
688
+
689
+ :param filehandle: Filehandle with wb mode
690
+ :return: None
691
+ """
519
692
nodes = {}
520
693
edges = []
521
694
@@ -534,6 +707,12 @@ def save(self, filehandle):
534
707
535
708
@staticmethod
536
709
def load (filehandle ):
710
+ """
711
+ Loads a pipeline that has been saved given the filehandle. Filehandle is in rb format.
712
+
713
+ :param filehandle: Filehandle to load pipeline from
714
+ :return:
715
+ """
537
716
saved_pipeline = pickle .load (filehandle )
538
717
if not isinstance (saved_pipeline , _SavedPipeline ):
539
718
raise pe .PipelineException ("Filehandle is not a saved pipeline instance" )
@@ -551,14 +730,28 @@ def load(filehandle):
551
730
552
731
553
732
class _SavedPipeline :
733
+ """
734
+ Internal class that serializes the pipeline so that it can be pickled. As noted, this only captures
735
+ the graph and not the state of the pipeline.
736
+ """
554
737
def __init__ (self , nodes , edges ):
555
738
self .__nodes__ = nodes
556
739
self .__edges__ = edges
557
740
558
741
def get_nodes (self ):
742
+ """
743
+ Nodes of the saved pipeline
744
+
745
+ :return: Dict of node name to node mapping
746
+ """
559
747
return self .__nodes__
560
748
561
749
def get_edges (self ):
750
+ """
751
+ Edges of the saved pipeline
752
+
753
+ :return: List of edges
754
+ """
562
755
return self .__edges__
563
756
564
757
0 commit comments