Skip to content

Adds support for Call-Context Arguments #843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 223 additions & 96 deletions tf_keras/engine/base_layer.py

Large diffs are not rendered by default.

165 changes: 165 additions & 0 deletions tf_keras/engine/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,171 @@ def __init__(self, var1, var2, var3=None, **kwargs):
with self.assertRaises(NotImplementedError):
config = layer.get_config()

def test_call_context_args_with_custom_layers_propagates_args(self):
class Inner(layers.Layer):
def __init__(self):
super().__init__()
self._register_call_context_args("foo_mode")

def call(self, x, foo_mode=None):
return x + (1 if foo_mode else 0)

class Outer(layers.Layer):
def __init__(self):
super().__init__()
self._register_call_context_args("foo_mode")
self.inner = Inner()

def call(self, x):
# Outer doesn’t even need to re‑inject explicitly:
# our base class will propagate foo_mode automatically
return self.inner(x)

layer = Outer()
self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1)
self.assertEqual(int(layer(np.array(0))), 0)

def test_register_call_context_arguments_success(self):
"""Validate that registering call-context args works as expected."""

class MyLayer(layers.Layer):
def call(self, x):
return x

layer = MyLayer()

layer._register_call_context_args("foo_mode")

self.assertCountEqual(
layer._call_context_args, ("foo_mode", "training")
)

def test_register_call_context_arguments_after_call_raises_error(self):
"""Validate that registering call-context args after the layer has
been called raises an error."""

class MyLayer(layers.Layer):
def call(self, x):
return x

layer = MyLayer()
layer(np.array(0))
with self.assertRaisesRegex(
RuntimeError,
"Cannot add call-context args after the layer has been called.",
):
layer._register_call_context_args("foo_mode")

def test_nested_context_args_follow_priority_order(self):
"""Validate that call-context args are propagated correctly
through multiple layers, and that the most specific value is used
when multiple values are passed down the call-stack.
"""

class Inner(base_layer.Layer):
def __init__(self):
super().__init__(name="inner_layer")
self._register_call_context_args("foo_mode")

def call(self, inputs, foo_mode=None):
return inputs + (1 if foo_mode else 0)

class Middle(base_layer.Layer):
def __init__(self):
super().__init__(name="middle_layer")
self._inner_layer = Inner()

def call(self, inputs):
return self._inner_layer(inputs)

class Outer(base_layer.Layer):
def __init__(self):
super().__init__(name="outer_layer")
self._middle = Middle()

def call(self, inputs):
return self._middle(inputs)

layer = Outer()
layer._register_call_context_args("foo_mode")

# The value of foo_mode is set to True in the call to Outer,
# so it should automatically propagate to Inner through Middle.
self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1)
self.assertEqual(int(layer(np.array(0))), 0)

def test_context_arg_propagation_without_declaration_does_not_resolve(self):
"""Validate that layer does not resolve a propagated arg if it is not
declared as a call-context arg in the layer itself."""

class Inner(layers.Layer):
def call(self, x, foo_mode=None):
return x + (1 if foo_mode else 0)

class Wrapper(layers.Layer):
def __init__(self):
super().__init__()
self.inner = Inner()

def call(self, x):
return self.inner(x)

layer = Wrapper()
layer._register_call_context_args("foo_mode")

# The value of foo_mode is set to True in the call to Wrapper,
# However, it is not declared as a call-context arg in Inner,
# so it should not resolve to True inside Inner (and instead
# default to False).
self.assertEqual(int(layer(np.array(0), foo_mode=True)), 0)

def test_call_context_args_with_models_as_layers_propagates_args(self):
"""Validate that call-context args are propagated correctly
through functional and sequential models when used as layers.
"""

class InnerLayer(base_layer.Layer):
def __init__(self):
super().__init__(name="inner_layer")
self._register_call_context_args("foo")

def call(self, inputs, foo=None):
if foo:
return inputs + 1.0
return inputs

class OuterLayer(base_layer.Layer):
def __init__(self):
super().__init__(name="outer_layer")
self._inner_layer = InnerLayer()

def call(self, inputs):
return self._inner_layer(inputs)

sample_input = tf.constant([[1.0, 2.0], [3.0, 4.0]], dtype="float32")

# Sequential model
seq = sequential.Sequential([OuterLayer()])
seq._register_call_context_args("foo")

out_true = seq(sample_input, foo=True)
self.assertAllEqual(out_true, sample_input + 1.0)

out_false = seq(sample_input, foo=False)
self.assertAllEqual(out_false, sample_input)

# Functional model
inp = input_layer.Input((2,))
outer = OuterLayer()(inp)
model = training_lib.Model(inputs=[inp], outputs=[outer])
model._register_call_context_args("foo")

out_true = model(sample_input, foo=True)
self.assertAllEqual(out_true, sample_input + 1.0)

out_false = model(sample_input, foo=False)
self.assertAllEqual(out_false, sample_input)


@test_utils.run_v2_only
class SymbolicSupportTest(test_combinations.TestCase):
Expand Down
22 changes: 17 additions & 5 deletions tf_keras/engine/base_layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ class CallContext:
layer: The `Layer` whose `call` is currently active.
inputs: The inputs to the currently active `Layer`.
build_graph: Whether currently inside a Graph or FuncGraph.
training: Whether currently executing in training or inference mode.
call_context_args: The call-context arguments being propagated through the
the call-stack.
saving: Whether currently saving to SavedModel.
frozen: Whether currently executing inside a `Layer` with `trainable` set
to `False`.
Expand All @@ -495,21 +496,25 @@ def __init__(self):
"layer": None,
"inputs": None,
"build_graph": False,
"call_context_args": dict(),
"training": None,
"saving": None,
}
# TODO(b/150169018): This logic can be replaced after the Functional API
# refactor.
self._in_keras_graph = False

def enter(self, layer, inputs, build_graph, training, saving=None):
def enter(
self, layer, inputs, build_graph, call_context_args=dict(), saving=None
):
"""Push a Layer and its inputs and state onto the current call context.

Args:
layer: The `Layer` whose `call` is currently active.
inputs: The inputs to the currently active `Layer`.
build_graph: Whether currently inside a Graph or FuncGraph.
training: Whether currently executing in training or inference mode.
call_context_args: The call-context arguments being propagated through
the call-stack.
saving: Whether currently saving to SavedModel.

Returns:
Expand All @@ -519,7 +524,7 @@ def enter(self, layer, inputs, build_graph, training, saving=None):
"layer": layer,
"inputs": inputs,
"build_graph": build_graph,
"training": training,
"call_context_args": call_context_args,
"saving": saving,
}
return CallContextManager(self, state)
Expand All @@ -538,7 +543,14 @@ def build_graph(self):

@property
def training(self):
return self._state["training"]
return self.call_context_args.get("training", None)

@property
def call_context_args(self):
return self._state["call_context_args"]

def get_call_context_arg(self, arg_name):
return self.call_context_args.get(arg_name, None)

@property
def saving(self):
Expand Down
11 changes: 10 additions & 1 deletion tf_keras/engine/base_layer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self, trainable=True, name=None, dtype=None, dynamic=False, **kwargs
):
self._instrument_layer_creation()
self._called = False

# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
Expand Down Expand Up @@ -165,6 +166,8 @@ def __init__(
self._input_spec = None
self.supports_masking = False

self._call_context_args = {"training"}

self._init_set_name(name)
self._activity_regularizer = regularizers.get(
kwargs.pop("activity_regularizer", None)
Expand Down Expand Up @@ -705,6 +708,7 @@ def __call__(self, *args, **kwargs):
RuntimeError: if `super().__init__()` was not called in the
constructor.
"""
self._called = True
self._assert_built_as_v1()

if not hasattr(self, "_thread_local"):
Expand Down Expand Up @@ -803,7 +807,12 @@ def _convert_non_tensor(x):
if build_graph and base_layer_utils.needs_keras_history(inputs):
base_layer_utils.create_keras_history(inputs)

with call_context.enter(self, inputs, build_graph, training_value):
with call_context.enter(
self,
inputs,
build_graph,
call_context_args={"training": training_value},
):
# Check input assumptions set after layer building, e.g. input
# shape.
if build_graph:
Expand Down
4 changes: 4 additions & 0 deletions tf_keras/layers/core/tf_op_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ def _call_wrapper(*args, **kwargs):

self._call_spec.expects_training_arg = False
self._call_spec.expects_mask_arg = False
# Clear the call-context arguments for the layer's call method.
# Otherwise, Keras ends up injecting context arguments into the op-call
# when the call method accepts kwargs.
self._call_spec._expected_context_args.clear()

def _call_wrapper(self, *args, **kwargs):
created_variables = []
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/rnn/base_rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,9 @@ def test_stacked_rnn_attributes(self):
cells[0].kernel, tf.ones_like(cells[0].kernel)
)
# TODO(b/128682878): Remove when RNNCells are __call__'d.
with base_layer_utils.call_context().enter(layer, x, True, None):
with base_layer_utils.call_context().enter(
layer, x, {"training": True}, None
):
cells[0].add_update(update_1)
cells[0].add_update(update_2)
self.assertEqual(len(layer.updates), 2)
Expand Down
4 changes: 3 additions & 1 deletion tf_keras/layers/rnn/bidirectional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,9 @@ def test_Bidirectional_updates(self):
_ = layer(x)
assert not layer.updates
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
with base_layer_utils.call_context().enter(layer, x, True, None):
with base_layer_utils.call_context().enter(
layer, x, {"training": True}, None
):
layer.forward_layer.add_update(x_reachable_update)
layer.forward_layer.add_update(1)
layer.backward_layer.add_update(x_reachable_update)
Expand Down
14 changes: 13 additions & 1 deletion tf_keras/layers/rnn/cell_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ def __init__(self, cell, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cell = cell
cell_call_spec = tf_inspect.getfullargspec(cell.call)
accepts_kwargs = cell_call_spec.varkw is not None

self._call_spec.expects_training_arg = (
"training" in cell_call_spec.args
) or (cell_call_spec.varkw is not None)
) or accepts_kwargs

# Filter _expects_context_arg. An argument is kept if:
# 1. It's an explicit argument in cell_call_spec.args OR
# 2. The cell accepts arbitrary keyword arguments (**kwargs),
# meaning it could potentially handle the context argument.
self._call_spec._expected_context_args = {
arg
for arg in self._call_spec._expected_context_args
if (arg in cell_call_spec.args) or accepts_kwargs
}

def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Calls the wrapped cell and performs the wrapping logic.
Expand Down
Loading
Loading