Skip to content
This repository was archived by the owner on Apr 16, 2025. It is now read-only.

Commit 5afc4ba

Browse files
committed
Test the new NonlinearSolveBase.jl
1 parent fd7d216 commit 5afc4ba

14 files changed

+644
-99
lines changed

Manifest.toml

+574
Large diffs are not rendered by default.

Project.toml

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,42 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.5.0"
4+
version = "1.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
10-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1110
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1211
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1312
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1413
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1514
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1615
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
16+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
1717
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1818
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1919
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2020
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2121

2222
[weakdeps]
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
2425
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2526
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2627
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2728

2829
[extensions]
29-
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
30+
SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt = ["ChainRulesCore", "DiffEqBase"]
3031
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
3132
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
3233
SimpleNonlinearSolveZygoteExt = "Zygote"
3334

3435
[compat]
3536
ADTypes = "0.2.6"
3637
AllocCheck = "0.1.1"
37-
ArrayInterface = "7.7"
3838
Aqua = "0.8"
39+
ArrayInterface = "7.7"
3940
CUDA = "5.2"
4041
ChainRulesCore = "1.22"
4142
ConcreteStructs = "0.2.3"
@@ -48,6 +49,7 @@ LinearAlgebra = "1.10"
4849
LinearSolve = "2.25"
4950
MaybeInplace = "0.1.1"
5051
NonlinearProblemLibrary = "0.1.2"
52+
NonlinearSolveBase = "1"
5153
Pkg = "1.10"
5254
PolyesterForwardDiff = "0.1.1"
5355
PrecompileTools = "1.2"
@@ -66,12 +68,12 @@ julia = "1.10"
6668
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
6769
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6870
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
69-
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
7071
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
7172
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7273
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7374
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
7475
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
76+
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
7577
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7678
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
7779
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -83,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8385
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8486

8587
[targets]
86-
test = ["Aqua", "AllocCheck", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]
88+
test = ["Aqua", "AllocCheck", "NonlinearSolveBase", "ForwardDiff", "LinearAlgebra", "LinearSolve", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SciMLSensitivity", "StaticArrays", "Zygote", "CUDA", "PolyesterForwardDiff", "Reexport", "Test", "FiniteDiff"]

ext/SimpleNonlinearSolveChainRulesCoreExt.jl renamed to ext/SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module SimpleNonlinearSolveChainRulesCoreExt
1+
module SimpleNonlinearSolveChainRulesCoreDiffEqBaseExt
22

33
using ChainRulesCore, DiffEqBase, SciMLBase, SimpleNonlinearSolve
44

src/SimpleNonlinearSolve.jl

+31-25
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@ module SimpleNonlinearSolve
33
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
7-
ForwardDiff, Reexport, LinearAlgebra, SciMLBase
8-
9-
import DiffEqBase: AbstractNonlinearTerminationMode,
10-
AbstractSafeNonlinearTerminationMode,
11-
AbstractSafeBestNonlinearTerminationMode,
12-
NonlinearSafeTerminationReturnCode, get_termination_mode,
13-
NONLINEARSOLVE_DEFAULT_NORM
6+
using ADTypes, ArrayInterface, FiniteDiff, ForwardDiff, NonlinearSolveBase, Reexport,
7+
LinearAlgebra, SciMLBase
8+
9+
import ConcreteStructs: @concrete
1410
import DiffResults
11+
import FastClosures: @closure
1512
import ForwardDiff: Dual
1613
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
14+
import NonlinearSolveBase: AbstractNonlinearTerminationMode,
15+
AbstractSafeNonlinearTerminationMode,
16+
AbstractSafeBestNonlinearTerminationMode,
17+
get_termination_mode, NONLINEARSOLVE_DEFAULT_NORM
1718
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
1819
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
1920
end
2021

21-
@reexport using ADTypes, SciMLBase
22+
@reexport using ADTypes, SciMLBase # TODO: Reexport NonlinearSolveBase after the situation with NonlinearSolve.jl is resolved
2223

2324
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
2425
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
@@ -58,23 +59,28 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Nothing, args...;
5859
end
5960

6061
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
61-
function SciMLBase.solve(
62-
prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
63-
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
64-
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
65-
sensealg = prob.kwargs[:sensealg]
66-
end
67-
new_u0 = u0 !== nothing ? u0 : prob.u0
68-
new_p = p !== nothing ? p : prob.p
69-
return __internal_solve_up(
70-
prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
71-
alg, args...; prob.kwargs..., kwargs...)
72-
end
62+
# Using eval to prevent ambiguity
63+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
64+
@eval begin
65+
function SciMLBase.solve(
66+
prob::$(pType), alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
67+
sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
68+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
69+
sensealg = prob.kwargs[:sensealg]
70+
end
71+
new_u0 = u0 !== nothing ? u0 : prob.u0
72+
new_p = p !== nothing ? p : prob.p
73+
return __internal_solve_up(
74+
prob, sensealg, new_u0, u0 === nothing, new_p, p === nothing,
75+
alg, args...; prob.kwargs..., kwargs...)
76+
end
7377

74-
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed, p,
75-
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
76-
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
77-
return SciMLBase.__solve(prob, alg, args...; kwargs...)
78+
function __internal_solve_up(_prob::$(pType), sensealg, u0, u0_changed, p,
79+
p_changed, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
80+
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
81+
return SciMLBase.__solve(prob, alg, args...; kwargs...)
82+
end
83+
end
7884
end
7985

8086
@setup_workload begin

src/bracketing/bisection.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

29-
abstol = __get_tolerance(nothing, abstol,
29+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
3030
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
3131

3232
if iszero(fl)

src/bracketing/brent.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1313
fl, fr = f(left), f(right)
1414
ϵ = eps(convert(typeof(fl), 1))
1515

16-
abstol = __get_tolerance(nothing, abstol,
16+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
1717
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1818

1919
if iszero(fl)

src/bracketing/falsi.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = __get_tolerance(nothing, abstol,
15+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/bracketing/itp.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
5858
left, right = prob.tspan
5959
fl, fr = f(left), f(right)
6060

61-
abstol = __get_tolerance(nothing, abstol,
61+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
6262
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
6363

6464
if iszero(fl)

src/bracketing/ridder.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = __get_tolerance(nothing, abstol,
15+
abstol = NonlinearSolveBase.get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/linesearch.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function (cache::LiFukushimaLineSearchCache)(u, δu)
7373
fx_norm = ϕ(T(0))
7474

7575
# Non-Blocking exit if the norm is NaN or Inf
76-
DiffEqBase.NAN_CHECK(fx_norm) && return cache.α
76+
NonlinearSolveBase.NAN_CHECK(fx_norm) && return cache.α
7777

7878
# Early Terminate based on Eq. 2.7
7979
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
@@ -84,12 +84,12 @@ function (cache::LiFukushimaLineSearchCache)(u, δu)
8484
fxλp_norm = ϕ(λ₂)
8585

8686
if cache.nan_maxiters !== nothing
87-
if DiffEqBase.NAN_CHECK(fxλp_norm)
87+
if NonlinearSolveBase.NAN_CHECK(fxλp_norm)
8888
nan_converged = false
8989
for _ in 1:(cache.nan_maxiters)
9090
λ₁, λ₂ = λ₂, cache.β * λ₂
9191
fxλp_norm = ϕ(λ₂)
92-
nan_converged = DiffEqBase.NAN_CHECK(fxλp_norm)::Bool
92+
nan_converged = NonlinearSolveBase.NAN_CHECK(fxλp_norm)::Bool
9393
nan_converged && break
9494
end
9595
nan_converged || return cache.α

src/nlsolve/lbroyden.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ function __static_solve(prob::NonlinearProblem{<:SArray}, alg::SimpleLimitedMemo
121121

122122
U, Vᵀ = __init_low_rank_jacobian(vec(x), vec(fx), threshold)
123123

124-
abstol = __get_tolerance(x, abstol, eltype(x))
124+
abstol = NonlinearSolveBase.get_tolerance(x, abstol, eltype(x))
125125

126126
xo, δx, fo, δf = x, -fx, fx, fx
127127

src/utils.jl

+17-55
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ except `cache` (& `J` if not nothing) are mutated.
7777
function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F, X}
7878
if isinplace(f)
7979
_f = (du, u) -> f(du, u, p)
80-
if DiffEqBase.has_jac(f)
80+
if SciMLBase.has_jac(f)
8181
f.jac(J, x, p)
8282
_f(y, x)
8383
return y, J
@@ -97,7 +97,7 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
9797
end
9898
else
9999
_f = Base.Fix2(f, p)
100-
if DiffEqBase.has_jac(f)
100+
if SciMLBase.has_jac(f)
101101
return _f(x), f.jac(x, p)
102102
elseif ad isa AutoForwardDiff
103103
if ArrayInterface.can_setindex(x)
@@ -124,7 +124,7 @@ end
124124
function __polyester_forwarddiff_jacobian! end
125125

126126
function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where {F}
127-
if DiffEqBase.has_jac(f)
127+
if SciMLBase.has_jac(f)
128128
return f(x, p), f.jac(x, p)
129129
elseif ad isa AutoForwardDiff
130130
T = typeof(__standard_tag(ad.tag, x))
@@ -152,7 +152,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
152152
if isinplace(f)
153153
_f = (du, u) -> f(du, u, p)
154154
J = similar(y, length(y), length(x))
155-
if DiffEqBase.has_jac(f)
155+
if SciMLBase.has_jac(f)
156156
return J, nothing
157157
elseif ad isa AutoForwardDiff || ad isa AutoPolyesterForwardDiff
158158
return J, __get_jacobian_config(ad, _f, y, x)
@@ -163,7 +163,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
163163
end
164164
else
165165
_f = Base.Fix2(f, p)
166-
if DiffEqBase.has_jac(f)
166+
if SciMLBase.has_jac(f)
167167
return nothing, nothing
168168
elseif ad isa AutoForwardDiff
169169
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
@@ -292,58 +292,27 @@ function init_termination_cache(abstol, reltol, du, u, ::Nothing)
292292
return init_termination_cache(abstol, reltol, du, u, AbsNormTerminationMode())
293293
end
294294
function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
295-
T = promote_type(eltype(du), eltype(u))
296-
abstol = __get_tolerance(u, abstol, T)
297-
reltol = __get_tolerance(u, reltol, T)
298295
tc_cache = init(du, u, tc; abstol, reltol)
299-
return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
296+
return (NonlinearSolveBase.get_abstol(tc_cache),
297+
NonlinearSolveBase.get_reltol(tc_cache), tc_cache)
300298
end
301299

302300
function check_termination(tc_cache, fx, x, xo, prob, alg)
303301
return check_termination(tc_cache, fx, x, xo, prob, alg,
304-
DiffEqBase.get_termination_mode(tc_cache))
305-
end
306-
function check_termination(tc_cache, fx, x, xo, prob, alg,
307-
::AbstractNonlinearTerminationMode)
308-
if Bool(tc_cache(fx, x, xo))
309-
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
310-
end
311-
return nothing
312-
end
313-
function check_termination(tc_cache, fx, x, xo, prob, alg,
314-
::AbstractSafeNonlinearTerminationMode)
315-
if Bool(tc_cache(fx, x, xo))
316-
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
317-
retcode = ReturnCode.Success
318-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
319-
retcode = ReturnCode.ConvergenceFailure
320-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
321-
retcode = ReturnCode.Unstable
322-
else
323-
error("Unknown termination code: $(tc_cache.retcode)")
324-
end
325-
return build_solution(prob, alg, x, fx; retcode)
326-
end
327-
return nothing
302+
NonlinearSolveBase.get_termination_mode(tc_cache))
328303
end
304+
329305
function check_termination(tc_cache, fx, x, xo, prob, alg,
330-
::AbstractSafeBestNonlinearTerminationMode)
306+
mode::AbstractNonlinearTerminationMode)
331307
if Bool(tc_cache(fx, x, xo))
332-
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
333-
retcode = ReturnCode.Success
334-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
335-
retcode = ReturnCode.ConvergenceFailure
336-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
337-
retcode = ReturnCode.Unstable
338-
else
339-
error("Unknown termination code: $(tc_cache.retcode)")
340-
end
341-
if isinplace(prob)
342-
prob.f(fx, x, prob.p)
343-
else
344-
fx = prob.f(x, prob.p)
308+
if mode isa AbstractSafeBestNonlinearTerminationMode
309+
if isinplace(prob)
310+
prob.f(fx, x, prob.p)
311+
else
312+
fx = prob.f(x, prob.p)
313+
end
345314
end
346-
return build_solution(prob, alg, tc_cache.u, fx; retcode)
315+
return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode)
347316
end
348317
return nothing
349318
end
@@ -382,12 +351,5 @@ end
382351
@inline __reshape(x::Number, args...) = x
383352
@inline __reshape(x::AbstractArray, args...) = reshape(x, args...)
384353

385-
# Override cases which might be used in a kernel launch
386-
__get_tolerance(x, η, ::Type{T}) where {T} = DiffEqBase._get_tolerance(η, T)
387-
function __get_tolerance(x::Union{SArray, Number}, ::Nothing, ::Type{T}) where {T}
388-
η = real(oneunit(T)) * (eps(real(one(T))))^(real(T)(0.8))
389-
return T(η)
390-
end
391-
392354
# Extension
393355
function __zygote_compute_nlls_vjp end

test/core/23_test_problems_tests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testsetup module RobustnessTesting
2-
using LinearAlgebra, NonlinearProblemLibrary, DiffEqBase, Test
2+
using LinearAlgebra, NonlinearProblemLibrary, NonlinearSolveBase, SciMLBase, Test
33

44
problems = NonlinearProblemLibrary.problems
55
dicts = NonlinearProblemLibrary.dicts

test/core/rootfind_tests.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
@testsetup module RootfindingTesting
22
using Reexport
33
@reexport using AllocCheck,
4-
LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff, DiffEqBase
4+
LinearSolve, StaticArrays, Random, LinearAlgebra, ForwardDiff,
5+
NonlinearSolveBase
56
import PolyesterForwardDiff
67

78
quadratic_f(u, p) = u .* u .- p
@@ -89,7 +90,7 @@ end
8990
end
9091
end
9192

92-
@testitem "Derivative Free Metods" setup=[RootfindingTesting] begin
93+
@testitem "Derivative Free Methods" setup=[RootfindingTesting] begin
9394
@testset "$(nameof(typeof(alg)))" for alg in [SimpleBroyden(), SimpleKlement(),
9495
SimpleDFSane(), SimpleLimitedMemoryBroyden(),
9596
SimpleBroyden(; linesearch = Val(true)),

0 commit comments

Comments
 (0)