Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,9 @@ function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::OD
M = integrator.f.mass_matrix
M isa UniformScaling && return
update_coefficients!(M, u, p, t)
algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
algebraic_vars = vec(all(iszero, M, dims = 1))
algebraic_eqs = vec(all(iszero, M, dims = 2))

(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
tmp = get_tmp_cache(integrator)[1]

Expand Down Expand Up @@ -456,7 +457,7 @@ function _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator, prob::OD
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u, nlprob, isAD)

nlsol = solve(nlprob, nlsolve; abstol = alg.abstol, reltol = integrator.opts.reltol)
alg_u .= nlsol
alg_u .= nlsol.u

recursivecopy!(integrator.uprev, integrator.u)
if alg_extrapolates(integrator.alg)
Expand Down
2 changes: 2 additions & 0 deletions test/gpu/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
OrdinaryDiffEqRosenbrock = "43230ef6-c299-4910-a778-202eb28ce4ce"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[compat]
Expand Down
65 changes: 65 additions & 0 deletions test/gpu/simple_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using OrdinaryDiffEqRosenbrock
using OrdinaryDiffEqNonlinearSolve
using CUDA
using LinearAlgebra
using Adapt
using SparseArrays
using Test

#=
du[1] = -u[1]
du[2] = -0.5*u[2]
0 = u[1] + u[2] - u[3]
0 = -u[1] + u[2] - u[4]
=#

function dae!(du, u, p, t)
mul!(du, p, u)
end

p = [-1 0 0 0
1 -0.5 0 0
1 1 -1 0
-1 1 0 -1]

# mass_matrix = [1 0 0 0
# 0 1 0 0
# 0 0 0 0
# 0 0 0 0]
mass_matrix = Diagonal([1, 1, 0, 0])
jac_prototype = sparse(map(x -> iszero(x) ? 0.0 : 1.0, p))

u0 = [1.0, 1.0, 0.5, 0.5] # force init
odef = ODEFunction(dae!, mass_matrix = mass_matrix, jac_prototype = jac_prototype)

tspan = (0.0, 5.0)
prob = ODEProblem(odef, u0, tspan, p)
sol = solve(prob, Rodas5P())

# gpu version
mass_matrix_d = adapt(CuArray, mass_matrix)

# TODO: jac_prototype fails
# jac_prototype_d = adapt(CuArray, jac_prototype)
# jac_prototype_d = CUDA.CUSPARSE.CuSparseMatrixCSR(jac_prototype)
jac_prototype_d = nothing

u0_d = adapt(CuArray, u0)
p_d = adapt(CuArray, p)
odef_d = ODEFunction(dae!, mass_matrix = mass_matrix_d, jac_prototype = jac_prototype_d)
prob_d = ODEProblem(odef_d, u0_d, tspan, p_d)
sol_d = solve(prob_d, Rodas5P())

@testset "Test constraints in GPU sol" begin
for t in sol_d.t
u = Vector(sol_d(t))
@test isapprox(u[1] + u[2], u[3]; atol = 1e-6)
@test isapprox(-u[1] + u[2], u[4]; atol = 1e-6)
end
end

@testset "Compare GPU to CPU solution" begin
for t in tspan[begin]:0.1:tspan[end]
@test Vector(sol_d(t)) ≈ sol(t)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ end
@time @safetestset "Linear Exponential GPU" include("gpu/linear_exp.jl")
@time @safetestset "Reaction-Diffusion Stiff Solver GPU" include("gpu/reaction_diffusion_stiff.jl")
@time @safetestset "Scalar indexing bug bypass" include("gpu/hermite_test.jl")
@time @safetestset "simple dae on GPU" include("gpu/simple_dae.jl")
end

if !is_APPVEYOR && GROUP == "QA"
Expand Down
Loading