Skip to content

Latest commit

 

History

History
254 lines (211 loc) · 6.96 KB

NodeMulticlassClassification.md

File metadata and controls

254 lines (211 loc) · 6.96 KB

runner.NodeMulticlassClassification

View source on GitHub

Node multiclass classification via structured readout.

Inherits From: Task

runner.NodeMulticlassClassification(
    key: str = 'seed',
    *,
    feature_name: str = tfgnn.HIDDEN_STATE,
    readout_node_set: tfgnn.NodeSetName = '_readout',
    validate: bool = True,
    num_classes: Optional[int] = None,
    class_names: Optional[Sequence[str]] = None,
    per_class_statistics: bool = False,
    name: str = 'classification_logits',
    label_fn: Optional[LabelFn] = None,
    label_feature_name: Optional[str] = None
)

Args

key A string key to select between possibly multiple named readouts.
feature_name The name of the feature to read. If unset, tfgnn.HIDDEN_STATE will be read.
readout_node_set A string, defaults to "_readout". This is used as the name for the readout node set and as a name prefix for its edge sets.
validate Setting this to false disables the validity checks for the auxiliary edge sets. This is stronlgy discouraged, unless great care is taken to run tfgnn.validate_graph_tensor_for_readout() earlier on structurally unchanged GraphTensors.
num_classes The number of classes. Exactly one of num_classes or class_names must be specified
class_names The class names. Exactly one of num_classes or class_names must be specified
per_class_statistics Whether to compute statistics per class.
name The classification head's layer name. To control the naming of saved model outputs see the runner model exporters (e.g., KerasModelExporter).
label_fn A label extraction function. This function mutates the input GraphTensor. Mutually exclusive with label_feature_name.
label_feature_name A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input GraphTensor. Mutually exclusive with label_fn.

Methods

gather_activations

View source

gather_activations(
    inputs: GraphTensor
) -> Field

Gather activations from auxiliary node (and edge) sets.

losses

View source

losses() -> interfaces.Losses

Sparse categorical crossentropy loss.

metrics

View source

metrics() -> interfaces.Metrics

Sparse categorical metrics.

predict

View source

predict(
    inputs: tfgnn.GraphTensor
) -> interfaces.Predictions

Apply a linear head for classification.

Args
inputs A tfgnn.GraphTensor for classification.
Returns
The classification logits.

preprocess

View source

preprocess(
    inputs: GraphTensor
) -> tuple[GraphTensor, Field]

Preprocesses a scalar (after merge_batch_to_components) GraphTensor.

This function uses the Keras functional API to define non-trainable transformations of the symbolic input GraphTensor, which get executed during dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two responsibilities:

  1. Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
  2. Optionally, transforming input features. Some advanced modeling techniques require running the same base GNN on multiple different transformations, so this function may return a single GraphTensor or a non-empty sequence of GraphTensors. The corresponding base GNN output for each GraphTensor is provided to the predict(...) method.
Args
inputs A symbolic Keras GraphTensor for processing.
Returns
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to be used as labels.