diff --git a/_typos.toml b/_typos.toml index a527fd96..ae08f1be 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,6 +1,7 @@ [default] locale = "en" extend-ignore-re = [ + "OT", ] [files] diff --git a/ext/CTFlowsSciMLExt.jl b/ext/CTFlowsSciMLExt.jl index ca3f4869..61818feb 100644 --- a/ext/CTFlowsSciMLExt.jl +++ b/ext/CTFlowsSciMLExt.jl @@ -160,9 +160,9 @@ function Strategies.metadata(::Type{SciML}) ), Strategies.OptionDefinition(; name = :save_everystep, - type = Bool, - default = Options.NotProvided, - description = "Save the solution at every solver step.", + type = Union{Bool, Symbol}, + default = :auto, + description = "Save the solution at every solver step. Set `true`/`false` to force, or `:auto` to infer from call pattern (false for `flow(t0, x0[, p0], tf)`, true for `flow((t0, tf), x0[, p0])`).", ), Strategies.OptionDefinition(; name = :saveat, @@ -173,9 +173,9 @@ function Strategies.metadata(::Type{SciML}) ), Strategies.OptionDefinition(; name = :dense, - type = Bool, - default = true, - description = "Whether to save extra pieces for dense (continuous) output.", + type = Union{Bool, Symbol}, + default = :auto, + description = "Dense output. Set `true`/`false` to force, or `:auto` to infer from call pattern (false for `flow(t0, x0[, p0], tf)`, true for `flow((t0, tf), x0[, p0])`).", ), Strategies.OptionDefinition(; name = :save_idxs, @@ -253,9 +253,9 @@ function Strategies.metadata(::Type{SciML}) ), Strategies.OptionDefinition(; name = :save_start, - type = Bool, - default = Options.NotProvided, - description = "Whether to save the initial condition.", + type = Union{Bool, Symbol}, + default = :auto, + description = "Save initial condition in solution. Set `true`/`false` to force, or `:auto` to infer from call pattern (false for `flow(t0, x0[, p0], tf)`, true for `flow((t0, tf), x0[, p0])`).", ), Strategies.OptionDefinition(; name = :save_end, @@ -280,16 +280,148 @@ end # build_sciml_integrator — actual implementation # ============================================================================= +# ============================================================================= +# Config-dependent option resolution +# ============================================================================= + +""" + _AUTO_OPTION_KEYS + +Tuple of option keys that support automatic resolution based on configuration type. + +These options use the `:auto` sentinel value in their metadata and are resolved +dynamically during integrator construction: +- For `PointConfig`: set to `false` (only final state needed) +- For `TrajectoryConfig`: set to `true` (full trajectory storage needed) + +Users can override automatic resolution by providing explicit `true`/`false` values +when constructing the integrator. +""" +const _AUTO_OPTION_KEYS = (:dense, :save_everystep, :save_start) + """ $(TYPEDSIGNATURES) -Build a `SciML` integrator with validated options. +Build a `SciML` integrator with validated options and pre-computed config-specific options. + +This function constructs a SciML integrator with automatic resolution of config-dependent +options. Options in `_AUTO_OPTION_KEYS` support the `:auto` sentinel value, which is +resolved based on the configuration type used during integration: +- For `PointConfig` (e.g., `flow(t0, x0, tf)`): options set to `false` to minimize memory + since only the final state is needed +- For `TrajectoryConfig` (e.g., `flow((t0, tf), x0)`): options set to `true` to enable + full trajectory storage and interpolation + +The resolved options are pre-computed and cached in the integrator for performance, +avoiding repeated resolution during integration. + +# Arguments +- `::Type{CTFlows.Integrators.SciMLTag}`: The SciML integrator tag type. +- `mode::Symbol`: Validation mode for strategy options (`:strict` or `:permissive`). +- `kwargs...`: User-provided option values. Explicit `true`/`false` values override + automatic `:auto` resolution. + +# Returns +- `CTFlows.Integrators.SciML`: Parametric SciML integrator with cached `options_point` + and `options_trajectory` fields. + +# Notes +- The `:auto` sentinel is defined in option metadata as `Union{Bool, Symbol}` with + default `:auto`. +- Pre-computation happens at construction time, not during integration. +- Config-specific options are returned by `Integrators.build_options` based on dispatch + on the configuration type. + +See also: [`CTFlows.Integrators.build_options`](@ref), [`CTFlows.Integrators.SciML`](@ref), +[`CTFlows.Common.PointConfig`](@ref), [`CTFlows.Common.TrajectoryConfig`](@ref). """ function CTFlows.Integrators.build_sciml_integrator( ::Type{CTFlows.Integrators.SciMLTag}; mode::Symbol = :strict, kwargs..., ) opts = Strategies.build_strategy_options(SciML; mode = mode, kwargs...) - return CTFlows.Integrators.SciML(opts) + raw = Strategies.options_dict(opts) + + # Pre-compute options for PointConfig + options_point = copy(raw) + for key in _AUTO_OPTION_KEYS + get(options_point, key, :auto) === :auto && (options_point[key] = false) + end + + # Pre-compute options for TrajectoryConfig + options_trajectory = copy(raw) + for key in _AUTO_OPTION_KEYS + get(options_trajectory, key, :auto) === :auto && (options_trajectory[key] = true) + end + + return CTFlows.Integrators.SciML{typeof(opts), typeof(options_point), typeof(options_trajectory)}( + opts, options_point, options_trajectory + ) +end + +# ============================================================================= +# build_options — config-dependent option resolution +# ============================================================================= + +""" +$(TYPEDSIGNATURES) + +Return pre-computed solver options for PointConfig. + +For a PointConfig, options like `dense`, `save_everystep`, and `save_start` +are set to `false` to minimize memory since only the final state is needed. + +# Arguments +- `integ::SciML`: The SciML integrator with pre-computed option caches. +- `config::Common.PointConfig`: The point configuration. + +# Returns +- `Dict{Symbol,Any}`: Pre-computed options optimized for PointConfig. + +See also: [`Integrators.build_options`](@ref), [`Integrators.SciML`](@ref). +""" +function Integrators.build_options(integ::SciML, config::Common.PointConfig) + return integ.options_point +end + +""" +$(TYPEDSIGNATURES) + +Return pre-computed solver options for TrajectoryConfig. + +For a TrajectoryConfig, options like `dense`, `save_everystep`, and `save_start` +are set to `true` to enable full trajectory storage and interpolation. + +# Arguments +- `integ::SciML`: The SciML integrator with pre-computed option caches. +- `config::Common.TrajectoryConfig`: The trajectory configuration. + +# Returns +- `Dict{Symbol,Any}`: Pre-computed options optimized for TrajectoryConfig. + +See also: [`Integrators.build_options`](@ref), [`Integrators.SciML`](@ref). +""" +function Integrators.build_options(integ::SciML, config::Common.TrajectoryConfig) + return integ.options_trajectory +end + +""" +$(TYPEDSIGNATURES) + +Return pre-computed solver options for fallback case (Nothing). + +Defaults to TrajectoryConfig options when no configuration is provided. + +# Arguments +- `integ::SciML`: The SciML integrator with pre-computed option caches. +- `config::Nothing`: No configuration provided (fallback). + +# Returns +- `Dict{Symbol,Any}`: Pre-computed options for TrajectoryConfig (fallback). + +See also: [`Integrators.build_options`](@ref), [`Integrators.SciML`](@ref). +""" +function Integrators.build_options(integ::SciML, config::Nothing) + return integ.options_trajectory # fallback vers Trajectory par défaut end # ============================================================================= @@ -356,12 +488,13 @@ end """ $(TYPEDSIGNATURES) -Solve an `ODEProblem` using the `SciML`'s configured options. +Solve an `ODEProblem` using resolved options. Returns a `SciMLIntegrationResult` wrapping the raw `ODESolution`. # Arguments - `integ::SciML`: The SciML integrator strategy. - `prob::SciMLBase.AbstractODEProblem`: The ODE problem to solve. +- `options::Dict{Symbol,Any}`: Resolved solver options (typically from `build_options`). - `unsafe=Common.__unsafe()`: If `true`, bypass ODE solver retcode checking; if `false`, throw `SolverFailure` on integration failure. # Returns @@ -370,8 +503,7 @@ Returns a `SciMLIntegrationResult` wrapping the raw `ODESolution`. # Throws - `CTBase.Exceptions.SolverFailure`: If the ODE solver returns an unsuccessful retcode and `unsafe=false`. """ -function Integrators.solve_problem(integ::SciML, prob::SciMLBase.AbstractODEProblem; unsafe=Common.__unsafe()) - options = Strategies.options_dict(integ) +function Integrators.solve_problem(integ::SciML, prob::SciMLBase.AbstractODEProblem, options::Dict{Symbol,Any}; unsafe=Common.__unsafe()) ode_sol = SciMLBase.solve(prob; options...) if !unsafe && !SciMLBase.successful_retcode(ode_sol.retcode) throw(Exceptions.SolverFailure( diff --git a/src/Flows/calling.jl b/src/Flows/calling.jl index dd3d820d..47cc4038 100644 --- a/src/Flows/calling.jl +++ b/src/Flows/calling.jl @@ -33,8 +33,11 @@ function call(flow::Flows.AbstractFlow{TD, VD}, config::Common.AbstractConfig; v # build ode problem prob = Integrators.build_problem(int, sys, config; variable=variable) + # build config-specific options + opts = Integrators.build_options(int, config) + # integrate ode problem - result = Integrators.solve_problem(int, prob; unsafe=unsafe) + result = Integrators.solve_problem(int, prob, opts; unsafe=unsafe) # build flow solution flow_sol = Solutions.build_solution(result, sys, config) diff --git a/src/Integrators/abstract_integrator.jl b/src/Integrators/abstract_integrator.jl index 665201fe..6c5bb485 100644 --- a/src/Integrators/abstract_integrator.jl +++ b/src/Integrators/abstract_integrator.jl @@ -20,10 +20,11 @@ Methods defined on **instances** that provide the actual configuration: # Concrete Implementation -All subtypes must implement two named functions: +All subtypes must implement three named functions: - `build_problem(integrator::AbstractIntegrator, system::CTFlows.Systems.AbstractSystem, config::CTFlows.Common.AbstractConfig; variable)`: Build the ODE problem representation from a system and configuration. -- `solve_problem(integrator::AbstractIntegrator, prob)`: Solve the given ODE problem (tspan is embedded in `prob`). +- `build_options(integrator::AbstractIntegrator, config::Union{CTFlows.Common.AbstractConfig, Nothing})`: Build solver options dict for the given configuration. +- `solve_problem(integrator::AbstractIntegrator, prob, options::Dict{Symbol,Any})`: Solve the given ODE problem with resolved options (tspan is embedded in `prob`). # Throws - `CTBase.Exceptions.NotImplemented`: If the methods are not implemented by the concrete type. @@ -63,11 +64,12 @@ end """ $(TYPEDSIGNATURES) -Solve the given ODE problem. +Solve the given ODE problem with resolved options. # Arguments - `integrator::AbstractIntegrator`: The integrator strategy. - `prob`: The ODE problem to solve (type varies by concrete integrator; tspan is embedded). +- `options::Dict{Symbol,Any}`: Resolved solver options (typically from `build_options`). - `unsafe=Common.__unsafe()`: If `true`, bypass ODE solver retcode checking; if `false`, throw `SolverFailure` on integration failure. # Returns @@ -76,13 +78,39 @@ Solve the given ODE problem. # Throws - `CTBase.Exceptions.NotImplemented`: If not implemented by the concrete type. -See also: [`CTFlows.Integrators.AbstractIntegrator`](@ref), [`CTFlows.Integrators.build_problem`](@ref). +See also: [`CTFlows.Integrators.AbstractIntegrator`](@ref), [`CTFlows.Integrators.build_problem`](@ref), [`CTFlows.Integrators.build_options`](@ref). """ -function solve_problem(integrator::AbstractIntegrator, prob; unsafe=Common.__unsafe()) +function solve_problem(integrator::AbstractIntegrator, prob, options::Dict{Symbol,Any}; unsafe=Common.__unsafe()) throw(Exceptions.NotImplemented( "AbstractIntegrator solve_problem not implemented"; - required_method = "solve_problem(integrator::$(typeof(integrator)), prob; unsafe=false)", - suggestion = "Implement solve_problem(i::YourIntegrator, prob; unsafe=false) returning an AbstractIntegrationResult.", + required_method = "solve_problem(integrator::$(typeof(integrator)), prob, options::Dict{Symbol,Any}; unsafe=false)", + suggestion = "Implement solve_problem(i::YourIntegrator, prob, options::Dict; unsafe=false) returning an AbstractIntegrationResult.", context = "AbstractIntegrator solve_problem - required method implementation", )) end + +""" +$(TYPEDSIGNATURES) + +Build solver options dict for the given configuration. + +# Arguments +- `integrator::AbstractIntegrator`: The integrator strategy. +- `config::Union{CTFlows.Common.AbstractConfig, Nothing}`: The integration configuration (or `Nothing` for fallback). + +# Returns +- `Dict{Symbol,Any}`: Resolved solver options for the given configuration. + +# Throws +- `CTBase.Exceptions.NotImplemented`: If not implemented by the concrete type. + +See also: [`CTFlows.Integrators.AbstractIntegrator`](@ref), [`CTFlows.Integrators.build_problem`](@ref), [`CTFlows.Integrators.solve_problem`](@ref). +""" +function build_options(integrator::AbstractIntegrator, config::Union{Common.AbstractConfig, Nothing}) + throw(Exceptions.NotImplemented( + "AbstractIntegrator build_options not implemented"; + required_method = "build_options(integrator::$(typeof(integrator)), config::Union{Common.AbstractConfig, Nothing})", + suggestion = "Implement build_options(i::YourIntegrator, config) returning a Dict{Symbol,Any} of resolved solver options.", + context = "AbstractIntegrator build_options - required method implementation", + )) +end diff --git a/src/Integrators/sciml.jl b/src/Integrators/sciml.jl index 0d0d728f..4320ca22 100644 --- a/src/Integrators/sciml.jl +++ b/src/Integrators/sciml.jl @@ -50,10 +50,14 @@ To activate the extension, load any of: - `using DifferentialEquations` # Fields -- `options::CTSolvers.Strategies.StrategyOptions`: validated option bundle. -""" -struct SciML <: AbstractSciMLIntegrator - options::CTSolvers.Strategies.StrategyOptions +- `options::CTSolvers.Strategies.StrategyOptions`: Validated option bundle. +- `options_point::Dict{Symbol, Any}`: Pre-computed options for PointConfig. +- `options_trajectory::Dict{Symbol, Any}`: Pre-computed options for TrajectoryConfig. +""" +struct SciML{O<:CTSolvers.Strategies.StrategyOptions, OP<:Dict{Symbol, Any}, OT<:Dict{Symbol, Any}} <: AbstractSciMLIntegrator + options::O + options_point::OP + options_trajectory::OT end # ============================================================================ diff --git a/test/suite/extensions/test_sciml_extension.jl b/test/suite/extensions/test_sciml_extension.jl index 0339cb99..61316401 100644 --- a/test/suite/extensions/test_sciml_extension.jl +++ b/test/suite/extensions/test_sciml_extension.jl @@ -215,8 +215,11 @@ function test_sciml_extension() # Build ODE problem prob = Integrators.build_problem(integ, sys, config; variable=nothing) + # Build options + opts = Integrators.build_options(integ, config) + # Solve - result = Integrators.solve_problem(integ, prob) + result = Integrators.solve_problem(integ, prob, opts) Test.@test result isa CTFlowsSciMLExt.SciMLIntegrationResult Test.@test result isa Solutions.AbstractIntegrationResult @@ -227,13 +230,15 @@ function test_sciml_extension() # Create a simple ODE problem that will fail with maxiters=1 prob = ODEProblem((du, u, p, t) -> du .= u, [1.0], (0.0, 1.0)) integ = Integrators.SciML(maxiters=1) + config = Common.PointConfig(0.0, [1.0], 1.0) + opts = Integrators.build_options(integ, config) # Test that solve_problem throws SolverFailure when unsafe=false - Test.@test_throws Exceptions.SolverFailure Integrators.solve_problem(integ, prob; unsafe=false) + Test.@test_throws Exceptions.SolverFailure Integrators.solve_problem(integ, prob, opts; unsafe=false) # Test the exception contains correct fields try - Integrators.solve_problem(integ, prob; unsafe=false) + Integrators.solve_problem(integ, prob, opts; unsafe=false) catch e Test.@test e isa Exceptions.SolverFailure Test.@test occursin("MaxIters", e.retcode) @@ -247,9 +252,11 @@ function test_sciml_extension() # Create a simple ODE problem that will fail with maxiters=1 prob = ODEProblem((du, u, p, t) -> du .= u, [1.0], (0.0, 1.0)) integ = Integrators.SciML(maxiters=1) + config = Common.PointConfig(0.0, [1.0], 1.0) + opts = Integrators.build_options(integ, config) # With unsafe=true, should not throw even with bad retcode - result = Integrators.solve_problem(integ, prob; unsafe=true) + result = Integrators.solve_problem(integ, prob, opts; unsafe=true) Test.@test result isa CTFlowsSciMLExt.SciMLIntegrationResult end end @@ -267,7 +274,8 @@ function test_sciml_extension() integ = Integrators.SciML(maxiters=1000, reltol=1e-6) prob = Integrators.build_problem(integ, sys, config; variable=nothing) - result = Integrators.solve_problem(integ, prob) + opts = Integrators.build_options(integ, config) + result = Integrators.solve_problem(integ, prob, opts) Test.@test Solutions.final_state(result) isa Vector{Float64} Test.@test length(Solutions.final_state(result)) == 2 @@ -296,8 +304,11 @@ function test_sciml_extension() # Build problem prob = Integrators.build_problem(integ, sys, config; variable=nothing) + # Build options + opts = Integrators.build_options(integ, config) + # Solve - result = Integrators.solve_problem(integ, prob) + result = Integrators.solve_problem(integ, prob, opts) # Build solution flow_sol = Solutions.build_solution(result, sys, config) @@ -316,8 +327,11 @@ function test_sciml_extension() # Build problem prob = Integrators.build_problem(integ, sys, config; variable=nothing) + # Build options + opts = Integrators.build_options(integ, config) + # Solve - result = Integrators.solve_problem(integ, prob) + result = Integrators.solve_problem(integ, prob, opts) # Build solution flow_sol = Solutions.build_solution(result, sys, config) @@ -325,6 +339,46 @@ function test_sciml_extension() Test.@test flow_sol isa Solutions.VectorFieldSolution end end + + # ==================================================================== + # UNIT TESTS - Config-Dependent Options + # ==================================================================== + + Test.@testset "Config-Dependent Options" begin + integ = Integrators.SciML() + config_point = Common.PointConfig(0.0, [1.0, 0.0], 1.0) + config_traj = Common.TrajectoryConfig((0.0, 1.0), [1.0, 0.0]) + + Test.@testset "auto defaults resolve correctly for PointConfig" begin + opts = Integrators.build_options(integ, config_point) + Test.@test opts[:dense] === false + Test.@test opts[:save_everystep] === false + Test.@test opts[:save_start] === false + end + + Test.@testset "auto defaults resolve correctly for TrajectoryConfig" begin + opts = Integrators.build_options(integ, config_traj) + Test.@test opts[:dense] === true + Test.@test opts[:save_everystep] === true + Test.@test opts[:save_start] === true + end + + Test.@testset "explicit values override auto" begin + integ_explicit = Integrators.SciML(dense=false, save_everystep=true, save_start=false) + opts = Integrators.build_options(integ_explicit, config_traj) + Test.@test opts[:dense] === false + Test.@test opts[:save_everystep] === true + Test.@test opts[:save_start] === false + end + + Test.@testset "build_options dispatch returns correct cached dicts" begin + opts_point = Integrators.build_options(integ, config_point) + opts_traj = Integrators.build_options(integ, config_traj) + Test.@test opts_point === integ.options_point + Test.@test opts_traj === integ.options_trajectory + Test.@test opts_point !== opts_traj + end + end end end diff --git a/test/suite/flows/test_calling.jl b/test/suite/flows/test_calling.jl index a0a77a77..61e0e570 100644 --- a/test/suite/flows/test_calling.jl +++ b/test/suite/flows/test_calling.jl @@ -34,13 +34,14 @@ Tracks which methods were called. """ mutable struct FakeIntegratorForCalling <: Integrators.AbstractIntegrator build_problem_called::Bool + build_options_called::Bool solve_problem_called::Bool problem_result::Any ode_solution::Any end function FakeIntegratorForCalling() - return FakeIntegratorForCalling(false, false, nothing, nothing) + return FakeIntegratorForCalling(false, false, false, nothing, nothing) end # Implement named functions instead of callables @@ -51,7 +52,12 @@ function Integrators.build_problem(integ::FakeIntegratorForCalling, system::Syst return integ.problem_result end -function Integrators.solve_problem(integ::FakeIntegratorForCalling, prob; unsafe=false) +function Integrators.build_options(integ::FakeIntegratorForCalling, config::Union{Common.AbstractConfig, Nothing}) + integ.build_options_called = true + return Dict{Symbol,Any}() +end + +function Integrators.solve_problem(integ::FakeIntegratorForCalling, prob, options::Dict{Symbol,Any}; unsafe=false) integ.solve_problem_called = true integ.ode_solution = FakeIntegrationResultForCalling() return integ.ode_solution @@ -113,6 +119,7 @@ function test_calling() # Verify all steps were called Test.@test integ.build_problem_called === true + Test.@test integ.build_options_called === true Test.@test integ.solve_problem_called === true # Verify result - for PointConfig it unwraps the vector @@ -129,6 +136,7 @@ function test_calling() result = Flows.call(flow, config; variable=0.5, unsafe=false) Test.@test integ.build_problem_called === true + Test.@test integ.build_options_called === true Test.@test integ.solve_problem_called === true Test.@test result == :fake_flow_solution end @@ -142,6 +150,7 @@ function test_calling() result = Flows.call(flow, config; variable=nothing, unsafe=false) Test.@test integ.build_problem_called === true + Test.@test integ.build_options_called === true Test.@test integ.solve_problem_called === true Test.@test result === :fake_vector_field_solution end @@ -156,6 +165,7 @@ function test_calling() result = Flows.call(flow, config; variable=nothing, unsafe=true) Test.@test integ.build_problem_called === true + Test.@test integ.build_options_called === true Test.@test integ.solve_problem_called === true Test.@test result == :fake_flow_solution end diff --git a/test/suite/integrators/test_abstract_integrator.jl b/test/suite/integrators/test_abstract_integrator.jl index 77dde295..af2740cd 100644 --- a/test/suite/integrators/test_abstract_integrator.jl +++ b/test/suite/integrators/test_abstract_integrator.jl @@ -35,13 +35,17 @@ function FakeIntegrator() return FakeIntegrator(CTSolvers.Strategies.StrategyOptions()) end -# Implement the two required callable signatures +# Implement the three required callable signatures function Integrators.build_problem(integ::FakeIntegrator, system::Systems.AbstractSystem, config::Common.AbstractConfig; variable) p = Common.ODEParameters(variable) return :fake_ode_problem end -function Integrators.solve_problem(integ::FakeIntegrator, prob) +function Integrators.build_options(integ::FakeIntegrator, config::Union{Common.AbstractConfig, Nothing}) + return Dict{Symbol,Any}() +end + +function Integrators.solve_problem(integ::FakeIntegrator, prob, options::Dict{Symbol,Any}) return :fake_ode_solution end @@ -91,9 +95,15 @@ function test_abstract_integrator() end Test.@testset "Integration signature" begin - result = Integrators.solve_problem(integ, :fake_prob) + result = Integrators.solve_problem(integ, :fake_prob, Dict{Symbol,Any}()) Test.@test result === :fake_ode_solution end + + Test.@testset "build_options signature" begin + config = Common.PointConfig(0.0, [1.0, 0.0], 1.0) + opts = Integrators.build_options(integ, config) + Test.@test opts isa Dict{Symbol,Any} + end end # ==================================================================== @@ -110,7 +120,12 @@ function test_abstract_integrator() end Test.@testset "Integration throws NotImplemented" begin - Test.@test_throws Exceptions.NotImplemented Integrators.solve_problem(integ, :fake_prob) + Test.@test_throws Exceptions.NotImplemented Integrators.solve_problem(integ, :fake_prob, Dict{Symbol,Any}()) + end + + Test.@testset "build_options throws NotImplemented" begin + config = Common.PointConfig(0.0, [1.0, 0.0], 1.0) + Test.@test_throws Exceptions.NotImplemented Integrators.build_options(integ, config) end end end