diff --git a/tensorflow_gnn/__init__.py b/tensorflow_gnn/__init__.py index 2308ec21..797a218b 100644 --- a/tensorflow_gnn/__init__.py +++ b/tensorflow_gnn/__init__.py @@ -152,7 +152,7 @@ broadcast_node_to_edges = broadcast_ops.broadcast_node_to_edges broadcast_context_to_nodes = broadcast_ops.broadcast_context_to_nodes broadcast_context_to_edges = broadcast_ops.broadcast_context_to_edges -broadcast = broadcast_ops.broadcast_v1 # TODO(b/265760014): Switch to v2. +broadcast = broadcast_ops.broadcast_v2 pool_edges_to_node = pool_ops_v1.pool_edges_to_node pool_nodes_to_context = pool_ops_v1.pool_nodes_to_context pool_edges_to_context = pool_ops_v1.pool_edges_to_context diff --git a/tensorflow_gnn/graph/broadcast_ops.py b/tensorflow_gnn/graph/broadcast_ops.py index 4289b66e..5ca63573 100644 --- a/tensorflow_gnn/graph/broadcast_ops.py +++ b/tensorflow_gnn/graph/broadcast_ops.py @@ -178,11 +178,8 @@ def _broadcast_context(graph_tensor: GraphTensor, repeats_sum_hint=node_or_edge_set.spec.total_size) -# TODO(b/265760014): Export as tfgnn.broadcast() and remove broadcast_v1(). -# The difference is that v2 supports multiple node/edge sets and v1 does not, -# mirroring the difference between pool_v2() and pool_v1(). def broadcast_v2( - graph: GraphTensor, + graph_tensor: GraphTensor, from_tag: IncidentNodeOrContextTag, *, edge_set_name: Union[Sequence[EdgeSetName], EdgeSetName, None] = None, @@ -207,7 +204,7 @@ def broadcast_v2( image of `tfgnn.pool()`, which comes in handy for some algorithms. Args: - graph: A scalar GraphTensor. + graph_tensor: A scalar GraphTensor. from_tag: Values are broadcast from context if this is `tfgnn.CONTEXT` or from the incident node on each edge with this tag. edge_set_name: The name of the edge set to which values are broadcast, or @@ -230,10 +227,10 @@ def broadcast_v2( If a list of names was specified, the result is a list of tensors, with parallel indices. """ - gt.check_scalar_graph_tensor(graph, "broadcast()") + gt.check_scalar_graph_tensor(graph_tensor, "broadcast()") edge_set_names, node_set_names, got_sequence_args = ( tag_utils.get_edge_or_node_set_name_args_for_tag( - graph.spec, from_tag, + graph_tensor.spec, from_tag, edge_set_name=edge_set_name, node_set_name=node_set_name, function_name="broadcast()")) del edge_set_name, node_set_name # Replaced by their cleaned-up versions. @@ -244,82 +241,18 @@ def broadcast_v2( if from_tag == const.CONTEXT: if edge_set_names is not None: - result = [broadcast_context_to_edges(graph, name, **feature_kwargs) + result = [broadcast_context_to_edges(graph_tensor, name, **feature_kwargs) for name in edge_set_names] else: - result = [broadcast_context_to_nodes(graph, name, **feature_kwargs) + result = [broadcast_context_to_nodes(graph_tensor, name, **feature_kwargs) for name in node_set_names] else: - result = [broadcast_node_to_edges(graph, name, from_tag, **feature_kwargs) - for name in edge_set_names] + result = [ + broadcast_node_to_edges(graph_tensor, name, from_tag, **feature_kwargs) + for name in edge_set_names] if got_sequence_args: return result else: assert len(result) == 1 return result[0] - - -# TODO(b/265760014): Remove in favor of broadcast_v2(). -# The difference is that v2 supports multiple node/edge sets and v1 does not, -# mirroring the difference between pool_v2() and pool_v1(). -def broadcast_v1(graph_tensor: GraphTensor, - from_tag: const.IncidentNodeOrContextTag, - *, - edge_set_name: Optional[EdgeSetName] = None, - node_set_name: Optional[NodeSetName] = None, - feature_value: Optional[Field] = None, - feature_name: Optional[FieldName] = None) -> Field: - """Broadcasts values from nodes to edges, or from context to nodes or edges. - - This function broadcasts from context if `from_tag=tfgnn.CONTEXT` and - broadcasts from incident nodes to edges if `from_tag` is an ordinary node tag - like `tfgnn.SOURCE` or `tfgnn.TARGET`. Most user code will not need this - flexibility and can directly call one of the underlying functions - `broadcast_node_to_edges()`, `broadcast_context_to_nodes()`, or - `broadcast_context_to_edges()`. - - Args: - graph_tensor: A scalar GraphTensor. - from_tag: Values are broadcast from context if this is `tfgnn.CONTEXT` or - from the incident node on each edge with this tag. - edge_set_name: The name of the edge set to which values are broadcast. - node_set_name: The name of the node set to which values are broadcast. - Can only be set with `from_tag=tfgnn.CONTEXT`. Either edge_set_name or - node_set_name must be set. - feature_value: As for the underlying broadcast_*() function. - feature_name: As for the underlying broadcast_*() function. - Exactly one of feature_name or feature_value must be set. - - Returns: - The result of the underlying broadcast_*() function. - """ - _validate_names_and_tag( - from_tag, edge_set_name=edge_set_name, node_set_name=node_set_name) - if from_tag == const.CONTEXT: - if node_set_name is not None: - return broadcast_context_to_nodes( - graph_tensor, node_set_name=node_set_name, - feature_value=feature_value, feature_name=feature_name) - else: - return broadcast_context_to_edges( - graph_tensor, edge_set_name=edge_set_name, - feature_value=feature_value, feature_name=feature_name) - else: - return broadcast_node_to_edges( - graph_tensor, edge_set_name=edge_set_name, node_tag=from_tag, - feature_value=feature_value, feature_name=feature_name) - - -# TODO(b/265760014): Remove along with broadcast_v1(). -def _validate_names_and_tag(tag, *, edge_set_name, node_set_name): - """Helper for broadcast_v1().""" - if tag == const.CONTEXT: - num_names = bool(edge_set_name is None) + bool(node_set_name is None) - if num_names != 1: - raise ValueError("With tag CONTEXT, must pass exactly 1 of " - f"edge_set_name, node_set_name; got {num_names}.") - else: - if edge_set_name is None or node_set_name is not None: - raise ValueError("Must pass edge_set_name but not node_set_name " - "for a tag other than CONTEXT.") diff --git a/tensorflow_gnn/graph/broadcast_ops_test.py b/tensorflow_gnn/graph/broadcast_ops_test.py index ff7a0368..d72ed10e 100644 --- a/tensorflow_gnn/graph/broadcast_ops_test.py +++ b/tensorflow_gnn/graph/broadcast_ops_test.py @@ -28,8 +28,12 @@ as_ragged = tf.ragged.constant -class BroadcastingTest(tf.test.TestCase, parameterized.TestCase): - """Tests for broadcasting operations.""" +class BroadcastXToYTest(tf.test.TestCase, parameterized.TestCase): + """Tests for basic broadcasting operations broadcast_*_to_*(). + + For consistency, some tests run the corresponging call to the generic + broadcast_v2() function as well, but see BroadcastV2Test for more on that. + """ @parameterized.named_parameters( ("WithAdjacency", False), @@ -73,7 +77,7 @@ def testEdgeFieldFromNode(self, use_hyper_adjacency=False): graph, "edge", const.SOURCE, feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.SOURCE, edge_set_name="edge", feature_name=fname)) for fname, expected in expected_target_fields.items(): self.assertAllEqual( @@ -82,7 +86,7 @@ def testEdgeFieldFromNode(self, use_hyper_adjacency=False): graph, "edge", const.TARGET, feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.TARGET, edge_set_name="edge", feature_name=fname)) @parameterized.parameters([ @@ -138,7 +142,7 @@ def testNodeFieldFromContext(self, description: str, context: gt.Context, graph, "node", feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.CONTEXT, node_set_name="node", feature_name=fname)) @parameterized.parameters([ @@ -204,12 +208,16 @@ def testEdgeFieldFromContext(self, description: str, context: gt.Context, graph, "edge", feature_name=fname)) self.assertAllEqual( expected, - broadcast_ops.broadcast_v1( + broadcast_ops.broadcast_v2( graph, const.CONTEXT, edge_set_name="edge", feature_name=fname)) class BroadcastV2Test(tf.test.TestCase, parameterized.TestCase): - """Tests for generic broadcast_v2(), on top of already-tested basic ops.""" + """Tests for generic broadcast_v2() wrapper. + + These tests assume correctness of the underlying broadcast_*_to_*() ops; + see BroadcastXtoYTest for these. + """ def testOneEdgeSetFromTag(self): input_graph = _get_test_graph_broadcast() diff --git a/tensorflow_gnn/keras/layers/convolution_base.py b/tensorflow_gnn/keras/layers/convolution_base.py index b514e45b..18949d6f 100644 --- a/tensorflow_gnn/keras/layers/convolution_base.py +++ b/tensorflow_gnn/keras/layers/convolution_base.py @@ -305,7 +305,7 @@ def bind_receiver_args(fn): return lambda feature_value, **kwargs: fn( graph, receiver_tag, **name_kwarg, feature_value=feature_value, **kwargs) - broadcast_from_receiver = bind_receiver_args(broadcast_ops.broadcast_v1) + broadcast_from_receiver = bind_receiver_args(broadcast_ops.broadcast_v2) pool_to_receiver = bind_receiver_args(pool_ops_v1.pool_v1) if self._extra_receiver_ops is None: extra_receiver_ops_kwarg = {} # Pass no argument for this.