diff --git a/lib/OrdinaryDiffEqExponentialRK/src/alg_utils.jl b/lib/OrdinaryDiffEqExponentialRK/src/alg_utils.jl index 1079cb0ca22..e77d3b72c94 100644 --- a/lib/OrdinaryDiffEqExponentialRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqExponentialRK/src/alg_utils.jl @@ -26,6 +26,37 @@ alg_order(alg::Exprb43) = 4 alg_adaptive_order(alg::Exprb32) = 2 alg_adaptive_order(alg::Exprb43) = 4 +function _expRK_has_concretization(A) + return try + has_concretization(A) + catch err + err isa UndefVarError || rethrow() + true + end +end + +_expRK_requires_krylov(f) = false +function _expRK_requires_krylov(f::SplitFunction) + A = f.f1.f + return size(A) != () && !_expRK_has_concretization(A) +end + +for Alg in (LawsonEuler, NorsettEuler, ETDRK2, ETDRK3, ETDRK4, HochOst4) + @eval function DiffEqBase.prepare_alg( + alg::$Alg, + u0::AbstractArray, + p, prob + ) + if !alg.krylov && _expRK_requires_krylov(prob.f) + return $Alg( + krylov = true, m = alg.m, iop = alg.iop, + autodiff = alg.autodiff, concrete_jac = alg.concrete_jac + ) + end + return alg + end +end + function DiffEqBase.prepare_alg( alg::ETD2, u0::AbstractArray, diff --git a/lib/OrdinaryDiffEqExponentialRK/src/exponential_rk_caches.jl b/lib/OrdinaryDiffEqExponentialRK/src/exponential_rk_caches.jl index 359819a1664..7515a97288a 100644 --- a/lib/OrdinaryDiffEqExponentialRK/src/exponential_rk_caches.jl +++ b/lib/OrdinaryDiffEqExponentialRK/src/exponential_rk_caches.jl @@ -61,6 +61,8 @@ function expRK_operators(::HochOst4, dt, A) return A21, A31, A32, A41, A42, A51, A52, A54, B1, B4, B5 end +_expRK_matrix_or_scalar(A) = size(A) == () ? convert(Number, A) : convert(AbstractMatrix, A) + # Unified constructor for constant caches for (Alg, Cache) in [ (:LawsonEuler, :LawsonEulerConstantCache), @@ -89,8 +91,7 @@ for (Alg, Cache) in [ else isa(f, SplitFunction) || throw(ArgumentError("Caching can only be used with SplitFunction")) - A = size(f.f1.f) == () ? convert(Number, f.f1.f) : - convert(AbstractMatrix, f.f1.f) + A = _expRK_matrix_or_scalar(f.f1.f) ops = expRK_operators(alg, dt, A) end if isa(f, SplitFunction) || SciMLBase.has_jac(f) @@ -143,7 +144,7 @@ function alg_cache_expRK( else KsCache = nothing # Precompute the operators - A = size(f.f1.f) == () ? convert(Number, f.f1.f) : convert(AbstractMatrix, f.f1.f) + A = _expRK_matrix_or_scalar(f.f1.f) ops = expRK_operators(alg, dt, A) end return uf, jac_config, J, ops, KsCache @@ -202,7 +203,7 @@ function alg_cache( KsCache = (Ks, expv_cache) else KsCache = nothing - A = size(f.f1.f) == () ? convert(Number, f.f1.f) : convert(AbstractMatrix, f.f1.f) + A = _expRK_matrix_or_scalar(f.f1.f) exphA = expRK_operators(alg, dt, A) end return LawsonEulerCache(u, uprev, tmp, dz, rtmp, G, du1, jac_config, uf, J, exphA, KsCache) diff --git a/lib/OrdinaryDiffEqExponentialRK/test/linear_nonlinear_convergence_tests.jl b/lib/OrdinaryDiffEqExponentialRK/test/linear_nonlinear_convergence_tests.jl index 64c94cc606d..ab8486b06e9 100644 --- a/lib/OrdinaryDiffEqExponentialRK/test/linear_nonlinear_convergence_tests.jl +++ b/lib/OrdinaryDiffEqExponentialRK/test/linear_nonlinear_convergence_tests.jl @@ -1,6 +1,7 @@ using OrdinaryDiffEqExponentialRK, Test, DiffEqDevTools, Random, LinearAlgebra, LinearSolve using OrdinaryDiffEqVerner, OrdinaryDiffEqSDIRK using OrdinaryDiffEqCore: alg_order +using SciMLBase: successful_retcode @testset "Caching Out-of-place" begin println("Caching Out-of-place") @@ -73,6 +74,27 @@ end @test sim.𝒪est[:L2] ≈ 4 atol = 0.1 end +@testset "Matrix-free SciMLOperator split" begin + u0 = ComplexF64[1.0 + 0.5im, -0.5 + 0.25im, 0.75 - 0.125im, -0.25 - 0.5im] + λ = ComplexF64[-1.0, -2.0, -3.0, -4.0] + F = FunctionOperator( + (v, u, p, t) -> v, u0; + T = ComplexF64, + islinear = true, + op_inverse = (v, u, p, t) -> v, + opnorm = (_ -> 1.0) + ) + L = cache_operator(F \ DiagonalOperator(λ) * F, u0) + @test !has_concretization(L) + + prob = SplitODEProblem(L, (u, p, t) -> zero(u), u0, (0.0, 0.1)) + sol = solve(prob, ETDRK4(), dt = 0.01, save_everystep = false) + + @test successful_retcode(sol) + @test sol.alg.krylov + @test sol.u[end]≈exp.(0.1 .* λ) .* u0 rtol=1.0e-6 +end + @info "CFNLIRK3() is broken" @testset "EPIRK Out-of-place" begin