@@ -77,7 +77,7 @@ except `cache` (& `J` if not nothing) are mutated.
77
77
function value_and_jacobian (ad, f:: F , y, x:: X , p, cache; J = nothing ) where {F, X}
78
78
if isinplace (f)
79
79
_f = (du, u) -> f (du, u, p)
80
- if DiffEqBase . has_jac (f)
80
+ if SciMLBase . has_jac (f)
81
81
f. jac (J, x, p)
82
82
_f (y, x)
83
83
return y, J
@@ -97,7 +97,7 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
97
97
end
98
98
else
99
99
_f = Base. Fix2 (f, p)
100
- if DiffEqBase . has_jac (f)
100
+ if SciMLBase . has_jac (f)
101
101
return _f (x), f. jac (x, p)
102
102
elseif ad isa AutoForwardDiff
103
103
if ArrayInterface. can_setindex (x)
124
124
function __polyester_forwarddiff_jacobian! end
125
125
126
126
function value_and_jacobian (ad, f:: F , y, x:: Number , p, cache; J = nothing ) where {F}
127
- if DiffEqBase . has_jac (f)
127
+ if SciMLBase . has_jac (f)
128
128
return f (x, p), f. jac (x, p)
129
129
elseif ad isa AutoForwardDiff
130
130
T = typeof (__standard_tag (ad. tag, x))
@@ -152,7 +152,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
152
152
if isinplace (f)
153
153
_f = (du, u) -> f (du, u, p)
154
154
J = similar (y, length (y), length (x))
155
- if DiffEqBase . has_jac (f)
155
+ if SciMLBase . has_jac (f)
156
156
return J, nothing
157
157
elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff
158
158
return J, __get_jacobian_config (ad, _f, y, x)
@@ -163,7 +163,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
163
163
end
164
164
else
165
165
_f = Base. Fix2 (f, p)
166
- if DiffEqBase . has_jac (f)
166
+ if SciMLBase . has_jac (f)
167
167
return nothing , nothing
168
168
elseif ad isa AutoForwardDiff
169
169
J = ArrayInterface. can_setindex (x) ? similar (y, length (y), length (x)) : nothing
@@ -292,58 +292,27 @@ function init_termination_cache(abstol, reltol, du, u, ::Nothing)
292
292
return init_termination_cache (abstol, reltol, du, u, AbsNormTerminationMode ())
293
293
end
294
294
function init_termination_cache (abstol, reltol, du, u, tc:: AbstractNonlinearTerminationMode )
295
- T = promote_type (eltype (du), eltype (u))
296
- abstol = __get_tolerance (u, abstol, T)
297
- reltol = __get_tolerance (u, reltol, T)
298
295
tc_cache = init (du, u, tc; abstol, reltol)
299
- return DiffEqBase. get_abstol (tc_cache), DiffEqBase. get_reltol (tc_cache), tc_cache
296
+ return (NonlinearSolveBase. get_abstol (tc_cache),
297
+ NonlinearSolveBase. get_reltol (tc_cache), tc_cache)
300
298
end
301
299
302
300
function check_termination (tc_cache, fx, x, xo, prob, alg)
303
301
return check_termination (tc_cache, fx, x, xo, prob, alg,
304
- DiffEqBase. get_termination_mode (tc_cache))
305
- end
306
- function check_termination (tc_cache, fx, x, xo, prob, alg,
307
- :: AbstractNonlinearTerminationMode )
308
- if Bool (tc_cache (fx, x, xo))
309
- return build_solution (prob, alg, x, fx; retcode = ReturnCode. Success)
310
- end
311
- return nothing
312
- end
313
- function check_termination (tc_cache, fx, x, xo, prob, alg,
314
- :: AbstractSafeNonlinearTerminationMode )
315
- if Bool (tc_cache (fx, x, xo))
316
- if tc_cache. retcode == NonlinearSafeTerminationReturnCode. Success
317
- retcode = ReturnCode. Success
318
- elseif tc_cache. retcode == NonlinearSafeTerminationReturnCode. PatienceTermination
319
- retcode = ReturnCode. ConvergenceFailure
320
- elseif tc_cache. retcode == NonlinearSafeTerminationReturnCode. ProtectiveTermination
321
- retcode = ReturnCode. Unstable
322
- else
323
- error (" Unknown termination code: $(tc_cache. retcode) " )
324
- end
325
- return build_solution (prob, alg, x, fx; retcode)
326
- end
327
- return nothing
302
+ NonlinearSolveBase. get_termination_mode (tc_cache))
328
303
end
304
+
329
305
function check_termination (tc_cache, fx, x, xo, prob, alg,
330
- :: AbstractSafeBestNonlinearTerminationMode )
306
+ mode :: AbstractNonlinearTerminationMode )
331
307
if Bool (tc_cache (fx, x, xo))
332
- if tc_cache. retcode == NonlinearSafeTerminationReturnCode. Success
333
- retcode = ReturnCode. Success
334
- elseif tc_cache. retcode == NonlinearSafeTerminationReturnCode. PatienceTermination
335
- retcode = ReturnCode. ConvergenceFailure
336
- elseif tc_cache. retcode == NonlinearSafeTerminationReturnCode. ProtectiveTermination
337
- retcode = ReturnCode. Unstable
338
- else
339
- error (" Unknown termination code: $(tc_cache. retcode) " )
340
- end
341
- if isinplace (prob)
342
- prob. f (fx, x, prob. p)
343
- else
344
- fx = prob. f (x, prob. p)
308
+ if mode isa AbstractSafeBestNonlinearTerminationMode
309
+ if isinplace (prob)
310
+ prob. f (fx, x, prob. p)
311
+ else
312
+ fx = prob. f (x, prob. p)
313
+ end
345
314
end
346
- return build_solution (prob, alg, tc_cache . u , fx; retcode)
315
+ return build_solution (prob, alg, x , fx; retcode = tc_cache . retcode)
347
316
end
348
317
return nothing
349
318
end
382
351
@inline __reshape (x:: Number , args... ) = x
383
352
@inline __reshape (x:: AbstractArray , args... ) = reshape (x, args... )
384
353
385
- # Override cases which might be used in a kernel launch
386
- __get_tolerance (x, η, :: Type{T} ) where {T} = DiffEqBase. _get_tolerance (η, T)
387
- function __get_tolerance (x:: Union{SArray, Number} , :: Nothing , :: Type{T} ) where {T}
388
- η = real (oneunit (T)) * (eps (real (one (T))))^ (real (T)(0.8 ))
389
- return T (η)
390
- end
391
-
392
354
# Extension
393
355
function __zygote_compute_nlls_vjp end
0 commit comments