Skip to content

feat: add LinearProblem codegen #3717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 12, 2025
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ LabelledArrays = "1.3"
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
Libdl = "1"
LinearAlgebra = "1"
LinearSolve = "3"
Logging = "1"
MLStyle = "0.4.17"
ModelingToolkitStandardLibrary = "2.20"
Expand All @@ -148,7 +149,7 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.91.1"
SciMLBase = "2.100.0"
SciMLPublic = "1.0.0"
SciMLStructures = "1.7"
Serialization = "1"
Expand Down Expand Up @@ -180,6 +181,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Expand All @@ -205,4 +207,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve"]
3 changes: 3 additions & 0 deletions docs/src/API/codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ ModelingToolkit.generate_constraint_hessian
ModelingToolkit.generate_control_jacobian
ModelingToolkit.build_explicit_observed_function
ModelingToolkit.generate_control_function
ModelingToolkit.generate_update_A
ModelingToolkit.generate_update_b
```

For functions such as jacobian calculation which require symbolic computation, there
Expand All @@ -42,6 +44,7 @@ ModelingToolkit.cost_hessian_sparsity
ModelingToolkit.calculate_constraint_jacobian
ModelingToolkit.calculate_constraint_hessian
ModelingToolkit.calculate_control_jacobian
ModelingToolkit.calculate_A_b
```

All code generation eventually calls `build_function_wrapper`.
Expand Down
3 changes: 2 additions & 1 deletion docs/src/API/problems.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ SciMLBase.DiscreteProblem
SciMLBase.ImplicitDiscreteProblem
```

## Nonlinear systems
## Linear and Nonlinear systems

```@docs
SciMLBase.NonlinearFunction
Expand All @@ -41,6 +41,7 @@ SciMLBase.IntervalNonlinearFunction
SciMLBase.IntervalNonlinearProblem
ModelingToolkit.HomotopyContinuationProblem
SciMLBase.HomotopyNonlinearFunction
SciMLBase.LinearProblem
```

## Optimization and optimal control
Expand Down
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ include("problems/jumpproblem.jl")
include("problems/initializationproblem.jl")
include("problems/sccnonlinearproblem.jl")
include("problems/bvproblem.jl")
include("problems/linearproblem.jl")

include("modelingtoolkitize/common.jl")
include("modelingtoolkitize/odeproblem.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/problems/compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,12 @@ function check_no_equations(sys::System, T)
"""))
end
end

function check_affine(sys::System, T)
if !isaffine(sys)
throw(SystemCompatibilityError("""
A non-affine system cannot be used to construct a `$T`. Consider a
`NonlinearProblem` instead.
"""))
end
end
29 changes: 29 additions & 0 deletions src/problems/docs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,32 @@ $PROBLEM_INTERNALS_HEADER

$PROBLEM_INTERNAL_KWARGS
""" SciMLBase.IntervalNonlinearProblem

@doc """
SciMLBase.LinearProblem(sys::System, op; kwargs...)
SciMLBase.LinearProblem{iip}(sys::System, op; kwargs...)

Build a `LinearProblem` given a system `sys` and operating point `op`. `iip` is a boolean
indicating whether the problem should be in-place. The operating point should be an
iterable collection of key-value pairs mapping variables/parameters in the system to the
(initial) values they should take in `LinearProblem`. Any values not provided will
fallback to the corresponding default (if present).

Note that since `u0` is optional for `LinearProblem`, values of unknowns do not need to be
specified in `op` to create a `LinearProblem`. In such a case, `prob.u0` will be `nothing`
and attempting to symbolically index the problem with an unknown, observable, or expression
depending on unknowns/observables will error.

Updating the parameters automatically updates the `A` and `b` arrays.

# Keyword arguments

$PROBLEM_KWARGS
$(prob_fun_common_kwargs(LinearProblem, false))

All other keyword arguments are forwarded to the $func constructor.

$PROBLEM_INTERNALS_HEADER

$PROBLEM_INTERNAL_KWARGS
""" SciMLBase.LinearProblem
98 changes: 98 additions & 0 deletions src/problems/linearproblem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
end

function SciMLBase.LinearProblem(sys::System, op::StaticArray; kwargs...)
SciMLBase.LinearProblem{false}(sys, op; kwargs...)
end

function SciMLBase.LinearProblem{iip}(
sys::System, op; check_length = true, expression = Val{false},
check_compatibility = true, sparse = false, eval_expression = false,
eval_module = @__MODULE__, checkbounds = false, cse = true,
u0_constructor = identity, u0_eltype = nothing, kwargs...) where {iip}
check_complete(sys, LinearProblem)
check_compatibility && check_compatible_system(LinearProblem, sys)

_, u0, p = process_SciMLProblem(
EmptySciMLFunction{iip}, sys, op; check_length, expression,
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
kwargs...)

if any(x -> symbolic_type(x) != NotSymbolic(), u0)
u0 = nothing
end

u0Type = typeof(op)
floatT = if u0 === nothing
calculate_float_type(op, u0Type)
else
eltype(u0)
end
u0_eltype = something(u0_eltype, floatT)

u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)

A, b = calculate_A_b(sys; sparse)
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
eval_module, checkbounds, cse, kwargs...)
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
eval_module, checkbounds, cse, kwargs...)
observedfun = ObservedFunctionCache(
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
cse)

if expression == Val{true}
symbolic_interface = quote
update_A = $update_A
update_b = $update_b
sys = $sys
observedfun = $observedfun
$(SciMLBase.SymbolicLinearInterface)(
update_A, update_b, sys, observedfun, nothing)
end
get_A = build_explicit_observed_function(
sys, A; param_only = true, eval_expression, eval_module)
if sparse
get_A = SparseArrays.sparse ∘ get_A
end
get_b = build_explicit_observed_function(
sys, b; param_only = true, eval_expression, eval_module)
A = u0_constructor(get_A(p))
b = u0_constructor(get_b(p))
else
symbolic_interface = SciMLBase.SymbolicLinearInterface(
update_A, update_b, sys, observedfun, nothing)
A = u0_constructor(update_A(p))
b = u0_constructor(update_b(p))
end

kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
args = (; A, b, p)

return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
end

# For remake
function SciMLBase.get_new_A_b(
sys::AbstractSystem, f::SciMLBase.SymbolicLinearInterface, p, A, b; kw...)
if ArrayInterface.ismutable(A)
f.update_A!(A, p)
f.update_b!(b, p)
else
# The generated function has both IIP and OOP variants
A = StaticArraysCore.similar_type(A)(f.update_A!(p))
b = StaticArraysCore.similar_type(b)(f.update_b!(p))
end
return A, b
end

function check_compatible_system(T::Type{LinearProblem}, sys::System)
check_time_independent(sys, T)
check_affine(sys, T)
check_not_dde(sys)
check_no_cost(sys, T)
check_no_constraints(sys, T)
check_no_jumps(sys, T)
check_no_noise(sys, T)
end
4 changes: 2 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1780,13 +1780,13 @@ function preface(sys::AbstractSystem)
end

function islinear(sys::AbstractSystem)
rhs = [eq.rhs for eq in equations(sys)]
rhs = [eq.rhs for eq in full_equations(sys)]

all(islinear(r, unknowns(sys)) for r in rhs)
end

function isaffine(sys::AbstractSystem)
rhs = [eq.rhs for eq in equations(sys)]
rhs = [eq.rhs for eq in full_equations(sys)]

all(isaffine(r, unknowns(sys)) for r in rhs)
end
Expand Down
87 changes: 87 additions & 0 deletions src/systems/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,90 @@ function build_explicit_observed_function(sys, ts;
return f
end
end

"""
$(TYPEDSIGNATURES)

Return matrix `A` and vector `b` such that the system `sys` can be represented as
`A * x = b` where `x` is `unknowns(sys)`. Errors if the system is not affine.

# Keyword arguments

- `sparse`: return a sparse `A`.
"""
function calculate_A_b(sys::System; sparse = false)
rhss = [eq.rhs for eq in full_equations(sys)]
dvs = unknowns(sys)

A = Matrix{Any}(undef, length(rhss), length(dvs))
b = Vector{Any}(undef, length(rhss))
for (i, rhs) in enumerate(rhss)
# mtkcompile makes this `0 ~ rhs` which typically ends up giving
# unknowns negative coefficients. If given the equations `A * x ~ b`
# it will simplify to `0 ~ b - A * x`. Thus this negation usually leads
# to more comprehensible user API.
resid = -rhs
for (j, var) in enumerate(dvs)
p, q, islinear = Symbolics.linear_expansion(resid, var)
if !islinear
throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var."))
end
A[i, j] = p
resid = q
end
# negate beucause `resid` is the residual on the LHS
b[i] = -resid
end

@assert all(Base.Fix1(isassigned, A), eachindex(A))
@assert all(Base.Fix1(isassigned, A), eachindex(b))

if sparse
A = SparseArrays.sparse(A)
end
return A, b
end

"""
$(TYPEDSIGNATURES)

Given a system `sys` and the `A` from [`calculate_A_b`](@ref) generate the function that
updates `A` given the parameter object.

# Keyword arguments

$GENERATE_X_KWARGS

All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
"""
function generate_update_A(sys::System, A::AbstractMatrix; expression = Val{true},
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
ps = reorder_parameters(sys)

res = build_function_wrapper(sys, A, ps...; p_start = 1, expression = Val{true},
similarto = typeof(A), kwargs...)
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
eval_expression, eval_module)
end

"""
$(TYPEDSIGNATURES)

Given a system `sys` and the `b` from [`calculate_A_b`](@ref) generate the function that
updates `b` given the parameter object.

# Keyword arguments

$GENERATE_X_KWARGS

All other keyword arguments are forwarded to [`build_function_wrapper`](@ref).
"""
function generate_update_b(sys::System, b::AbstractVector; expression = Val{true},
wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
ps = reorder_parameters(sys)

res = build_function_wrapper(sys, b, ps...; p_start = 1, expression = Val{true},
similarto = typeof(b), kwargs...)
return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res;
eval_expression, eval_module)
end
1 change: 1 addition & 0 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ function SciMLBase.late_binding_update_u0_p(
prob, sys::AbstractSystem, u0, p, t0, newu0, newp)
supports_initialization(sys) || return newu0, newp
prob isa IntervalNonlinearProblem && return newu0, newp
prob isa LinearProblem && return newu0, newp

initdata = prob.f.initialization_data
meta = initdata === nothing ? nothing : initdata.metadata
Expand Down
Loading
Loading