diff --git a/Project.toml b/Project.toml index 65bfac1..0c3f88b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.22" +version = "0.12.23" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -8,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] diff --git a/src/FiniteDifferences.jl b/src/FiniteDifferences.jl index 0a44f84..82dd64c 100644 --- a/src/FiniteDifferences.jl +++ b/src/FiniteDifferences.jl @@ -5,6 +5,7 @@ using LinearAlgebra using Printf using Random using Richardson +using SparseArrays using StaticArrays export to_vec, grad, jacobian, jvp, j′vp diff --git a/src/to_vec.jl b/src/to_vec.jl index 66bb1f9..37287f9 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -156,6 +156,30 @@ function to_vec(X::T) where {T<:PermutedDimsArray} return x_vec, PermutedDimsArray_from_vec end +function to_vec(v::SparseVector) + inds, _ = findnz(v) + sizes = size(v) + + x_vec, back = to_vec(collect(v)) + function SparseVector_from_vec(x_v) + v_values = back(x_v) + return sparsevec(inds, v_values[inds], sizes...) + end + return x_vec, SparseVector_from_vec +end + +function to_vec(m::SparseMatrixCSC) + is, js, _ = findnz(m) + sizes = size(m) + + x_vec, back = to_vec(collect(m)) + function SparseMatrixCSC_from_vec(x_v) + v_values = back(x_v) + return sparse(is, js, [v_values[i, j] for (i, j) in zip(is, js)], sizes...) + end + return x_vec, SparseMatrixCSC_from_vec +end + # Factorizations function to_vec(x::F) where {F <: SVD} diff --git a/test/runtests.jl b/test/runtests.jl index e0b776a..63039f7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using FiniteDifferences using LinearAlgebra using Printf using Random +using SparseArrays using StaticArrays using Test diff --git a/test/to_vec.jl b/test/to_vec.jl index 08f24bd..54e18ab 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -129,6 +129,11 @@ end ) end + @testset "SparseArrays" begin + test_to_vec(sparsevec([1 2 0; 0 0 3; 0 4 0.0])) + test_to_vec(sparse([1 2 0; 0 0 3; 0 4 0.0])) + end + @testset "Factorizations" begin # (100, 100) is needed to test for the NaNs that can appear in the # qr(M).T matrix