Skip to content

Latest commit

 

History

History
216 lines (173 loc) · 5.88 KB

RootNodeBinaryClassification.md

File metadata and controls

216 lines (173 loc) · 5.88 KB

runner.RootNodeBinaryClassification

View source on GitHub

Root node binary (or multi-label) classification.

Inherits From: Task

runner.RootNodeBinaryClassification(
    node_set_name: str,
    units: int = 1,
    *,
    state_name: str = tfgnn.HIDDEN_STATE,
    name: str = 'classification_logits',
    label_fn: Optional[LabelFn] = None,
    label_feature_name: Optional[str] = None
)

Args

node_set_name The node set containing the root node.
units The units for the classification head. (Typically 1 for binary classification and the number of labels for multi-label classification.)
state_name The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
name The classification head's layer name. To control the naming of saved model outputs see the runner model exporters (e.g., KerasModelExporter).
label_fn A label extraction function. This function mutates the input GraphTensor. Mutually exclusive with label_feature_name.
label_feature_name A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input GraphTensor. Mutually exclusive with label_fn.

Methods

gather_activations

View source

gather_activations(
    inputs: GraphTensor
) -> Field

Gather activations from root nodes.

losses

View source

losses() -> interfaces.Losses

Returns arbitrary task specific losses.

metrics

View source

metrics() -> interfaces.Metrics

Returns arbitrary task specific metrics.

predict

View source

predict(
    inputs: tfgnn.GraphTensor
) -> interfaces.Predictions

Apply a linear head for classification.

Args
inputs A tfgnn.GraphTensor for classification.
Returns
The classification logits.

preprocess

View source

preprocess(
    inputs: GraphTensor
) -> tuple[GraphTensor, Field]

Preprocesses a scalar (after merge_batch_to_components) GraphTensor.

This function uses the Keras functional API to define non-trainable transformations of the symbolic input GraphTensor, which get executed during dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two responsibilities:

  1. Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
  2. Optionally, transforming input features. Some advanced modeling techniques require running the same base GNN on multiple different transformations, so this function may return a single GraphTensor or a non-empty sequence of GraphTensors. The corresponding base GNN output for each GraphTensor is provided to the predict(...) method.
Args
inputs A symbolic Keras GraphTensor for processing.
Returns
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to be used as labels.