Skip to content

Commit 1ac005d

Browse files
Tests that a jump at t1 is saved.
1 parent 5d9f6b9 commit 1ac005d

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

test/test_adaptive_stepsize_controller.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,3 +336,28 @@ def test_implicit_solver_with_clip_controller(new: bool):
336336
max_steps=16384,
337337
saveat=diffrax.SaveAt(t1=True),
338338
)
339+
340+
341+
# https://github.com/patrick-kidger/diffrax/issues/663
342+
# `jump_ts` sets the time we step to as `prevbefore` the time provided.
343+
# Clipping at t1 saves us! We need to clip at at least 1 ULP.
344+
def test_jump_at_t1_with_large_t1_in_float32():
345+
t0 = jnp.array(0.0, dtype=jnp.float32)
346+
t1 = jnp.array(1e3, dtype=jnp.float32)
347+
dt0 = jnp.array(0.01, dtype=jnp.float32)
348+
y0 = jnp.array(1, dtype=jnp.float32)
349+
saveat = diffrax.SaveAt(ts=t1[None])
350+
ssc = diffrax.ClipStepSizeController(
351+
diffrax.PIDController(atol=1e-6, rtol=1e-6), jump_ts=t1[None]
352+
)
353+
sol = diffrax.diffeqsolve(
354+
diffrax.ODETerm(lambda t, y, args: -y),
355+
diffrax.Heun(),
356+
t0=t0,
357+
t1=t1,
358+
dt0=dt0,
359+
y0=y0,
360+
stepsize_controller=ssc,
361+
saveat=saveat,
362+
)
363+
assert sol.ts == jnp.array([t1])

0 commit comments

Comments
 (0)