diff --git a/Project.toml b/Project.toml index eb7333ae47..faf9e4eaf9 100644 --- a/Project.toml +++ b/Project.toml @@ -129,6 +129,7 @@ LabelledArrays = "1.3" Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16" Libdl = "1" LinearAlgebra = "1" +LinearSolve = "3" Logging = "1" MLStyle = "0.4.17" ModelingToolkitStandardLibrary = "2.20" @@ -148,7 +149,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.91.1" +SciMLBase = "2.100.0" SciMLPublic = "1.0.0" SciMLStructures = "1.7" Serialization = "1" @@ -180,6 +181,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" @@ -205,4 +207,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase"] +test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEqMIRK", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve", "Logging", "OptimizationBase", "LinearSolve"] diff --git a/docs/src/API/codegen.md b/docs/src/API/codegen.md index cd76d34522..4f31405174 100644 --- a/docs/src/API/codegen.md +++ b/docs/src/API/codegen.md @@ -21,6 +21,8 @@ ModelingToolkit.generate_constraint_hessian ModelingToolkit.generate_control_jacobian ModelingToolkit.build_explicit_observed_function ModelingToolkit.generate_control_function +ModelingToolkit.generate_update_A +ModelingToolkit.generate_update_b ``` For functions such as jacobian calculation which require symbolic computation, there @@ -42,6 +44,7 @@ ModelingToolkit.cost_hessian_sparsity ModelingToolkit.calculate_constraint_jacobian ModelingToolkit.calculate_constraint_hessian ModelingToolkit.calculate_control_jacobian +ModelingToolkit.calculate_A_b ``` All code generation eventually calls `build_function_wrapper`. diff --git a/docs/src/API/problems.md b/docs/src/API/problems.md index d308785d77..72147a7e09 100644 --- a/docs/src/API/problems.md +++ b/docs/src/API/problems.md @@ -29,7 +29,7 @@ SciMLBase.DiscreteProblem SciMLBase.ImplicitDiscreteProblem ``` -## Nonlinear systems +## Linear and Nonlinear systems ```@docs SciMLBase.NonlinearFunction @@ -41,6 +41,7 @@ SciMLBase.IntervalNonlinearFunction SciMLBase.IntervalNonlinearProblem ModelingToolkit.HomotopyContinuationProblem SciMLBase.HomotopyNonlinearFunction +SciMLBase.LinearProblem ``` ## Optimization and optimal control diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 0a7fdd5209..de360d624d 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -188,6 +188,7 @@ include("problems/jumpproblem.jl") include("problems/initializationproblem.jl") include("problems/sccnonlinearproblem.jl") include("problems/bvproblem.jl") +include("problems/linearproblem.jl") include("modelingtoolkitize/common.jl") include("modelingtoolkitize/odeproblem.jl") diff --git a/src/problems/compatibility.jl b/src/problems/compatibility.jl index e5608ce4f1..9d5abf926e 100644 --- a/src/problems/compatibility.jl +++ b/src/problems/compatibility.jl @@ -169,3 +169,12 @@ function check_no_equations(sys::System, T) """)) end end + +function check_affine(sys::System, T) + if !isaffine(sys) + throw(SystemCompatibilityError(""" + A non-affine system cannot be used to construct a `$T`. Consider a + `NonlinearProblem` instead. + """)) + end +end diff --git a/src/problems/docs.jl b/src/problems/docs.jl index 94b77b8772..17bc2c83c6 100644 --- a/src/problems/docs.jl +++ b/src/problems/docs.jl @@ -391,3 +391,32 @@ $PROBLEM_INTERNALS_HEADER $PROBLEM_INTERNAL_KWARGS """ SciMLBase.IntervalNonlinearProblem + +@doc """ + SciMLBase.LinearProblem(sys::System, op; kwargs...) + SciMLBase.LinearProblem{iip}(sys::System, op; kwargs...) + +Build a `LinearProblem` given a system `sys` and operating point `op`. `iip` is a boolean +indicating whether the problem should be in-place. The operating point should be an +iterable collection of key-value pairs mapping variables/parameters in the system to the +(initial) values they should take in `LinearProblem`. Any values not provided will +fallback to the corresponding default (if present). + +Note that since `u0` is optional for `LinearProblem`, values of unknowns do not need to be +specified in `op` to create a `LinearProblem`. In such a case, `prob.u0` will be `nothing` +and attempting to symbolically index the problem with an unknown, observable, or expression +depending on unknowns/observables will error. + +Updating the parameters automatically updates the `A` and `b` arrays. + +# Keyword arguments + +$PROBLEM_KWARGS +$(prob_fun_common_kwargs(LinearProblem, false)) + +All other keyword arguments are forwarded to the $func constructor. + +$PROBLEM_INTERNALS_HEADER + +$PROBLEM_INTERNAL_KWARGS +""" SciMLBase.LinearProblem diff --git a/src/problems/linearproblem.jl b/src/problems/linearproblem.jl new file mode 100644 index 0000000000..26d5c932bd --- /dev/null +++ b/src/problems/linearproblem.jl @@ -0,0 +1,98 @@ +function SciMLBase.LinearProblem(sys::System, op; kwargs...) + SciMLBase.LinearProblem{true}(sys, op; kwargs...) +end + +function SciMLBase.LinearProblem(sys::System, op::StaticArray; kwargs...) + SciMLBase.LinearProblem{false}(sys, op; kwargs...) +end + +function SciMLBase.LinearProblem{iip}( + sys::System, op; check_length = true, expression = Val{false}, + check_compatibility = true, sparse = false, eval_expression = false, + eval_module = @__MODULE__, checkbounds = false, cse = true, + u0_constructor = identity, u0_eltype = nothing, kwargs...) where {iip} + check_complete(sys, LinearProblem) + check_compatibility && check_compatible_system(LinearProblem, sys) + + _, u0, p = process_SciMLProblem( + EmptySciMLFunction{iip}, sys, op; check_length, expression, + build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype, + kwargs...) + + if any(x -> symbolic_type(x) != NotSymbolic(), u0) + u0 = nothing + end + + u0Type = typeof(op) + floatT = if u0 === nothing + calculate_float_type(op, u0Type) + else + eltype(u0) + end + u0_eltype = something(u0_eltype, floatT) + + u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype) + + A, b = calculate_A_b(sys; sparse) + update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression, + eval_module, checkbounds, cse, kwargs...) + update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression, + eval_module, checkbounds, cse, kwargs...) + observedfun = ObservedFunctionCache( + sys; steady_state = false, expression, eval_expression, eval_module, checkbounds, + cse) + + if expression == Val{true} + symbolic_interface = quote + update_A = $update_A + update_b = $update_b + sys = $sys + observedfun = $observedfun + $(SciMLBase.SymbolicLinearInterface)( + update_A, update_b, sys, observedfun, nothing) + end + get_A = build_explicit_observed_function( + sys, A; param_only = true, eval_expression, eval_module) + if sparse + get_A = SparseArrays.sparse ∘ get_A + end + get_b = build_explicit_observed_function( + sys, b; param_only = true, eval_expression, eval_module) + A = u0_constructor(get_A(p)) + b = u0_constructor(get_b(p)) + else + symbolic_interface = SciMLBase.SymbolicLinearInterface( + update_A, update_b, sys, observedfun, nothing) + A = u0_constructor(update_A(p)) + b = u0_constructor(update_b(p)) + end + + kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface) + args = (; A, b, p) + + return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...) +end + +# For remake +function SciMLBase.get_new_A_b( + sys::AbstractSystem, f::SciMLBase.SymbolicLinearInterface, p, A, b; kw...) + if ArrayInterface.ismutable(A) + f.update_A!(A, p) + f.update_b!(b, p) + else + # The generated function has both IIP and OOP variants + A = StaticArraysCore.similar_type(A)(f.update_A!(p)) + b = StaticArraysCore.similar_type(b)(f.update_b!(p)) + end + return A, b +end + +function check_compatible_system(T::Type{LinearProblem}, sys::System) + check_time_independent(sys, T) + check_affine(sys, T) + check_not_dde(sys) + check_no_cost(sys, T) + check_no_constraints(sys, T) + check_no_jumps(sys, T) + check_no_noise(sys, T) +end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 1d92416455..e643be904d 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1780,13 +1780,13 @@ function preface(sys::AbstractSystem) end function islinear(sys::AbstractSystem) - rhs = [eq.rhs for eq in equations(sys)] + rhs = [eq.rhs for eq in full_equations(sys)] all(islinear(r, unknowns(sys)) for r in rhs) end function isaffine(sys::AbstractSystem) - rhs = [eq.rhs for eq in equations(sys)] + rhs = [eq.rhs for eq in full_equations(sys)] all(isaffine(r, unknowns(sys)) for r in rhs) end diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 96ff14f58a..4a68f935e8 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -1130,3 +1130,90 @@ function build_explicit_observed_function(sys, ts; return f end end + +""" + $(TYPEDSIGNATURES) + +Return matrix `A` and vector `b` such that the system `sys` can be represented as +`A * x = b` where `x` is `unknowns(sys)`. Errors if the system is not affine. + +# Keyword arguments + +- `sparse`: return a sparse `A`. +""" +function calculate_A_b(sys::System; sparse = false) + rhss = [eq.rhs for eq in full_equations(sys)] + dvs = unknowns(sys) + + A = Matrix{Any}(undef, length(rhss), length(dvs)) + b = Vector{Any}(undef, length(rhss)) + for (i, rhs) in enumerate(rhss) + # mtkcompile makes this `0 ~ rhs` which typically ends up giving + # unknowns negative coefficients. If given the equations `A * x ~ b` + # it will simplify to `0 ~ b - A * x`. Thus this negation usually leads + # to more comprehensible user API. + resid = -rhs + for (j, var) in enumerate(dvs) + p, q, islinear = Symbolics.linear_expansion(resid, var) + if !islinear + throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var.")) + end + A[i, j] = p + resid = q + end + # negate beucause `resid` is the residual on the LHS + b[i] = -resid + end + + @assert all(Base.Fix1(isassigned, A), eachindex(A)) + @assert all(Base.Fix1(isassigned, A), eachindex(b)) + + if sparse + A = SparseArrays.sparse(A) + end + return A, b +end + +""" + $(TYPEDSIGNATURES) + +Given a system `sys` and the `A` from [`calculate_A_b`](@ref) generate the function that +updates `A` given the parameter object. + +# Keyword arguments + +$GENERATE_X_KWARGS + +All other keyword arguments are forwarded to [`build_function_wrapper`](@ref). +""" +function generate_update_A(sys::System, A::AbstractMatrix; expression = Val{true}, + wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...) + ps = reorder_parameters(sys) + + res = build_function_wrapper(sys, A, ps...; p_start = 1, expression = Val{true}, + similarto = typeof(A), kwargs...) + return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res; + eval_expression, eval_module) +end + +""" + $(TYPEDSIGNATURES) + +Given a system `sys` and the `b` from [`calculate_A_b`](@ref) generate the function that +updates `b` given the parameter object. + +# Keyword arguments + +$GENERATE_X_KWARGS + +All other keyword arguments are forwarded to [`build_function_wrapper`](@ref). +""" +function generate_update_b(sys::System, b::AbstractVector; expression = Val{true}, + wrap_gfw = Val{false}, eval_expression = false, eval_module = @__MODULE__, kwargs...) + ps = reorder_parameters(sys) + + res = build_function_wrapper(sys, b, ps...; p_start = 1, expression = Val{true}, + similarto = typeof(b), kwargs...) + return maybe_compile_function(expression, wrap_gfw, (1, 1, is_split(sys)), res; + eval_expression, eval_module) +end diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index eab33e00b3..93fa988b7d 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -727,6 +727,7 @@ function SciMLBase.late_binding_update_u0_p( prob, sys::AbstractSystem, u0, p, t0, newu0, newp) supports_initialization(sys) || return newu0, newp prob isa IntervalNonlinearProblem && return newu0, newp + prob isa LinearProblem && return newu0, newp initdata = prob.f.initialization_data meta = initdata === nothing ? nothing : initdata.metadata diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 4ec875c83e..e24f37331f 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -376,7 +376,7 @@ function varmap_to_vars(varmap::AbstractDict, vars::Vector; if toterm !== nothing add_toterms!(varmap; toterm) end - if check + if check && !allow_symbolic missing_vars = missingvars(varmap, vars; toterm) if !isempty(missing_vars) if is_initializeprob @@ -387,7 +387,7 @@ function varmap_to_vars(varmap::AbstractDict, vars::Vector; end end evaluate_varmap!(varmap, vars; limit = substitution_limit) - vals = map(x -> varmap[x], vars) + vals = map(x -> get(varmap, x, x), vars) if !allow_symbolic missingsyms = Any[] missingvals = Any[] @@ -1192,6 +1192,20 @@ function float_type_from_varmap(varmap, floatT = Bool) return float(floatT) end +""" + $(TYPEDSIGNATURES) + +Calculate the floating point type to use from the given `varmap` by looking at variables +with a constant value. `u0Type` takes priority if it is a real-valued array type. +""" +function calculate_float_type(varmap, u0Type::Type, floatT = Bool) + if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{} + return float(eltype(u0Type)) + else + return float_type_from_varmap(varmap, floatT) + end +end + """ $(TYPEDSIGNATURES) @@ -1208,6 +1222,41 @@ function calculate_resid_prototype(N::Int, u0, p) return zeros(u0ElType, N) end +""" + $(TYPEDSIGNATURES) + +Given the user-provided value of `u0_constructor`, the container type of user-provided +`op`, the desired floating point type and whether a symbolic `u0` is allowed, return the +updated `u0_constructor`. +""" +function get_u0_constructor(u0_constructor, u0Type::Type, floatT::Type, symbolic_u0::Bool) + u0_constructor === identity || return u0_constructor + u0Type <: StaticArray || return u0_constructor + return function (vals) + elT = if symbolic_u0 && any(x -> symbolic_type(x) != NotSymbolic(), vals) + nothing + else + floatT + end + SymbolicUtils.Code.create_array(u0Type, elT, Val(1), Val(length(vals)), vals...) + end +end + +""" + $(TYPEDSIGNATURES) + +Given the user-provided value of `p_constructor`, the container type of user-provided `op`, +ans the desired floating point type, return the updated `p_constructor`. +""" +function get_p_constructor(p_constructor, pType::Type, floatT::Type) + p_constructor === identity || return p_constructor + pType <: StaticArray || return p_constructor + return function (vals) + SymbolicUtils.Code.create_array( + pType, floatT, Val(ndims(vals)), Val(size(vals)), vals...) + end +end + """ $(TYPEDSIGNATURES) @@ -1274,26 +1323,15 @@ function process_SciMLProblem( missing_unknowns, missing_pars = build_operating_point!(sys, op, u0map, pmap, defs, dvs, ps) - floatT = Bool - if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{} - floatT = float(eltype(u0Type)) - else - floatT = float_type_from_varmap(op, floatT) - end - + floatT = calculate_float_type(op, u0Type) u0_eltype = something(u0_eltype, floatT) if !is_time_dependent(sys) || is_initializesystem(sys) add_observed_equations!(op, obs) end - if u0_constructor === identity && u0Type <: StaticArray - u0_constructor = vals -> SymbolicUtils.Code.create_array( - u0Type, floatT, Val(1), Val(length(vals)), vals...) - end - if p_constructor === identity && pType <: StaticArray - p_constructor = vals -> SymbolicUtils.Code.create_array( - pType, floatT, Val(1), Val(length(vals)), vals...) - end + + u0_constructor = get_u0_constructor(u0_constructor, u0Type, u0_eltype, symbolic_u0) + p_constructor = get_p_constructor(p_constructor, pType, floatT) if build_initializeprob kws = maybe_build_initialization_problem( diff --git a/test/linearproblem.jl b/test/linearproblem.jl new file mode 100644 index 0000000000..2c5c130a1b --- /dev/null +++ b/test/linearproblem.jl @@ -0,0 +1,189 @@ +using ModelingToolkit +using LinearSolve +using SciMLBase +using StaticArrays +using SparseArrays +using Test +using ModelingToolkit: t_nounits as t, D_nounits as D, SystemCompatibilityError + +@testset "Rejects non-affine systems" begin + @variables x y + @mtkbuild sys = System([0 ~ x^2 + y, 0 ~ x - y]) + @test_throws SystemCompatibilityError LinearProblem(sys, nothing) +end + +@variables x[1:3] [irreducible = true] +@parameters p[1:3, 1:3] q[1:3] + +@mtkbuild sys = System([p * x ~ q]) +# sanity check +@test length(unknowns(sys)) == length(equations(sys)) == 3 +A = Float64[1 2 3; 4 3.5 1.7; 5.2 1.8 9.7] +b = Float64[2, 5, 8] +ps = [p => A, q => b] + +@testset "Basics" begin + # Ensure it works without providing `u0` + prob = LinearProblem(sys, ps) + @test prob.u0 === nothing + @test SciMLBase.isinplace(prob) + @test prob.A ≈ A + @test prob.b ≈ b + @test eltype(prob.A) == Float64 + @test eltype(prob.b) == Float64 + + @test prob.ps[p * q] ≈ A * b + + sol = solve(prob) + # https://github.com/SciML/LinearSolve.jl/issues/532 + @test_broken SciMLBase.successful_retcode(sol) + @test prob.A * sol.u - prob.b≈zeros(3) atol=1e-10 + + A2 = rand(3, 3) + b2 = rand(3) + @testset "remake" begin + prob2 = remake(prob; p = [p => A2, q => b2]) + @test prob2.ps[p] ≈ A2 + @test prob2.ps[q] ≈ b2 + @test prob2.A ≈ A2 + @test prob2.b ≈ b2 + end + + prob.ps[p] = A2 + @test prob.A ≈ A2 + prob.ps[q] = b2 + @test prob.b ≈ b2 + A2[1, 1] = prob.ps[p[1, 1]] = 1.5 + @test prob.A ≈ A2 + b2[1] = prob.ps[q[1]] = 2.5 + @test prob.b ≈ b2 + + @testset "expression = Val{true}" begin + prob3e = LinearProblem(sys, ps; expression = Val{true}) + @test prob3e isa Expr + prob3 = eval(prob3e) + + @test prob3.u0 === nothing + @test SciMLBase.isinplace(prob3) + @test prob3.A ≈ A + @test prob3.b ≈ b + @test eltype(prob3.A) == Float64 + @test eltype(prob3.b) == Float64 + + @test prob3.ps[p * q] ≈ A * b + + sol = solve(prob3) + # https://github.com/SciML/LinearSolve.jl/issues/532 + @test_broken SciMLBase.successful_retcode(sol) + @test prob3.A * sol.u - prob3.b≈zeros(3) atol=1e-10 + end +end + +@testset "With `u0`" begin + prob = LinearProblem(sys, [x => ones(3); ps]) + @test prob.u0 ≈ ones(3) + @test SciMLBase.isinplace(prob) + @test eltype(prob.u0) == Float64 + + # Observed should work + @test prob[x[1] + x[2]] ≈ 2.0 + + @testset "expression = Val{true}" begin + prob3e = LinearProblem(sys, [x => ones(3); ps]; expression = Val{true}) + @test prob3e isa Expr + prob3 = eval(prob3e) + @test prob3.u0 ≈ ones(3) + @test eltype(prob3.u0) == Float64 + end +end + +@testset "SArray OOP form" begin + prob = LinearProblem(sys, SVector{2}(ps)) + @test prob.A isa SMatrix{3, 3, Float64} + @test prob.b isa SVector{3, Float64} + @test !SciMLBase.isinplace(prob) + @test prob.ps[p * q] ≈ A * b + + sol = solve(prob) + # https://github.com/SciML/LinearSolve.jl/issues/532 + @test_broken SciMLBase.successful_retcode(sol) + @test prob.A * sol.u - prob.b≈zeros(3) atol=1e-10 + + A2 = rand(3, 3) + b2 = rand(3) + @testset "remake" begin + prob2 = remake(prob; p = [p => A2, q => b2]) + # Despite passing `Array` to `remake` + @test prob2.A isa SMatrix{3, 3, Float64} + @test prob2.b isa SVector{3, Float64} + @test prob2.ps[p] ≈ A2 + @test prob2.ps[q] ≈ b2 + @test prob2.A ≈ A2 + @test prob2.b ≈ b2 + end + + @testset "expression = Val{true}" begin + prob3e = LinearProblem(sys, SVector{2}(ps); expression = Val{true}) + @test prob3e isa Expr + prob3 = eval(prob3e) + @test prob3.A isa SMatrix{3, 3, Float64} + @test prob3.b isa SVector{3, Float64} + @test !SciMLBase.isinplace(prob3) + @test prob3.ps[p * q] ≈ A * b + + sol = solve(prob3) + # https://github.com/SciML/LinearSolve.jl/issues/532 + @test_broken SciMLBase.successful_retcode(sol) + @test prob3.A * sol.u - prob3.b≈zeros(3) atol=1e-10 + end +end + +@testset "u0_constructor" begin + prob = LinearProblem{false}(sys, ps; u0_constructor = x -> SArray{Tuple{size(x)...}}(x)) + @test prob.A isa SMatrix{3, 3, Float64} + @test prob.b isa SVector{3, Float64} + @test prob.ps[p * q] ≈ A * b +end + +@testset "sparse form" begin + prob = LinearProblem(sys, ps; sparse = true) + @test issparse(prob.A) + @test !issparse(prob.b) + + sol = solve(prob) + # This might end up failing because of + # https://github.com/SciML/LinearSolve.jl/issues/532 + @test SciMLBase.successful_retcode(sol) + + A2 = rand(3, 3) + prob.ps[p] = A2 + @test prob.A ≈ A2 + b2 = rand(3) + prob.ps[q] = b2 + @test prob.b ≈ b2 + + A2 = rand(3, 3) + b2 = rand(3) + @testset "remake" begin + prob2 = remake(prob; p = [p => A2, q => b2]) + @test issparse(prob2.A) + @test !issparse(prob2.b) + @test prob2.ps[p] ≈ A2 + @test prob2.ps[q] ≈ b2 + @test prob2.A ≈ A2 + @test prob2.b ≈ b2 + end + + @testset "expression = Val{true}" begin + prob3e = LinearProblem(sys, ps; sparse = true, expression = Val{true}) + @test prob3e isa Expr + prob3 = eval(prob3e) + @test issparse(prob3.A) + @test !issparse(prob3.b) + + sol = solve(prob3) + # This might end up failing because of + # https://github.com/SciML/LinearSolve.jl/issues/532 + @test SciMLBase.successful_retcode(sol) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 31e4cc609f..c108c7898d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -98,6 +98,7 @@ end @safetestset "Namespacing test" include("namespacing.jl") @safetestset "Subsystem replacement" include("substitute_component.jl") @safetestset "Linearization Tests" include("linearize.jl") + @safetestset "LinearProblem Tests" include("linearproblem.jl") end end