@@ -5,8 +5,17 @@ function SciMLBase.solve(
5
5
sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
6
6
dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
7
7
return SciMLBase. build_solution (
8
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
9
- sol. original)
8
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
9
+ end
10
+
11
+ function SciMLBase. solve (
12
+ prob:: NonlinearLeastSquaresProblem {<: AbstractArray ,
13
+ iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
14
+ alg:: AbstractSimpleNonlinearSolveAlgorithm , args... ; kwargs... ) where {T, V, P, iip}
15
+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
16
+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
17
+ return SciMLBase. build_solution (
18
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
10
19
end
11
20
12
21
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -24,7 +33,8 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
24
33
end
25
34
end
26
35
27
- function __nlsolve_ad (prob, alg, args... ; kwargs... )
36
+ function __nlsolve_ad (
37
+ prob:: Union{IntervalNonlinearProblem, NonlinearProblem} , alg, args... ; kwargs... )
28
38
p = value (prob. p)
29
39
if prob isa IntervalNonlinearProblem
30
40
tspan = value .(prob. tspan)
@@ -55,6 +65,96 @@ function __nlsolve_ad(prob, alg, args...; kwargs...)
55
65
return sol, partials
56
66
end
57
67
68
+ function __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
69
+ p = value (prob. p)
70
+ u0 = value (prob. u0)
71
+ newprob = NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
72
+
73
+ sol = solve (newprob, alg, args... ; kwargs... )
74
+
75
+ uu = sol. u
76
+
77
+ # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
78
+ # nested autodiff as the last resort
79
+ if SciMLBase. has_vjp (prob. f)
80
+ if isinplace (prob)
81
+ _F = @closure (du, u, p) -> begin
82
+ resid = similar (du, length (sol. resid))
83
+ prob. f (resid, u, p)
84
+ prob. f. vjp (du, resid, u, p)
85
+ du .*= 2
86
+ return nothing
87
+ end
88
+ else
89
+ _F = @closure (u, p) -> begin
90
+ resid = prob. f (u, p)
91
+ return reshape (2 .* prob. f. vjp (resid, u, p), size (u))
92
+ end
93
+ end
94
+ elseif SciMLBase. has_jac (prob. f)
95
+ if isinplace (prob)
96
+ _F = @closure (du, u, p) -> begin
97
+ J = similar (du, length (sol. resid), length (u))
98
+ prob. f. jac (J, u, p)
99
+ resid = similar (du, length (sol. resid))
100
+ prob. f (resid, u, p)
101
+ mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
102
+ return nothing
103
+ end
104
+ else
105
+ _F = @closure (u, p) -> begin
106
+ return reshape (2 .* vec (prob. f (u, p))' * prob. f. jac (u, p), size (u))
107
+ end
108
+ end
109
+ else
110
+ if isinplace (prob)
111
+ _F = @closure (du, u, p) -> begin
112
+ resid = similar (du, length (sol. resid))
113
+ res = DiffResults. DiffResult (
114
+ resid, similar (du, length (sol. resid), length (u)))
115
+ _f = @closure (du, u) -> prob. f (du, u, p)
116
+ ForwardDiff. jacobian! (res, _f, resid, u)
117
+ mul! (reshape (du, 1 , :), vec (DiffResults. value (res))' ,
118
+ DiffResults. jacobian (res), 2 , false )
119
+ return nothing
120
+ end
121
+ else
122
+ # For small problems, nesting ForwardDiff is actually quite fast
123
+ if __is_extension_loaded (Val (:Zygote )) && (length (uu) + length (sol. resid) ≥ 50 )
124
+ _F = @closure (u, p) -> __zygote_compute_nlls_vjp (prob. f, u, p)
125
+ else
126
+ _F = @closure (u, p) -> begin
127
+ T = promote_type (eltype (u), eltype (p))
128
+ res = DiffResults. DiffResult (
129
+ similar (u, T, size (sol. resid)), similar (
130
+ u, T, length (sol. resid), length (u)))
131
+ ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
132
+ return reshape (
133
+ 2 .* vec (DiffResults. value (res))' * DiffResults. jacobian (res),
134
+ size (u))
135
+ end
136
+ end
137
+ end
138
+ end
139
+
140
+ f_p = __nlsolve_∂f_∂p (prob, _F, uu, p)
141
+ f_x = __nlsolve_∂f_∂u (prob, _F, uu, p)
142
+
143
+ z_arr = - f_x \ f_p
144
+
145
+ pp = prob. p
146
+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
147
+ if uu isa Number
148
+ partials = sum (sumfun, zip (z_arr, pp))
149
+ elseif p isa Number
150
+ partials = sumfun ((z_arr, pp))
151
+ else
152
+ partials = sum (sumfun, zip (eachcol (z_arr), pp))
153
+ end
154
+
155
+ return sol, partials
156
+ end
157
+
58
158
@inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
59
159
if isinplace (prob)
60
160
__f = p -> begin
0 commit comments