Skip to content

Commit f78fd8c

Browse files
Raise exception on batch_size mismatch for stateful RNNs (#21742)
* fix stateful rnn (pulling from existing PR that closed) * Update keras/src/layers/rnn/rnn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * ensure we don't recompute expected batch size for each call * address comments from francois --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 4bc6576 commit f78fd8c

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

keras/src/layers/rnn/rnn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def __init__(
212212
self.supports_masking = True
213213
self.input_spec = None
214214
self.states = None
215+
self._expected_batch_size = None
215216

216217
state_size = getattr(self.cell, "state_size", None)
217218
if state_size is None:
@@ -283,6 +284,9 @@ def build(self, sequences_shape, initial_state_shape=None):
283284
f"batch size: sequence.shape={sequences_shape}"
284285
)
285286
self._create_state_variables(sequences_shape[0])
287+
self._expected_batch_size = ops.shape(
288+
tree.flatten(self.states)[0]
289+
)[0]
286290

287291
@tracking.no_automatic_dependency_tracking
288292
def _create_state_variables(self, batch_size):
@@ -382,6 +386,21 @@ def call(
382386
initial_state = self.get_initial_state(
383387
batch_size=ops.shape(sequences)[0]
384388
)
389+
if self.stateful:
390+
actual_batch_size = ops.shape(sequences)[0]
391+
if (
392+
self._expected_batch_size is not None
393+
and actual_batch_size is not None
394+
and actual_batch_size != self._expected_batch_size
395+
):
396+
raise ValueError(
397+
f"If an RNN is stateful, the batch size of the "
398+
f"input sequences must be the same as the batch "
399+
f"size of the initial state. \n"
400+
f"- Expected batch size: {self._expected_batch_size}\n"
401+
f"- Received batch size: {actual_batch_size}"
402+
)
403+
385404
# RNN expect the states in a list, even if single state.
386405
if not tree.is_nested(initial_state):
387406
initial_state = [initial_state]

keras/src/layers/rnn/rnn_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,26 @@ def test_serialization(self):
381381
layer = layers.RNN(OneStateRNNCell(2), return_sequences=False)
382382
self.run_class_serialization_test(layer)
383383

384+
def test_stateful_batch_size_mismatch_raises(self):
385+
from keras.src.models import Functional
386+
387+
batch_size = 4
388+
timesteps = 5
389+
features = 3
390+
391+
layer = layers.RNN(TwoStatesRNNCell(2), stateful=True)
392+
inputs = layers.Input(
393+
shape=(timesteps, features), batch_size=batch_size
394+
)
395+
model = Functional(inputs, layer(inputs))
396+
397+
# Call once with correct batch size
398+
x = ops.random.uniform(shape=(batch_size, timesteps, features))
399+
_ = model(x)
400+
401+
# Expect ValueError when called with incorrect batch size
402+
with self.assertRaisesRegex(ValueError, "batch size"):
403+
x_bad = ops.random.uniform(shape=(1, timesteps, features))
404+
model(x_bad)
405+
384406
# TODO: test masking

0 commit comments

Comments
 (0)