Skip to content

Latest commit

 

History

History
299 lines (236 loc) · 8.66 KB

Context.md

File metadata and controls

299 lines (236 loc) · 8.66 KB

tfgnn.Context

View source on GitHub

A composite tensor for graph context features.

tfgnn.Context(
    data: Data, spec: 'GraphPieceSpecBase'
)

The items of the context are the graph components (just like the items of a node set are the nodes and the items of an edge set are the edges). The Context is a composite tensor. It stores features that belong to a graph component as a whole, not any particular node or edge. Each context feature has a shape [*graph_shape, num_components, ...], where num_components is the number of graph components in a graph (could be ragged).

Args

data Nest of Field or subclasses of GraphPieceBase.
spec A subclass of GraphPieceSpecBase with a _data_spec that matches data.

Attributes

features A read-only mapping of feature name to feature specs.
indices_dtype The dtype for graph items indexing. One of tf.int32 or tf.int64.
num_components The number of graph components for each graph.
rank The rank of this Tensor. Guaranteed not to be None.
row_splits_dtype The dtype for ragged row partitions. One of tf.int32 or tf.int64.
shape A possibly-partial shape specification for this Tensor.

The returned tf.TensorShape is guaranteed to have a known rank and no unknown dimensions except possibly the outermost.

sizes The number of items in each graph component.
spec The public type specification of this tensor.
total_num_components The total number of graph components.
total_size The total number of items.

Methods

from_fields

View source

@classmethod
from_fields(
    *_,
    features: Optional[Fields] = None,
    sizes: Optional[Field] = None,
    shape: Optional[ShapeLike] = None,
    indices_dtype: Optional[tf.dtypes.DType] = None,
    validate: Optional[bool] = None
) -> 'Context'

Constructs a new instance from context fields.

Example:

tfgnn.Context.from_fields(features={'country_code': ['CH']})
Args
features A mapping from feature name to feature Tensor or RaggedTensor. All feature tensors must have shape [*graph_shape, num_components, *feature_shape], where num_components is the number of graph components (could be ragged); feature_shape are feature-specific dimensions (could be ragged).
sizes A Tensor of 1's with shape [*graph_shape, num_components], where num_components is the number of graph components (could be ragged). For symmetry with sizes in NodeSet and EdgeSet, this counts the items per graph component, but since the items of Context are the components themselves, each value is 1. Must be compatible with shape, if that is specified.
shape The shape of this tensor and a GraphTensor containing it, also known as the graph_shape. If not specified, the shape is inferred from sizes or set to [] if the sizes is not specified.
indices_dtype An indices_dtype of a GraphTensor containing this object, used as row_splits_dtype when batching potentially ragged fields. If sizes are specified they are casted to that type.
validate If true, use tf.assert ops to inspect the shapes of each field and check at runtime that they form a valid Context. The default behavior is set by the disable_graph_tensor_validation_at_runtime() and enable_graph_tensor_validation_at_runtime().
Returns
A Context composite tensor.

get_features_dict

View source

get_features_dict() -> Dict[FieldName, Field]

Returns features copy as a dictionary.

replace_features

View source

replace_features(
    features: Fields
) -> 'Context'

Returns a new instance with a new set of features.

set_shape

View source

set_shape(
    new_shape: ShapeLike
) -> 'GraphPieceBase'

Deprecated. Use with_shape().

with_indices_dtype

View source

with_indices_dtype(
    dtype: tf.dtypes.DType
) -> 'GraphPieceBase'

Returns a copy of this piece with the given indices dtype.

with_row_splits_dtype

View source

with_row_splits_dtype(
    dtype: tf.dtypes.DType
) -> 'GraphPieceBase'

Returns a copy of this piece with the given row splits dtype.

with_shape

View source

with_shape(
    new_shape: ShapeLike
) -> 'GraphPieceBase'

Enforce the common prefix shape on all the contained features.

__getitem__

View source

__getitem__(
    feature_name: FieldName
) -> Field

Indexing operator [] to access feature values by their name.