From 8980ab1f1b0d7fb7aa77b4bc282f46bb3db1f141 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Mon, 22 May 2023 01:12:51 -0700 Subject: [PATCH] Switch tfgnn.broadcast() to implementation v2 and remove v1. This adds support for broadcasting to multiple edge sets (or node sets). Matching changes are planned for tfgnn.pool() and the Keras wrappers of both. PiperOrigin-RevId: 533987359 --- tensorflow_gnn/__init__.py | 2 +- tensorflow_gnn/graph/broadcast_ops.py | 85 ++----------------- tensorflow_gnn/graph/broadcast_ops_test.py | 22 +++-- .../keras/layers/convolution_base.py | 2 +- 4 files changed, 26 insertions(+), 85 deletions(-) diff --git a/tensorflow_gnn/__init__.py b/tensorflow_gnn/__init__.py index 2308ec21..797a218b 100644 --- a/tensorflow_gnn/__init__.py +++ b/tensorflow_gnn/__init__.py @@ -152,7 +152,7 @@ broadcast_node_to_edges = broadcast_ops.broadcast_node_to_edges broadcast_context_to_nodes = broadcast_ops.broadcast_context_to_nodes broadcast_context_to_edges = broadcast_ops.broadcast_context_to_edges -broadcast = broadcast_ops.broadcast_v1 # TODO(b/265760014): Switch to v2. +broadcast = broadcast_ops.broadcast_v2 pool_edges_to_node = pool_ops_v1.pool_edges_to_node pool_nodes_to_context = pool_ops_v1.pool_nodes_to_context pool_edges_to_context = pool_ops_v1.pool_edges_to_context diff --git a/tensorflow_gnn/graph/broadcast_ops.py b/tensorflow_gnn/graph/broadcast_ops.py index 4289b66e..5ca63573 100644 --- a/tensorflow_gnn/graph/broadcast_ops.py +++ b/tensorflow_gnn/graph/broadcast_ops.py @@ -178,11 +178,8 @@ def _broadcast_context(graph_tensor: GraphTensor, repeats_sum_hint=node_or_edge_set.spec.total_size) -# TODO(b/265760014): Export as tfgnn.broadcast() and remove broadcast_v1(). -# The difference is that v2 supports multiple node/edge sets and v1 does not, -# mirroring the difference between pool_v2() and pool_v1(). def broadcast_v2( - graph: GraphTensor, + graph_tensor: GraphTensor, from_tag: IncidentNodeOrContextTag, *, edge_set_name: Union[Sequence[EdgeSetName], EdgeSetName, None] = None, @@ -207,7 +204,7 @@ def broadcast_v2( image of `tfgnn.pool()`, which comes in handy for some algorithms. Args: - graph: A scalar GraphTensor. + graph_tensor: A scalar GraphTensor. from_tag: Values are broadcast from context if this is `tfgnn.CONTEXT` or from the incident node on each edge with this tag. edge_set_name: The name of the edge set to which values are broadcast, or @@ -230,10 +227,10 @@ def broadcast_v2( If a list of names was specified, the result is a list of tensors, with parallel indices. """ - gt.check_scalar_graph_tensor(graph, "broadcast()") + gt.check_scalar_graph_tensor(graph_tensor, "broadcast()") edge_set_names, node_set_names, got_sequence_args = ( tag_utils.get_edge_or_node_set_name_args_for_tag( - graph.spec, from_tag, + graph_tensor.spec, from_tag, edge_set_name=edge_set_name, node_set_name=node_set_name, function_name="broadcast()")) del edge_set_name, node_set_name # Replaced by their cleaned-up versions. @@ -244,82 +241,18 @@ def broadcast_v2( if from_tag == const.CONTEXT: if edge_set_names is not None: - result = [broadcast_context_to_edges(graph, name, **feature_kwargs) + result = [broadcast_context_to_edges(graph_tensor, name, **feature_kwargs) for name in edge_set_names] else: - result = [broadcast_context_to_nodes(graph, name, **feature_kwargs) + result = [broadcast_context_to_nodes(graph_tensor, name, **feature_kwargs) for name in node_set_names] else: - result = [broadcast_node_to_edges(graph, name, from_tag, **feature_kwargs) - for name in edge_set_names] + result = [ + broadcast_node_to_edges(graph_tensor, name, from_tag, **feature_kwargs) + for name in edge_set_names] if got_sequence_args: return result else: assert len(result) == 1 return result[0] - - -# TODO(b/265760014): Remove in favor of broadcast_v2(). -# The difference is that v2 supports multiple node/edge sets and v1 does not, -# mirroring the difference between pool_v2() and pool_v1(). -def broadcast_v1(graph_tensor: GraphTensor, - from_tag: const.IncidentNodeOrContextTag, - *, - edge_set_name: Optional[EdgeSetName] = None, - node_set_name: Optional[NodeSetName] = None, - feature_value: Optional[Field] = None, - feature_name: Optional[FieldName] = None) -> Field: - """Broadcasts values from nodes to edges, or from context to nodes or edges. - - This function broadcasts from context if `from_tag=tfgnn.CONTEXT` and - broadcasts from incident nodes to edges if `from_tag` is an ordinary node tag - like `tfgnn.SOURCE` or `tfgnn.TARGET`. Most user code will not need this - flexibility and can directly call one of the underlying functions - `broadcast_node_to_edges()`, `broadcast_context_to_nodes()`, or - `broadcast_context_to_edges()`. - - Args: - graph_tensor: A scalar GraphTensor. - from_tag: Values are broadcast from context if this is `tfgnn.CONTEXT` or - from the incident node on each edge with this tag. - edge_set_name: The name of the edge set to which values are broadcast. - node_set_name: The name of the node set to which values are broadcast. - Can only be set with `from_tag=tfgnn.CONTEXT`. Either edge_set_name or - node_set_name must be set. - feature_value: As for the underlying broadcast_*() function. - feature_name: As for the underlying broadcast_*() function. - Exactly one of feature_name or feature_value must be set. - - Returns: - The result of the underlying broadcast_*() function. - """ - _validate_names_and_tag( - from_tag, edge_set_name=edge_set_name, node_set_name=node_set_name) - if from_tag == const.CONTEXT: - if node_set_name is not None: - return broadcast_context_to_nodes( - graph_tensor, node_set_name=node_set_name, - feature_value=feature_value, feature_name=feature_name) - else: - return broadcast_context_to_edges( - graph_tensor, edge_set_name=edge_set_name, - feature_value=feature_value, feature_name=feature_name) - else: - return broadcast_node_to_edges( - graph_tensor, edge_set_name=edge_set_name, node_tag=from_tag, - feature_value=feature_value, feature_name=feature_name) - - -# TODO(b/265760014): Remove along with broadcast_v1(). -def _validate_names_and_tag(tag, *, edge_set_name, node_set_name): - """Helper for broadcast_v1().""" - if tag == const.CONTEXT: - num_names = bool(edge_set_name is None) + bool(node_set_name is None) - if num_names != 1: - raise ValueError("With tag CONTEXT, must pass exactly 1 of " - f"edge_set_name, node_set_name; got {num_names}.") - else: - if edge_set_name is None or node_set_name is not None: - raise ValueError("Must pass edge_set_name but not node_set_name " - "for a tag other than CONTEXT.") diff --git a/tensorflow_gnn/graph/broadcast_ops_test.py b/tensorflow_gnn/graph/broadcast_ops_test.py index ff7a0368..d72ed10e 100644 --- a/tensorflow_gnn/graph/broadcast_ops_test.py +++ b/tensorflow_gnn/graph/broadcast_ops_test.py @@ -28,8 +28,12 @@ as_ragged = tf.ragged.constant -class BroadcastingTest(tf.test.TestCase, parameterized.TestCase): - """Tests for broadcasting operations.""" +class BroadcastXToYTest(tf.test.TestCase, parameterized.TestCase): + """Tests for basic broadcasting operations broadcast_*_to_*(). + + For consistency, some tests run the corresponging call to the generic + broadcast_v2() function as well, but see BroadcastV2Test for more on that. + """ @parameterized.named_parameters( ("WithAdjacency", False), @@ -73,7 +77,7 @@ def testEdgeFieldFromNode(self, use_hyper_adjacency=False): graph, "edge", const.SOURCE, feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.SOURCE, edge_set_name="edge", feature_name=fname)) for fname, expected in expected_target_fields.items(): self.assertAllEqual( @@ -82,7 +86,7 @@ def testEdgeFieldFromNode(self, use_hyper_adjacency=False): graph, "edge", const.TARGET, feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.TARGET, edge_set_name="edge", feature_name=fname)) @parameterized.parameters([ @@ -138,7 +142,7 @@ def testNodeFieldFromContext(self, description: str, context: gt.Context, graph, "node", feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.CONTEXT, node_set_name="node", feature_name=fname)) @parameterized.parameters([ @@ -204,12 +208,16 @@ def testEdgeFieldFromContext(self, description: str, context: gt.Context, graph, "edge", feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.CONTEXT, edge_set_name="edge", feature_name=fname)) class BroadcastV2Test(tf.test.TestCase, parameterized.TestCase): - """Tests for generic broadcast_v2(), on top of already-tested basic ops.""" + """Tests for generic broadcast_v2() wrapper. + + These tests assume correctness of the underlying broadcast_*_to_*() ops; + see BroadcastXtoYTest for these. + """ def testOneEdgeSetFromTag(self): input_graph = _get_test_graph_broadcast() diff --git a/tensorflow_gnn/keras/layers/convolution_base.py b/tensorflow_gnn/keras/layers/convolution_base.py index b514e45b..18949d6f 100644 --- a/tensorflow_gnn/keras/layers/convolution_base.py +++ b/tensorflow_gnn/keras/layers/convolution_base.py @@ -305,7 +305,7 @@ def bind_receiver_args(fn): return lambda feature_value, **kwargs: fn( graph, receiver_tag, **name_kwarg, feature_value=feature_value, **kwargs) - broadcast_from_receiver = bind_receiver_args(broadcast_ops.broadcast_v1) + broadcast_from_receiver = bind_receiver_args(broadcast_ops.broadcast_v2) pool_to_receiver = bind_receiver_args(pool_ops_v1.pool_v1) if self._extra_receiver_ops is None: extra_receiver_ops_kwarg = {} # Pass no argument for this.