diff --git a/src/DiffEqNoiseProcess.jl b/src/DiffEqNoiseProcess.jl index 13939a1..1eba9e5 100644 --- a/src/DiffEqNoiseProcess.jl +++ b/src/DiffEqNoiseProcess.jl @@ -11,6 +11,8 @@ import RandomNumbers, Random123 import DiffEqBase: isinplace, solve, AbstractNoiseProcess, DEIntegrator, AbstractNoiseProblem +import SciMLBase: remake + import PoissonRandom, Distributions import QuadGK, Optim diff --git a/src/noise_interfaces/common.jl b/src/noise_interfaces/common.jl index e34fc92..248bae5 100644 --- a/src/noise_interfaces/common.jl +++ b/src/noise_interfaces/common.jl @@ -88,3 +88,34 @@ function Base.reverse(W::AbstractNoiseProcess) end return backwardnoise end + +function Base.similar(np::NoiseProcess, ::Type{NoiseProcess} = NoiseProcess) + NoiseProcess{isinplace(np)}(0.0, 0.0, np.Z isa AbstractVector ? np.Z[1] : np.Z, np.dist, + np.bridge; + rswm = np.rswm, save_everystep = np.save_everystep, + rng = deepcopy(np.rng), + reset = np.reset, reseed = np.reseed, + continuous = np.continuous, + cache = np.cache) +end + +function Base.copy(np::NoiseProcess) + np2 = similar(np) + for f in propertynames(np) + setfield!(np2, f, getfield(np, f)) + end + np2 +end + +function SciMLBase.remake(np::NoiseProcess; kwargs...) + np_new = copy(np) + inits = (t0 = :t, W0 = :W, Z0 = :Z) + for kwarg in kwargs + if first(kwarg) in keys(inits) + setfield!(np_new, inits[first(kwarg)], [second(kwarg)]) + else + setfield!(np_new, kwarg...) + end + end + np_new +end diff --git a/test/noise_process_remake.jl b/test/noise_process_remake.jl new file mode 100644 index 0000000..2bade15 --- /dev/null +++ b/test/noise_process_remake.jl @@ -0,0 +1,23 @@ +@testset "Remake" begin + using SciMLBase, DiffEqNoiseProcess, Test, Random + W = WienerProcess(0.0, 1.0, 1.0, rng = Random.Xoshiro(42)) + dt = 0.1 + W.dt = dt + u = nothing + p = nothing # for state-dependent distributions + calculate_step!(W, dt, u, p) + for i in 1:10 + accept_step!(W, dt, u, p) + end + W2 = copy(W) + for prop in propertynames(W) + @test getfield(W, prop) === getfield(W2, prop) + end + rng2 = Random.Xoshiro(43) + W3 = remake(W2, rng = rng2) + @test W3.rng === rng2 + W.rng = rng2 + for prop in propertynames(W) + @test getfield(W, prop) === getfield(W3, prop) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 33538ce..0bb429e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using Test include("correlated.jl") include("noise_wrapper.jl") include("noise_function.jl") + include("noise_process_remake.jl") include("VBT_test.jl") include("noise_grid.jl") include("noise_approximation.jl")