Skip to content
This repository was archived by the owner on Apr 16, 2025. It is now read-only.

Commit a8fda19

Browse files
Merge pull request #136 from SciML/ap/nlls_term
Use the different norms for termination
2 parents 0f02306 + e37f938 commit a8fda19

File tree

9 files changed

+37
-41
lines changed

9 files changed

+37
-41
lines changed

Project.toml

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.6.0"
4+
version = "1.7.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -39,13 +39,13 @@ SimpleNonlinearSolveZygoteExt = "Zygote"
3939
ADTypes = "0.2.6"
4040
AllocCheck = "0.1.1"
4141
Aqua = "0.8"
42-
ArrayInterface = "7.7"
42+
ArrayInterface = "7.8"
4343
CUDA = "5.2"
4444
ChainRulesCore = "1.22"
4545
ConcreteStructs = "0.2.3"
46-
DiffEqBase = "6.146"
46+
DiffEqBase = "6.149"
4747
DiffResults = "1.1"
48-
FastClosures = "0.3"
48+
FastClosures = "0.3.2"
4949
FiniteDiff = "2.22"
5050
ForwardDiff = "0.10.36"
5151
LinearAlgebra = "1.10"
@@ -59,7 +59,7 @@ Random = "1.10"
5959
ReTestItems = "1.23"
6060
Reexport = "1.2"
6161
ReverseDiff = "1.15"
62-
SciMLBase = "2.26.3"
62+
SciMLBase = "2.28.0"
6363
SciMLSensitivity = "7.56"
6464
StaticArrays = "1.9"
6565
StaticArraysCore = "1.4.2"

src/nlsolve/broyden.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
4848
@bb δJ⁻¹n = copy(x)
4949
@bb δJ⁻¹ = copy(J⁻¹)
5050

51-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
51+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
5252
termination_condition)
5353

5454
ls_cache = __get_linesearch(alg) === Val(true) ?

src/nlsolve/dfsane.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
7070
τ_min = T(alg.τ_min)
7171
τ_max = T(alg.τ_max)
7272

73-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
73+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
7474
termination_condition)
7575

7676
fx_norm = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp

src/nlsolve/halley.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
3434
T = eltype(x)
3535

3636
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
37-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
37+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
3838
termination_condition)
3939

4040
@bb xo = copy(x)

src/nlsolve/klement.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
1313
T = eltype(x)
1414
fx = _get_fx(prob, x)
1515

16-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
16+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
1717
termination_condition)
1818

1919
@bb δx = copy(x)

src/nlsolve/lbroyden.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ end
6161

6262
U, Vᵀ = __init_low_rank_jacobian(x, fx, x isa StaticArray ? threshold : Val(η))
6363

64-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
64+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
6565
termination_condition)
6666

6767
@bb xo = copy(x)

src/nlsolve/raphson.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr
3232
@bb xo = copy(x)
3333
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
3434

35-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
35+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
3636
termination_condition)
3737

3838
for i in 1:maxiters

src/nlsolve/trustRegion.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
8888
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
8989
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
9090

91-
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
91+
abstol, reltol, tc_cache = init_termination_cache(prob, abstol, reltol, fx, x,
9292
termination_condition)
9393

9494
# Set default trust region radius if not specified by user.

src/utils.jl

+25-29
Original file line numberDiff line numberDiff line change
@@ -288,14 +288,30 @@ end
288288
# different. NonlinearSolve is more for robust / cached solvers while SimpleNonlinearSolve
289289
# is meant for low overhead solvers, users can opt into the other termination modes but the
290290
# default is to use the least overhead version.
291-
function init_termination_cache(abstol, reltol, du, u, ::Nothing)
292-
return init_termination_cache(abstol, reltol, du, u, AbsNormTerminationMode())
291+
function init_termination_cache(prob::NonlinearProblem, abstol, reltol, du, u, ::Nothing)
292+
return init_termination_cache(prob, abstol, reltol, du, u,
293+
AbsNormTerminationMode(Base.Fix1(maximum, abs)))
293294
end
294-
function init_termination_cache(abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
295+
function init_termination_cache(
296+
prob::NonlinearLeastSquaresProblem, abstol, reltol, du, u, ::Nothing)
297+
return init_termination_cache(prob, abstol, reltol, du, u,
298+
AbsNormTerminationMode(Base.Fix2(norm, 2)))
299+
end
300+
301+
function init_termination_cache(
302+
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
303+
abstol, reltol, du, u, tc::AbstractNonlinearTerminationMode)
295304
T = promote_type(eltype(du), eltype(u))
296305
abstol = __get_tolerance(u, abstol, T)
297306
reltol = __get_tolerance(u, reltol, T)
298-
tc_cache = init(du, u, tc; abstol, reltol)
307+
tc_ = if hasfield(typeof(tc), :internalnorm) && tc.internalnorm === nothing
308+
internalnorm = ifelse(
309+
prob isa NonlinearProblem, Base.Fix1(maximum, abs), Base.Fix2(norm, 2))
310+
DiffEqBase.set_termination_mode_internalnorm(tc, internalnorm)
311+
else
312+
tc
313+
end
314+
tc_cache = init(du, u, tc_; abstol, reltol, use_deprecated_retcodes = Val(false))
299315
return DiffEqBase.get_abstol(tc_cache), DiffEqBase.get_reltol(tc_cache), tc_cache
300316
end
301317

@@ -305,45 +321,25 @@ function check_termination(tc_cache, fx, x, xo, prob, alg)
305321
end
306322
function check_termination(tc_cache, fx, x, xo, prob, alg,
307323
::AbstractNonlinearTerminationMode)
308-
if Bool(tc_cache(fx, x, xo))
324+
tc_cache(fx, x, xo) &&
309325
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
310-
end
311326
return nothing
312327
end
313328
function check_termination(tc_cache, fx, x, xo, prob, alg,
314329
::AbstractSafeNonlinearTerminationMode)
315-
if Bool(tc_cache(fx, x, xo))
316-
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
317-
retcode = ReturnCode.Success
318-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
319-
retcode = ReturnCode.ConvergenceFailure
320-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
321-
retcode = ReturnCode.Unstable
322-
else
323-
error("Unknown termination code: $(tc_cache.retcode)")
324-
end
325-
return build_solution(prob, alg, x, fx; retcode)
326-
end
330+
tc_cache(fx, x, xo) &&
331+
return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode)
327332
return nothing
328333
end
329334
function check_termination(tc_cache, fx, x, xo, prob, alg,
330335
::AbstractSafeBestNonlinearTerminationMode)
331-
if Bool(tc_cache(fx, x, xo))
332-
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
333-
retcode = ReturnCode.Success
334-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
335-
retcode = ReturnCode.ConvergenceFailure
336-
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
337-
retcode = ReturnCode.Unstable
338-
else
339-
error("Unknown termination code: $(tc_cache.retcode)")
340-
end
336+
if tc_cache(fx, x, xo)
341337
if isinplace(prob)
342338
prob.f(fx, x, prob.p)
343339
else
344340
fx = prob.f(x, prob.p)
345341
end
346-
return build_solution(prob, alg, tc_cache.u, fx; retcode)
342+
return build_solution(prob, alg, tc_cache.u, fx; retcode = tc_cache.retcode)
347343
end
348344
return nothing
349345
end

0 commit comments

Comments
 (0)