@@ -212,6 +212,7 @@ def __init__(
212212 self .supports_masking = True
213213 self .input_spec = None
214214 self .states = None
215+ self ._expected_batch_size = None
215216
216217 state_size = getattr (self .cell , "state_size" , None )
217218 if state_size is None :
@@ -283,6 +284,9 @@ def build(self, sequences_shape, initial_state_shape=None):
283284 f"batch size: sequence.shape={ sequences_shape } "
284285 )
285286 self ._create_state_variables (sequences_shape [0 ])
287+ self ._expected_batch_size = ops .shape (
288+ tree .flatten (self .states )[0 ]
289+ )[0 ]
286290
287291 @tracking .no_automatic_dependency_tracking
288292 def _create_state_variables (self , batch_size ):
@@ -382,6 +386,21 @@ def call(
382386 initial_state = self .get_initial_state (
383387 batch_size = ops .shape (sequences )[0 ]
384388 )
389+ if self .stateful :
390+ actual_batch_size = ops .shape (sequences )[0 ]
391+ if (
392+ self ._expected_batch_size is not None
393+ and actual_batch_size is not None
394+ and actual_batch_size != self ._expected_batch_size
395+ ):
396+ raise ValueError (
397+ f"If an RNN is stateful, the batch size of the "
398+ f"input sequences must be the same as the batch "
399+ f"size of the initial state. \n "
400+ f"- Expected batch size: { self ._expected_batch_size } \n "
401+ f"- Received batch size: { actual_batch_size } "
402+ )
403+
385404 # RNN expect the states in a list, even if single state.
386405 if not tree .is_nested (initial_state ):
387406 initial_state = [initial_state ]
0 commit comments