@@ -200,6 +200,9 @@ class Node(ABC):
200
200
A node class that is an abstract one, this is capturing basic info re the Node.
201
201
The hash code of this node is the name of the node and equality is defined if the
202
202
node name and the type of the node match.
203
+
204
+ When doing a grid search, a node can be parameterized with new params for the estimator and updated. This
205
+ is an internal method used by grid search.
203
206
"""
204
207
205
208
def __init__ (self , node_name , estimator : BaseEstimator , node_input_type : NodeInputType , node_firing_type : NodeFiringType , node_state_type : NodeStateType ):
@@ -210,6 +213,11 @@ def __init__(self, node_name, estimator: BaseEstimator, node_input_type: NodeInp
210
213
self .__node_state_type__ = node_state_type
211
214
212
215
def __str__ (self ):
216
+ """
217
+ Returns a string representation of the node along with the parameters of the estimator of the node.
218
+
219
+ :return: String representation of the node
220
+ """
213
221
estimator_params_str = str (self .get_estimator ().get_params ())
214
222
retval = self .__node_name__ + estimator_params_str
215
223
return retval
@@ -247,9 +255,22 @@ def get_node_state_type(self) -> NodeStateType:
247
255
return self .__node_state_type__
248
256
249
257
def get_estimator (self ):
258
+ """
259
+ Return the estimator of the node
260
+
261
+ :return: The node's estimator
262
+ """
250
263
return self .__estimator__
251
264
252
265
def get_parameterized_node (self , node_name , ** params ):
266
+ """
267
+ Get a parameterized node, given kwargs **params, convert this node and update the estimator with the
268
+ new set of parameters. It will clone the node and its underlying estimator.
269
+
270
+ :param node_name: New node name
271
+ :param params: Updated parameters
272
+ :return:
273
+ """
253
274
cloned_node = self .clone ()
254
275
cloned_node .__node_name__ = node_name
255
276
estimator = cloned_node .get_estimator ()
@@ -311,7 +332,6 @@ def __init__(self, node_name: str, estimator: BaseEstimator):
311
332
"""
312
333
super ().__init__ (node_name , estimator , NodeInputType .OR , NodeFiringType .ANY , NodeStateType .IMMUTABLE )
313
334
314
-
315
335
def clone (self ):
316
336
"""
317
337
Clones the given node and the underlying estimator as well, if it was initialized with
@@ -323,6 +343,17 @@ def clone(self):
323
343
324
344
325
345
class AndEstimator (BaseEstimator ):
346
+ """
347
+ An and estimator, is part of the AndNode, it is very similar to a standard estimator, however the key
348
+ difference is that it takes a `xy_list` as input and outputs an `xy`, contrasting to the EstimatorNode,
349
+ which takes an input as `xy` and outputs `xy_t`.
350
+
351
+ In the pipeline execution, we expect three modes: (a) FIT: A regressor or classifier will call the fit
352
+ and then pass on the transform results downstream, a non-regressor/classifier will call the fit_transform
353
+ method, (b) PREDICT: A regressor or classifier will call the predict method, whereas a non-regressor/classifier
354
+ will call the transform method, and (c) SCORE: A regressor will call the score method, and a non-regressor/classifer
355
+ will call the transform method.
356
+ """
326
357
@abstractmethod
327
358
def transform (self , xy_list : list ) -> Xy :
328
359
raise NotImplementedError ("And estimator needs to implement a transform method" )
0 commit comments