diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 814c90d9..1e814f9f 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -249,7 +249,7 @@ def body_fun(state): error_order, state.controller_state, ) - assert jnp.result_type(keep_step) is jnp.dtype(bool) + assert jnp.result_type(keep_step) in [bool, jnp.dtype(bool)] # # Do some book-keeping.