We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
LinearInterpolation
1 parent 737bf39 commit 685a8f3Copy full SHA for 685a8f3
diffrax/global_interpolation.py
@@ -125,10 +125,13 @@ def _index(_ys):
125
prev_t = self.ts[index]
126
next_t = self.ts[index + 1]
127
diff_t = next_t - prev_t
128
-
129
- return (
130
- prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
131
- ).ω
+ diff_nonzero = diff_t >= jnp.finfo(diff_t.dtype).eps
+ safe_diff = jnp.where(diff_nonzero, diff_t, jnp.ones_like(diff_t))
+ return jnp.where(
+ diff_nonzero,
132
+ (prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / safe_diff)).ω,
133
+ prev_ys
134
+ )
135
136
@eqx.filter_jit
137
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
0 commit comments