diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 4f7bc7771f..eab33e00b3 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -61,6 +61,12 @@ function generate_initializesystem_timevarying(sys::AbstractSystem; isempty(trueobs) || filter_delay_equations_variables!(sys, trueobs) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup + arrvars = Set() + for var in vars + if iscall(var) && operation(var) === getindex + push!(arrvars, first(arguments(var))) + end + end eqs_ics = Equation[] defs = copy(defaults(sys)) # copy so we don't modify sys.defaults @@ -71,9 +77,13 @@ function generate_initializesystem_timevarying(sys::AbstractSystem; # PREPROCESSING op = anydict(op) + if isempty(op) + op = copy(defs) + end + scalarize_vars_in_varmap!(op, arrvars) u0map = anydict() pmap = anydict() - build_operating_point!(sys, op, u0map, pmap, defs, unknowns(sys), + build_operating_point!(sys, op, u0map, pmap, Dict(), unknowns(sys), parameters(sys; initial_parameters = true)) for (k, v) in op if has_parameter_dependency_with_lhs(sys, k) && is_variable_floatingpoint(k) @@ -144,7 +154,7 @@ function generate_initializesystem_timevarying(sys::AbstractSystem; # 3) process other variables for var in vars - if var ∈ keys(defs) + if var ∈ keys(op) push!(eqs_ics, var ~ defs[var]) elseif var ∈ keys(guesses) push!(defs, var => guesses[var]) @@ -238,7 +248,7 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem; op = anydict(op) u0map = anydict() pmap = anydict() - build_operating_point!(sys, op, u0map, pmap, defs, unknowns(sys), + build_operating_point!(sys, op, u0map, pmap, Dict(), unknowns(sys), parameters(sys; initial_parameters = true)) for (k, v) in op if has_parameter_dependency_with_lhs(sys, k) && is_variable_floatingpoint(k) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2016b1efd8..65b1e4ba18 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -523,6 +523,24 @@ function scalarize_varmap!(varmap::AbstractDict) return varmap end +""" + $(TYPEDSIGNATURES) + +For each array variable in `vars`, scalarize the corresponding entry in `varmap`. +If a scalarized entry already exists, it is not overridden. +""" +function scalarize_vars_in_varmap!(varmap::AbstractDict, vars) + for var in vars + symbolic_type(var) == ArraySymbolic() || continue + is_sized_array_symbolic(var) || continue + haskey(varmap, var) || continue + for i in eachindex(var) + haskey(varmap, var[i]) && continue + varmap[var[i]] = varmap[var][i] + end + end +end + function get_temporary_value(p, floatT = Float64) stype = symtype(unwrap(p)) return if stype == Real diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index e12d05479f..06d0752076 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1664,3 +1664,10 @@ end sol = solve(prob, Tsit5()) @test SciMLBase.successful_retcode(sol) end + +@testset "Defaults removed with ` => nothing` aren't retained" begin + @variables x(t)[1:2] + @mtkbuild sys = System([D(x[1]) ~ -x[1], x[1] + x[2] ~ 3], t; defaults = [x[1] => 1]) + prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0)) + @test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED +end