Skip to content

Fixed L2 error for complex valued problems #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions debug_non_diag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
using Random
using StochasticDiffEq, DiffEqDevTools, Test
using SDEProblemLibrary: prob_sde_additivesystem

prob = prob_sde_additivesystem
prob = SDEProblem(prob.f, prob.g, prob.u0, (0.0, 0.1), prob.p)

reltols = 1.0 ./ 10.0 .^ (1:4)
abstols = reltols#[0.0 for i in eachindex(reltols)]
setups = [Dict(:alg => SRIW1())
Dict(:alg => EM(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => RKMil(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => SRIW1(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => SRA1(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => SRA1())]
_names = ["SRIW1", "EM", "RKMil", "SRIW1 Fixed", "SRA1 Fixed", "SRA1"]
test_dt = 0.1
# wp = WorkPrecisionSet(prob, abstols, reltols, setups, test_dt;
# numruns = 5, names = _names, error_estimate = :l2)

# se = get_sample_errors(prob, setups[1], numruns = 100, solution_runs = 100)
# se = get_sample_errors(prob, setups[1], numruns = [5, 10, 25, 50, 100, 1000],
# solution_runs = 100)

println("Now weak error without analytical solution")

prob2 = SDEProblem((du, u, p, t) -> prob.f(du, u, p, t), prob.g, prob.u0, (0.0, 0.1),
prob.p)
test_dt = 1 / 10^4
appxsol_setup = Dict(:alg => SRIW1(), :abstol => 1e-4, :reltol => 1e-4)

wp = WorkPrecisionSet(prob2, abstols, reltols, setups, test_dt;
appxsol_setup = appxsol_setup,
numruns = 5, names = _names, error_estimate = :l2)

using Test
using OrdinaryDiffEq, StochasticDiffEq, DiffEqDevTools, Plots
import SDEProblemLibrary: prob_sde_additivesystem

using Random
Random.seed!(123)

gr()

@testset "Analyticless SDE WorkPrecisionSet" begin
prob0 = prob_sde_additivesystem
prob = SDEProblem((du, u, p, t) -> prob0.f(du, u, p, t), prob0.g, prob0.u0, (0.0, 0.1),
prob0.p)

reltols = 1.0 ./ 10.0 .^ (1:5)
abstols = reltols#[0.0 for i in eachindex(reltols)]
setups = [Dict(:alg => SRIW1())
Dict(:alg => EM(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1))
Dict(:alg => RKMil(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => SRIW1(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => SRA1(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => SRA1())]
names = ["SRIW1", "EM", "RKMil", "SRIW1 Fixed", "SRA1 Fixed", "SRA1"]
test_dt = 0.1
wp = WorkPrecisionSet(prob, abstols, reltols, setups, test_dt;
numruns = 5, names = names, error_estimate = :l2,
appxsol_setup = Dict(:alg => RKMilCommute(), :abstol => 1e-4, :reltol => 1e-4))

plt = @test_nowarn plot(wp)
for i in 1:length(names)
@test plt[1][i][:x] ≈ getproperty(wp[i].errors, wp[i].error_estimate)
@test plt[1][i][:label] == names[i]
end
end

@testset failfast=true "Non-diagonal SDE WorkPrecisionSet" begin
# Linear SDE system
f_lin = function (du, u, p, t)
du = -0.5 .* u
end

g_lin = function (du, u, p, t)
du[1, 1] = im
du[2, 1] = im
du[3, 1] = 0.1
du[1, 2] = 0.1
du[2, 2] = 0.1
du[3, 2] = 0.2
end

tspan = (0.0, 1.0)
noise_rate_prototype = zeros(ComplexF64, 3, 2)
noise = StochasticDiffEq.RealWienerProcess!(0.0, [0.0, 0.0], [0.0, 0.0])
prob = SDEProblem(SDEFunction(f_lin, g_lin),
ComplexF64[1.0, 0.0, 0.0], tspan, noise = noise, noise_rate_prototype = noise_rate_prototype)

reltols = 1.0 ./ 10.0 .^ (1:5)
abstols = reltols#[0.0 for i in eachindex(reltols)]
setups = [Dict(:alg => EM(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1))
Dict(:alg => RKMilGeneral(; ii_approx = IICommutative()),
:dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false);
Dict(:alg => EulerHeun(), :dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1),
:adaptive => false)
Dict(:alg => LambaEulerHeun(),
:dts => 1.0 ./ 5.0 .^ ((1:length(reltols)) .+ 1), :adaptive => true)]
names = ["EM", "RKMilGeneral", "EulerHeun Fixed", "LambaEulerHeun"]
test_dt = 0.1#(1.0 / 5.0)^6
wp = WorkPrecisionSet(prob, abstols, reltols, setups, test_dt;
numruns = 5, names = names, error_estimate = :l2,
appxsol_setup = Dict(:alg => RKMilGeneral(; ii_approx = IICommutative())), maxiters = 1e7)

plt = @test_nowarn plot(wp)
for i in 1:length(names)
@test plt[1][i][:x] ≈ getproperty(wp[i].errors, wp[i].error_estimate)
@test plt[1][i][:label] == names[i]
end
end
2 changes: 1 addition & 1 deletion src/DiffEqDevTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using LinearAlgebra, Distributed

using Statistics

import Base: length
import Base: length, isless

import DiffEqBase: AbstractODEProblem, AbstractDDEProblem, AbstractDDEAlgorithm,
AbstractODESolution, AbstractRODEProblem, AbstractSDEProblem,
Expand Down
89 changes: 77 additions & 12 deletions src/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ function WorkPrecision(prob::AbstractBVProblem, alg, abstols, reltols, dts = not
let _prob = _prob
timeseries_errors = error_estimate ∈ TIMESERIES_ERRORS
dense_errors = error_estimate ∈ DENSE_ERRORS
println(
"timeseries_errors: $timeseries_errors dense_errors: $dense_errors ", @__LINE__)
for i in 1:N
if dts === nothing
sol = solve(_prob, alg; kwargs..., abstol = abstols[i],
Expand Down Expand Up @@ -464,7 +466,9 @@ function WorkPrecisionSet(prob,
print_names = false, names = nothing, appxsol = nothing,
error_estimate = :final,
test_dt = nothing, kwargs...)
println(@__LINE__)
N = length(setups)
@assert test_dt isa Nothing || test_dt isa Real
@assert names === nothing || length(setups) == length(names)
wps = Vector{WorkPrecision}(undef, N)
if names === nothing
Expand All @@ -487,14 +491,27 @@ function WorkPrecisionSet(prob,
end

@def error_calculation begin
println("line: ", @__LINE__)
if !DiffEqBase.has_analytic(prob.f)
println("No analytic line: ", @__LINE__)
t = prob.tspan[1]:test_dt:prob.tspan[2]
brownian_values = cumsum([[zeros(size(prob.u0))];
[sqrt(test_dt) * randn(size(prob.u0))
for i in 1:(length(t) - 1)]])
brownian_values2 = cumsum([[zeros(size(prob.u0))];
[sqrt(test_dt) * randn(size(prob.u0))
for i in 1:(length(t) - 1)]])
if prob.noise_rate_prototype === nothing
brownian_values = cumsum([[zeros(size(prob.u0))];
[sqrt(test_dt) * randn(size(prob.u0))
for i in 1:(length(t) - 1)]])
brownian_values2 = cumsum([[zeros(size(prob.u0))];
[sqrt(test_dt) * randn(size(prob.u0))
for i in 1:(length(t) - 1)]])
else
brownian_values = cumsum([[zeros(size(prob.noise_rate_prototype, 2))];
[sqrt(test_dt) *
randn(size(prob.noise_rate_prototype, 2))
for i in 1:(length(t) - 1)]])
brownian_values2 = cumsum([[zeros(size(prob.noise_rate_prototype, 2))];
[sqrt(test_dt) *
randn(size(prob.noise_rate_prototype, 2))
for i in 1:(length(t) - 1)]])
end
np = NoiseGrid(t, brownian_values, brownian_values2)
_prob = remake(prob, noise = np)
true_sol = solve(_prob, appxsol_setup[:alg]; kwargs..., appxsol_setup...)
Expand All @@ -520,28 +537,69 @@ end
_dts = get(setups[k], :dts, zeros(length(_abstols)))
filtered_setup = filter(p -> p.first in DiffEqBase.allowedkeywords, setups[k])

println("error_estimate: ", error_estimate, " line: ", @__LINE__)
println("timeseries_errors: $timeseries_errors dense_errors: $dense_errors line: ",
@__LINE__)
sol = solve(_prob, setups[k][:alg];
kwargs..., filtered_setup..., abstol = _abstols[j],
reltol = _reltols[j], dt = _dts[j],
timeseries_errors = timeseries_errors,
dense_errors = dense_errors)
DiffEqBase.has_analytic(prob.f) ? err_sol = sol : err_sol = appxtrue(sol, true_sol)
tmp_solutions[i, j, k] = err_sol
println(err_sol.errors)
end
end

import Base: isless
function Base.isless(a::Union{Complex, AbstractFloat}, b::Union{Complex, AbstractFloat})
# This is a workaround for the issue that SciMLBase incorrectly infers the solutions
# errors type. Not having this leads to crash while calculating the median.
return Real(a) < Real(b)
end

import SciMLBase: EnsembleTestSolution
function fix_SciMLBase_error_type_issue(sol::EnsembleTestSolution)
# This is a workaround for the issue that SciMLBase incorrectly infers the type of the
# EnsembleTestSolution.errors from the solution's state vector element type, instead of from
# the solution's error type. If the solution state is a Vector{ComplexF64}, this incorrectly leads
# errors with type Dict{Symbol, Vector{ComplexF64}} which should be Dict{Symbol, Vector{Float64}}.
T = eltype(Real.(collect(sol.errors)[1][2]))
new_errors = Dict{Symbol, Vector{T}}()
for (key, value) in sol.errors
new_errors[key] = Real.(sol.errors[key])
end
new_weak_errors = Dict{Symbol, T}()
for (key, value) in sol.weak_errors
new_weak_errors[key] = Real(sol.weak_errors[key])
end
new_error_means = Dict{Symbol, T}()
for (key, value) in sol.error_means
new_error_means[key] = Real(sol.error_means[key])
end
new_error_medians = Dict{Symbol, T}()
for (key, value) in sol.error_medians
new_error_medians[key] = Real(sol.error_medians[key])
end
return EnsembleTestSolution(sol.u, new_errors, new_weak_errors, new_error_means,
new_error_medians, sol.elapsedTime, sol.converged)
end

function WorkPrecisionSet(prob::AbstractRODEProblem, abstols, reltols, setups,
test_dt = nothing;
numruns = 20, numruns_error = 20,
print_names = false, names = nothing, appxsol_setup = nothing,
error_estimate = :final, parallel_type = :none,
kwargs...)
println(@__LINE__)
@assert test_dt isa Nothing || test_dt isa Real
@assert names === nothing || length(setups) == length(names)
timeseries_errors = DiffEqBase.has_analytic(prob.f) &&
error_estimate ∈ TIMESERIES_ERRORS
timeseries_errors = error_estimate ∈ TIMESERIES_ERRORS
weak_timeseries_errors = error_estimate ∈ WEAK_TIMESERIES_ERRORS
weak_dense_errors = error_estimate ∈ WEAK_DENSE_ERRORS
dense_errors = DiffEqBase.has_analytic(prob.f) && error_estimate ∈ DENSE_ERRORS
dense_errors = error_estimate ∈ DENSE_ERRORS
println(
"timeseries_errors: $timeseries_errors weak_timeseries_errors: $weak_timeseries_errors dense_errors: $dense_errors weak_dense_errors: $weak_dense_errors line: ", @__LINE__)
N = length(setups)
M = length(abstols)
times = Array{Float64}(undef, M, N)
Expand All @@ -565,10 +623,11 @@ function WorkPrecisionSet(prob::AbstractRODEProblem, abstols, reltols, setups,

_solutions_k = [[EnsembleSolution(tmp_solutions[:, j, k], 0.0, true) for j in 1:M]
for k in 1:N]
solutions = [[DiffEqBase.calculate_ensemble_errors(sim;
solutions = [[fix_SciMLBase_error_type_issue(DiffEqBase.calculate_ensemble_errors(sim;
weak_timeseries_errors = weak_timeseries_errors,
weak_dense_errors = weak_dense_errors)
weak_dense_errors = weak_dense_errors))
for sim in sol_k] for sol_k in _solutions_k]
println("mean errors: ", solutions[1][1].error_means)
if error_estimate ∈ WEAK_ERRORS
errors = [[solutions[j][i].weak_errors for i in 1:M] for j in 1:N]
else
Expand All @@ -595,6 +654,8 @@ function WorkPrecisionSet(prob::AbstractRODEProblem, abstols, reltols, setups,
GC.gc()
for j in 1:M
for i in 1:numruns
# println("$(names[k]) ($i/$numruns) the number j $j")
# println("$(_abstols[k]) $(_reltols[k]) $(_dts[k])")
time_tmp[i] = @elapsed sol = solve(prob, setups[k][:alg];
kwargs..., filtered_setup...,
abstol = _abstols[k][j],
Expand Down Expand Up @@ -623,6 +684,7 @@ function WorkPrecisionSet(prob::AbstractEnsembleProblem, abstols, reltols, setup
expected_value = nothing,
error_estimate = :weak_final, ensemblealg = EnsembleThreads(),
kwargs...)
println(@__LINE__)
@assert names === nothing || length(setups) == length(names)

weak_timeseries_errors = error_estimate ∈ WEAK_TIMESERIES_ERRORS
Expand Down Expand Up @@ -730,6 +792,7 @@ function WorkPrecisionSet(prob::AbstractBVProblem,
print_names = false, names = nothing, appxsol = nothing,
error_estimate = :final,
test_dt = nothing, kwargs...)
println(@__LINE__)
N = length(setups)
@assert names === nothing || length(setups) == length(names)
wps = Vector{WorkPrecision}(undef, N)
Expand Down Expand Up @@ -788,8 +851,10 @@ function get_sample_errors(prob::AbstractRODEProblem, setup, test_dt = nothing;
_dt = prob.tspan[2] - prob.tspan[1]
if prob.u0 isa Number
W = sqrt(_dt) * randn()
else
elseif prob.noise_rate_prototype isa Nothing
W = sqrt(_dt) * randn(size(prob.u0))
else
W = sqrt(_dt) * randn(size(prob.noise_rate_prototype, 2))
end
prob.f.analytic(prob.u0, prob.p, prob.tspan[2], W)
end
Expand Down
45 changes: 35 additions & 10 deletions src/test_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,29 @@ function appxtrue(sol::AbstractODESolution, sol2::TestSolution)
errors = Dict(:final => recursive_mean(abs.(sol.u[end] - _sol.u[end])))
if _sol.dense
timeseries_analytic = _sol(sol.t)
errors[:l∞] = maximum(vecvecapply((x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply((x) -> float(x) .^ 2,
errors[:l∞] = maximum(vecvecapply(
(x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
sol - timeseries_analytic)))
densetimes = collect(range(sol.t[1], stop = sol.t[end], length = 100))
interp_u = sol(densetimes)
interp_analytic = _sol(densetimes)
interp_errors = Dict(
:L∞ => maximum(vecvecapply((x) -> abs.(x),
interp_u - interp_analytic)),
:L2 => sqrt(recursive_mean(vecvecapply((x) -> float(x) .^ 2,
:L2 => sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
interp_u -
interp_analytic))))
errors = merge(errors, interp_errors)
else
timeseries_analytic = sol2.u
if sol.t == sol2.t
errors[:l∞] = maximum(vecvecapply((x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply((x) -> float(x) .^ 2,
errors[:l∞] = maximum(vecvecapply(
(x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
sol - timeseries_analytic)))
end
end
Expand All @@ -84,10 +89,14 @@ calculated.
function appxtrue(sol::AbstractODESolution, sol2::AbstractODESolution;
timeseries_errors = sol2.dense, dense_errors = sol2.dense)
errors = Dict(:final => recursive_mean(abs.(sol.u[end] - sol2.u[end])))
println("Made it to appxtrue")
if sol2.dense
println("Made it to dense")
timeseries_analytic = sol2(sol.t)
errors[:l∞] = maximum(vecvecapply((x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply((x) -> float(x) .^ 2,
errors[:l∞] = maximum(vecvecapply(
(x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
sol - timeseries_analytic)))
if dense_errors
densetimes = collect(range(sol.t[1], stop = sol.t[end], length = 100))
Expand All @@ -96,16 +105,32 @@ function appxtrue(sol::AbstractODESolution, sol2::AbstractODESolution;
interp_errors = Dict(
:L∞ => maximum(vecvecapply((x) -> abs.(x),
interp_u - interp_analytic)),
:L2 => sqrt(recursive_mean(vecvecapply((x) -> float(x) .^ 2,
:L2 => sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
interp_u -
interp_analytic))))
errors = merge(errors, interp_errors)
end
else
println("Made it to timeseries")
timeseries_analytic = sol2.u
println("typeof timeseries_analytic: ", typeof(timeseries_analytic))
println("timeseries_errors: ", timeseries_errors,
" does sol.t==sol2.t: ", sol.t == sol2.t)
if timeseries_errors && sol.t == sol2.t
errors[:l∞] = maximum(vecvecapply((x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply((x) -> float(x) .^ 2,
errors[:l∞] = maximum(vecvecapply(
(x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
sol - timeseries_analytic)))
elseif timeseries_errors
println(stderr,
"Warning: timeseries_errors requested but appxsol_setup's timesteps do not match solution. Switching to interpolated solution for this run.")
timeseries_analytic = sol2(sol.t)
errors[:l∞] = maximum(vecvecapply(
(x) -> abs.(x), sol - timeseries_analytic))
errors[:l2] = sqrt(recursive_mean(vecvecapply(
(x) -> Real.(conj.(float(x)) .* float(x)),
sol - timeseries_analytic)))
end
end
Expand Down
Loading
Loading