Skip to content

Commit

Permalink
Switch tfgnn.broadcast() to implementation v2 and remove v1.
Browse files Browse the repository at this point in the history
This adds support for broadcasting to multiple edge sets (or node sets).
Matching changes are planned for tfgnn.pool() and the Keras wrappers of both.

PiperOrigin-RevId: 533987359
  • Loading branch information
arnoegw authored and tensorflower-gardener committed May 22, 2023
1 parent b01902f commit 8980ab1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 85 deletions.
2 changes: 1 addition & 1 deletion tensorflow_gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
broadcast_node_to_edges = broadcast_ops.broadcast_node_to_edges
broadcast_context_to_nodes = broadcast_ops.broadcast_context_to_nodes
broadcast_context_to_edges = broadcast_ops.broadcast_context_to_edges
broadcast = broadcast_ops.broadcast_v1 # TODO(b/265760014): Switch to v2.
broadcast = broadcast_ops.broadcast_v2
pool_edges_to_node = pool_ops_v1.pool_edges_to_node
pool_nodes_to_context = pool_ops_v1.pool_nodes_to_context
pool_edges_to_context = pool_ops_v1.pool_edges_to_context
Expand Down
85 changes: 9 additions & 76 deletions tensorflow_gnn/graph/broadcast_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,8 @@ def _broadcast_context(graph_tensor: GraphTensor,
repeats_sum_hint=node_or_edge_set.spec.total_size)


# TODO(b/265760014): Export as tfgnn.broadcast() and remove broadcast_v1().
# The difference is that v2 supports multiple node/edge sets and v1 does not,
# mirroring the difference between pool_v2() and pool_v1().
def broadcast_v2(
graph: GraphTensor,
graph_tensor: GraphTensor,
from_tag: IncidentNodeOrContextTag,
*,
edge_set_name: Union[Sequence[EdgeSetName], EdgeSetName, None] = None,
Expand All @@ -207,7 +204,7 @@ def broadcast_v2(
image of `tfgnn.pool()`, which comes in handy for some algorithms.
Args:
graph: A scalar GraphTensor.
graph_tensor: A scalar GraphTensor.
from_tag: Values are broadcast from context if this is `tfgnn.CONTEXT` or
from the incident node on each edge with this tag.
edge_set_name: The name of the edge set to which values are broadcast, or
Expand All @@ -230,10 +227,10 @@ def broadcast_v2(
If a list of names was specified, the result is a list of tensors,
with parallel indices.
"""
gt.check_scalar_graph_tensor(graph, "broadcast()")
gt.check_scalar_graph_tensor(graph_tensor, "broadcast()")
edge_set_names, node_set_names, got_sequence_args = (
tag_utils.get_edge_or_node_set_name_args_for_tag(
graph.spec, from_tag,
graph_tensor.spec, from_tag,
edge_set_name=edge_set_name, node_set_name=node_set_name,
function_name="broadcast()"))
del edge_set_name, node_set_name # Replaced by their cleaned-up versions.
Expand All @@ -244,82 +241,18 @@ def broadcast_v2(

if from_tag == const.CONTEXT:
if edge_set_names is not None:
result = [broadcast_context_to_edges(graph, name, **feature_kwargs)
result = [broadcast_context_to_edges(graph_tensor, name, **feature_kwargs)
for name in edge_set_names]
else:
result = [broadcast_context_to_nodes(graph, name, **feature_kwargs)
result = [broadcast_context_to_nodes(graph_tensor, name, **feature_kwargs)
for name in node_set_names]
else:
result = [broadcast_node_to_edges(graph, name, from_tag, **feature_kwargs)
for name in edge_set_names]
result = [
broadcast_node_to_edges(graph_tensor, name, from_tag, **feature_kwargs)
for name in edge_set_names]

if got_sequence_args:
return result
else:
assert len(result) == 1
return result[0]


# TODO(b/265760014): Remove in favor of broadcast_v2().
# The difference is that v2 supports multiple node/edge sets and v1 does not,
# mirroring the difference between pool_v2() and pool_v1().
def broadcast_v1(graph_tensor: GraphTensor,
from_tag: const.IncidentNodeOrContextTag,
*,
edge_set_name: Optional[EdgeSetName] = None,
node_set_name: Optional[NodeSetName] = None,
feature_value: Optional[Field] = None,
feature_name: Optional[FieldName] = None) -> Field:
"""Broadcasts values from nodes to edges, or from context to nodes or edges.
This function broadcasts from context if `from_tag=tfgnn.CONTEXT` and
broadcasts from incident nodes to edges if `from_tag` is an ordinary node tag
like `tfgnn.SOURCE` or `tfgnn.TARGET`. Most user code will not need this
flexibility and can directly call one of the underlying functions
`broadcast_node_to_edges()`, `broadcast_context_to_nodes()`, or
`broadcast_context_to_edges()`.
Args:
graph_tensor: A scalar GraphTensor.
from_tag: Values are broadcast from context if this is `tfgnn.CONTEXT` or
from the incident node on each edge with this tag.
edge_set_name: The name of the edge set to which values are broadcast.
node_set_name: The name of the node set to which values are broadcast.
Can only be set with `from_tag=tfgnn.CONTEXT`. Either edge_set_name or
node_set_name must be set.
feature_value: As for the underlying broadcast_*() function.
feature_name: As for the underlying broadcast_*() function.
Exactly one of feature_name or feature_value must be set.
Returns:
The result of the underlying broadcast_*() function.
"""
_validate_names_and_tag(
from_tag, edge_set_name=edge_set_name, node_set_name=node_set_name)
if from_tag == const.CONTEXT:
if node_set_name is not None:
return broadcast_context_to_nodes(
graph_tensor, node_set_name=node_set_name,
feature_value=feature_value, feature_name=feature_name)
else:
return broadcast_context_to_edges(
graph_tensor, edge_set_name=edge_set_name,
feature_value=feature_value, feature_name=feature_name)
else:
return broadcast_node_to_edges(
graph_tensor, edge_set_name=edge_set_name, node_tag=from_tag,
feature_value=feature_value, feature_name=feature_name)


# TODO(b/265760014): Remove along with broadcast_v1().
def _validate_names_and_tag(tag, *, edge_set_name, node_set_name):
"""Helper for broadcast_v1()."""
if tag == const.CONTEXT:
num_names = bool(edge_set_name is None) + bool(node_set_name is None)
if num_names != 1:
raise ValueError("With tag CONTEXT, must pass exactly 1 of "
f"edge_set_name, node_set_name; got {num_names}.")
else:
if edge_set_name is None or node_set_name is not None:
raise ValueError("Must pass edge_set_name but not node_set_name "
"for a tag other than CONTEXT.")
22 changes: 15 additions & 7 deletions tensorflow_gnn/graph/broadcast_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
as_ragged = tf.ragged.constant


class BroadcastingTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for broadcasting operations."""
class BroadcastXToYTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for basic broadcasting operations broadcast_*_to_*().
For consistency, some tests run the corresponging call to the generic
broadcast_v2() function as well, but see BroadcastV2Test for more on that.
"""

@parameterized.named_parameters(
("WithAdjacency", False),
Expand Down Expand Up @@ -73,7 +77,7 @@ def testEdgeFieldFromNode(self, use_hyper_adjacency=False):
graph, "edge", const.SOURCE, feature_name=fname))
self.assertAllEqual(
expected,
broadcast_ops.broadcast_v1(
broadcast_ops.broadcast_v2(
graph, const.SOURCE, edge_set_name="edge", feature_name=fname))
for fname, expected in expected_target_fields.items():
self.assertAllEqual(
Expand All @@ -82,7 +86,7 @@ def testEdgeFieldFromNode(self, use_hyper_adjacency=False):
graph, "edge", const.TARGET, feature_name=fname))
self.assertAllEqual(
expected,
broadcast_ops.broadcast_v1(
broadcast_ops.broadcast_v2(
graph, const.TARGET, edge_set_name="edge", feature_name=fname))

@parameterized.parameters([
Expand Down Expand Up @@ -138,7 +142,7 @@ def testNodeFieldFromContext(self, description: str, context: gt.Context,
graph, "node", feature_name=fname))
self.assertAllEqual(
expected,
broadcast_ops.broadcast_v1(
broadcast_ops.broadcast_v2(
graph, const.CONTEXT, node_set_name="node", feature_name=fname))

@parameterized.parameters([
Expand Down Expand Up @@ -204,12 +208,16 @@ def testEdgeFieldFromContext(self, description: str, context: gt.Context,
graph, "edge", feature_name=fname))
self.assertAllEqual(
expected,
broadcast_ops.broadcast_v1(
broadcast_ops.broadcast_v2(
graph, const.CONTEXT, edge_set_name="edge", feature_name=fname))


class BroadcastV2Test(tf.test.TestCase, parameterized.TestCase):
"""Tests for generic broadcast_v2(), on top of already-tested basic ops."""
"""Tests for generic broadcast_v2() wrapper.
These tests assume correctness of the underlying broadcast_*_to_*() ops;
see BroadcastXtoYTest for these.
"""

def testOneEdgeSetFromTag(self):
input_graph = _get_test_graph_broadcast()
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/keras/layers/convolution_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def bind_receiver_args(fn):
return lambda feature_value, **kwargs: fn(
graph, receiver_tag, **name_kwarg,
feature_value=feature_value, **kwargs)
broadcast_from_receiver = bind_receiver_args(broadcast_ops.broadcast_v1)
broadcast_from_receiver = bind_receiver_args(broadcast_ops.broadcast_v2)
pool_to_receiver = bind_receiver_args(pool_ops_v1.pool_v1)
if self._extra_receiver_ops is None:
extra_receiver_ops_kwarg = {} # Pass no argument for this.
Expand Down

0 comments on commit 8980ab1

Please sign in to comment.