Skip to content

fix isinplace inference and add inference tests #1019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/problems/optimization_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ function OptimizationProblem(
OptimizationProblem{isinplace(f)}(f, args...; kwargs...)
end
function OptimizationProblem(f, args...; kwargs...)
isinplace(f, 2, has_two_dispatches = false)
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
OptimizationProblem(OptimizationFunction(f), args...; kwargs...)
end

function OptimizationFunction(
Expand Down
22 changes: 15 additions & 7 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4199,7 +4199,10 @@ IntervalNonlinearFunction(f::IntervalNonlinearFunction; kwargs...) = f
struct NoAD <: AbstractADType end

(f::OptimizationFunction)(args...) = f.f(args...)
OptimizationFunction(args...; kwargs...) = OptimizationFunction{true}(args...; kwargs...)
function OptimizationFunction(f, args...; kwargs...)
isinplace(f, 2, outofplace_param_number=2)
OptimizationFunction{true}(f, args...; kwargs...)
end

function OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
grad = nothing, fg = nothing, hess = nothing, hv = nothing, fgh = nothing,
Expand Down Expand Up @@ -4251,8 +4254,9 @@ end
(f::MultiObjectiveOptimizationFunction)(args...) = f.f(args...)

# Convenience constructor
function MultiObjectiveOptimizationFunction(args...; kwargs...)
MultiObjectiveOptimizationFunction{true}(args...; kwargs...)
function MultiObjectiveOptimizationFunction(f, args...; kwargs...)
isinplace(f, 2, outofplace_param_number=2)
MultiObjectiveOptimizationFunction{true}(f, args...; kwargs...)
end

# Constructor with keyword arguments
Expand Down Expand Up @@ -4339,15 +4343,17 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
if iip_f
jac = update_coefficients! #(J,u,p,t)
else
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
jac_prototype_copy = deepcopy(jac_prototype)
jac = (u, p, t) -> update_coefficients!(jac_prototype_copy, u, p, t)
end
end

if bcjac === nothing && isa(bcjac_prototype, AbstractSciMLOperator)
if iip_bc
bcjac = update_coefficients! #(J,u,p,t)
else
bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t)
bcjac_prototype_copy = deepcopy(bcjac_prototype)
bcjac = (u, p, t) -> update_coefficients!(bcjac_prototype_copy, u, p, t)
end
end

Expand Down Expand Up @@ -4512,15 +4518,17 @@ function DynamicalBVPFunction{iip, specialize, twopoint}(f, bc;
if iip_f
jac = update_coefficients! #(J,u,p,t)
else
jac = (u, p, t) -> update_coefficients!(deepcopy(jac_prototype), u, p, t)
jac_prototype_copy = deepcopy(jac_prototype)
jac = (u, p, t) -> update_coefficients!(jac_prototype_copy, u, p, t)
end
end

if bcjac === nothing && isa(bcjac_prototype, AbstractSciMLOperator)
if iip_bc
bcjac = update_coefficients! #(J,u,p,t)
else
bcjac = (u, p, t) -> update_coefficients!(deepcopy(bcjac_prototype), u, p, t)
bcjac_prototype_copy = deepcopy(jac_prototype)
bcjac = (u, p, t) -> update_coefficients!(bcjac_prototype_copy, u, p, t)
end
end

Expand Down
221 changes: 12 additions & 209 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,157 +42,6 @@ function num_types_in_tuple(sig::UnionAll)
length(Base.unwrap_unionall(sig).parameters)
end

const NO_METHODS_ERROR_MESSAGE = """
No methods were found for the model function passed to the equation solver.
The function `f` needs to have dispatches, for example, for an ODEProblem
`f` must define either `f(u,p,t)` or `f(du,u,p,t)`. For more information
on how the model function `f` should be defined, consult the docstring for
the appropriate `AbstractSciMLFunction`.
"""

struct NoMethodsError <: Exception
fname::String
end

function Base.showerror(io::IO, e::NoMethodsError)
println(io, NO_METHODS_ERROR_MESSAGE)
print(io, "Offending function: ")
printstyled(io, e.fname; bold = true, color = :red)
end

const TOO_MANY_ARGUMENTS_ERROR_MESSAGE = """
All methods for the model function `f` had too many arguments. For example,
an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. This error
can be thrown if you define an ODE model for example as `f(du,u,p1,p2,t)`.
For more information on the required number of arguments for the function
you were defining, consult the documentation for the `SciMLProblem` or
`SciMLFunction` type that was being constructed.

A common reason for this occurrence is due to following the MATLAB or SciPy
convention for parameter passing, i.e. to add each parameter as an argument.
In the SciML convention, if you wish to pass multiple parameters, use a
struct or other collection to hold the parameters. For example, here is the
parameterized Lorenz equation:

```julia
function lorenz(du,u,p,t)
du[1] = p[1]*(u[2]-u[1])
du[2] = u[1]*(p[2]-u[3]) - u[2]
du[3] = u[1]*u[2] - p[3]*u[3]
end
u0 = [1.0;0.0;0.0]
p = [10.0,28.0,8/3]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz,u0,tspan,p)
```

Notice that `f` is defined with a single `p`, an array which matches the definition
of the `p` in the `ODEProblem`. Note that `p` can be any Julia struct.
"""

struct TooManyArgumentsError <: Exception
fname::String
f::Any
end

function Base.showerror(io::IO, e::TooManyArgumentsError)
println(io, TOO_MANY_ARGUMENTS_ERROR_MESSAGE)
print(io, "Offending function: ")
printstyled(io, e.fname; bold = true, color = :red)
println(io, "\nMethods:")
println(io, methods(e.f))
end

const TOO_FEW_ARGUMENTS_ERROR_MESSAGE_OPTIMIZATION = """
All methods for the model function `f` had too few arguments. For example,
an OptimizationProblem `f` must define `f(u,p)` where `u` is the optimization
state and `p` are the parameters of the optimization (commonly, the hyperparameters
of the simulation).

A common reason for this error is from defining a single-input loss function
`f(u)`. While parameters are not required, a loss function which takes parameters
is required, i.e. `f(u,p)`. If you have a function `f(u)`, ignored parameters
can be easily added using a closure, i.e. `OptimizationProblem((u,_)->f(u),...)`.

For example, here is a parameterized optimization problem:

```julia
using Optimization, OptimizationOptimJL
rosenbrock(u,p) = (p[1] - u[1])^2 + p[2] * (u[2] - u[1]^2)^2
u0 = zeros(2)
p = [1.0,100.0]

prob = OptimizationProblem(rosenbrock,u0,p)
sol = solve(prob,NelderMead())
```

and a parameter-less example:

```julia
using Optimization, OptimizationOptimJL
rosenbrock(u,p) = (1 - u[1])^2 + (u[2] - u[1]^2)^2
u0 = zeros(2)

prob = OptimizationProblem(rosenbrock,u0)
sol = solve(prob,NelderMead())
```
"""

const TOO_FEW_ARGUMENTS_ERROR_MESSAGE = """
All methods for the model function `f` had too few arguments. For example,
an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`. This error
can be thrown if you define an ODE model for example as `f(u,t)`. The parameters
`p` are not optional in the definition of `f`! For more information on the required
number of arguments for the function you were defining, consult the documentation
for the `SciMLProblem` or `SciMLFunction` type that was being constructed.

For example, here is the no parameter Lorenz equation. The two valid versions
are out of place:

```julia
function lorenz(u,p,t)
du1 = 10.0*(u[2]-u[1])
du2 = u[1]*(28.0-u[3]) - u[2]
du3 = u[1]*u[2] - 8/3*u[3]
[du1,du2,du3]
end
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz,u0,tspan)
```

and in-place:

```julia
function lorenz!(du,u,p,t)
du[1] = 10.0*(u[2]-u[1])
du[2] = u[1]*(28.0-u[3]) - u[2]
du[3] = u[1]*u[2] - 8/3*u[3]
end
u0 = [1.0;0.0;0.0]
tspan = (0.0,100.0)
prob = ODEProblem(lorenz!,u0,tspan)
```
"""

struct TooFewArgumentsError <: Exception
fname::String
f::Any
isoptimization::Bool
end

function Base.showerror(io::IO, e::TooFewArgumentsError)
if e.isoptimization
println(io, TOO_FEW_ARGUMENTS_ERROR_MESSAGE_OPTIMIZATION)
else
println(io, TOO_FEW_ARGUMENTS_ERROR_MESSAGE)
end
print(io, "Offending function: ")
printstyled(io, e.fname; bold = true, color = :red)
println(io, "\nMethods:")
println(io, methods(e.f))
end

const ARGUMENTS_ERROR_MESSAGE = """
Methods dispatches for the model function `f` do not match the required number.
For example, an ODEProblem `f` must define either `f(u,p,t)` or `f(du,u,p,t)`.
Expand All @@ -207,6 +56,12 @@ struct FunctionArgumentsError <: Exception
f::Any
end

# backward compat in case anyone is using these.
# TODO: remove at next major version
const TooManyArgumentsError = FunctionArgumentsError
const TooFewArgumentsError = FunctionArgumentsError
const NoMethodsError = FunctionArgumentsError

function Base.showerror(io::IO, e::FunctionArgumentsError)
println(io, ARGUMENTS_ERROR_MESSAGE)
print(io, "Offending function: ")
Expand Down Expand Up @@ -246,66 +101,14 @@ form is disabled and the 2-argument signature is ensured to be matched.
function isinplace(f, inplace_param_number, fname = "f", iip_preferred = true;
has_two_dispatches = true, isoptimization = false,
outofplace_param_number = inplace_param_number - 1)
nargs = numargs(f)
iip_dispatch = any(x -> x == inplace_param_number, nargs)
oop_dispatch = any(x -> x == outofplace_param_number, nargs)

if length(nargs) == 0
throw(NoMethodsError(fname))
end

if !iip_dispatch && !oop_dispatch && !isoptimization
if all(>(inplace_param_number), nargs)
throw(TooManyArgumentsError(fname, f))
elseif all(<(outofplace_param_number), nargs) && has_two_dispatches
# Possible extra safety?
# Find if there's a `f(args...)` dispatch
# If so, no error
_parameters = if methods(f).ms[1].sig isa UnionAll
Base.unwrap_unionall(methods(f).ms[1].sig).parameters
else
methods(f).ms[1].sig.parameters
end

for i in 1:length(nargs)
if nargs[i] < inplace_param_number &&
any(isequal(Vararg{Any}), _parameters)
# If varargs, assume iip
return iip_preferred
end
end

# No varargs detected, error that there are dispatches but not the right ones

throw(TooFewArgumentsError(fname, f, isoptimization))
else
throw(FunctionArgumentsError(fname, f))
end
elseif oop_dispatch && !iip_dispatch && !has_two_dispatches

# Possible extra safety?
# Find if there's a `f(args...)` dispatch
# If so, no error
for i in 1:length(nargs)
if nargs[i] < inplace_param_number &&
any(isequal(Vararg{Any}), methods(f).ms[1].sig.parameters)
# If varargs, assume iip
return iip_preferred
end
end

throw(TooFewArgumentsError(fname, f, isoptimization))
if iip_preferred
hasmethod(f, ntuple(_->Any, inplace_param_number)) && return true
hasmethod(f, ntuple(_->Any, outofplace_param_number)) && return false
else
if iip_preferred
# Equivalent to, if iip_dispatch exists, treat as iip
# Otherwise, it's oop
iip_dispatch
else
# Equivalent to, if oop_dispatch exists, treat as oop
# Otherwise, it's iip
!oop_dispatch
end
hasmethod(f, ntuple(_->Any, outofplace_param_number)) && return false
hasmethod(f, ntuple(_->Any, inplace_param_number)) && return true
end
throw(FunctionArgumentsError(fname, f))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message loses a lot of information. Can we in the error path do a method check and throw the more informative error message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the better method is to just put the expected nargs in the exception and then we can make the exception printing in charge of that logic. That way none of the complicated bits are in the main function.

end

isinplace(f::AbstractSciMLFunction{iip}) where {iip} = iip
Expand Down
2 changes: 1 addition & 1 deletion test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end
# for method_ambiguity in ambs
# @show method_ambiguity
# end
@warn "Number of method ambiguities: $(length(ambs))"
!isempty(ambs) &&@warn "Number of method ambiguities: $(length(ambs))"
@test length(ambs) ≤ 8
end

Expand Down
Loading
Loading