diff --git a/Project.toml b/Project.toml index 279bd47..a72048c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.18" +version = "0.13.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/to_vec.jl b/src/to_vec.jl index 838078d..43e740d 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -18,6 +18,12 @@ function to_vec(z::Complex) return [real(z), imag(z)], Complex_from_vec end +# Integers cannot be perturbed! +function to_vec(x::Integer) + Integer_from_vec(v) = x + return Bool[], Integer_from_vec +end + # Base case -- if x is already a Vector{<:Real} there's no conversion necessary. to_vec(x::Vector{<:Real}) = (x, identity) @@ -37,9 +43,9 @@ end # chunk of the time. function to_vec(x::T) where {T} Base.isstructtype(T) || throw(error("Expected a struct type")) - isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types + is_singleton(x) && return (Bool[], _ -> x) # Singleton types - val_vecs_and_backs = map(name -> to_vec(getfield(x, name)), fieldnames(T)) + val_vecs_and_backs = get_val_vecs_and_backs(x) vals = first.(val_vecs_and_backs) backs = last.(val_vecs_and_backs) @@ -56,6 +62,16 @@ function to_vec(x::T) where {T} return v, structtype_from_vec end +# Type-stable way to determine whether a type has any fields. +@generated function is_singleton(x) + return isempty(fieldnames(x)) ? :true : :false +end + +# Type-stable way to call `to_vec` on each field. +@generated function get_val_vecs_and_backs(x) + return Expr(:tuple, map(name -> :(to_vec(x.$name)), fieldnames(x))...) +end + function to_vec(x::DenseVector) x_vecs_and_backs = map(to_vec, x) x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) @@ -79,83 +95,11 @@ function to_vec(x::DenseArray) return x_vec, Array_from_vec end -# Some specific subtypes of AbstractArray. -function to_vec(x::Base.ReshapedArray{<:Any, 1}) - x_vec, from_vec = to_vec(parent(x)) - function ReshapedArray_from_vec(x_vec) - p = from_vec(x_vec) - return Base.ReshapedArray(p, x.dims, x.mi) - end - - return x_vec, ReshapedArray_from_vec -end - # To return a SubArray we would endup needing to copy the `parent` of `x` in `from_vec` # which doesn't seem particularly useful. So we just convert the view into a copy. # we might be able to do something more performant but this seems good for now. to_vec(x::Base.SubArray) = to_vec(copy(x)) -function to_vec(x::T) where {T<:LinearAlgebra.AbstractTriangular} - x_vec, back = to_vec(Matrix(x)) - function AbstractTriangular_from_vec(x_vec) - return T(reshape(back(x_vec), size(x))) - end - return x_vec, AbstractTriangular_from_vec -end - -function to_vec(x::T) where {T<:LinearAlgebra.HermOrSym} - x_vec, back = to_vec(Matrix(x)) - function HermOrSym_from_vec(x_vec) - return T(back(x_vec), x.uplo) - end - return x_vec, HermOrSym_from_vec -end - -function to_vec(X::Diagonal) - x_vec, back = to_vec(Matrix(X)) - function Diagonal_from_vec(x_vec) - return Diagonal(back(x_vec)) - end - return x_vec, Diagonal_from_vec -end - -function to_vec(X::Transpose) - x_vec, back = to_vec(Matrix(X)) - function Transpose_from_vec(x_vec) - return Transpose(permutedims(back(x_vec))) - end - return x_vec, Transpose_from_vec -end - -function to_vec(x::Transpose{<:Any, <:AbstractVector}) - x_vec, back = to_vec(Matrix(x)) - Transpose_from_vec(x_vec) = Transpose(vec(back(x_vec))) - return x_vec, Transpose_from_vec -end - -function to_vec(X::Adjoint) - x_vec, back = to_vec(Matrix(X)) - function Adjoint_from_vec(x_vec) - return Adjoint(conj!(permutedims(back(x_vec)))) - end - return x_vec, Adjoint_from_vec -end - -function to_vec(x::Adjoint{<:Any, <:AbstractVector}) - x_vec, back = to_vec(Matrix(x)) - Adjoint_from_vec(x_vec) = Adjoint(conj!(vec(back(x_vec)))) - return x_vec, Adjoint_from_vec -end - -function to_vec(X::T) where {T<:PermutedDimsArray} - x_vec, back = to_vec(parent(X)) - function PermutedDimsArray_from_vec(x_vec) - X_parent = back(x_vec) - return T(X_parent) - end - return x_vec, PermutedDimsArray_from_vec -end - # Factorizations function to_vec(x::F) where {F <: SVD} @@ -170,14 +114,6 @@ function to_vec(x::F) where {F <: SVD} return x_vec, SVD_from_vec end -function to_vec(x::Cholesky) - x_vec, back = to_vec(x.factors) - function Cholesky_from_vec(v) - return Cholesky(back(v), x.uplo, x.info) - end - return x_vec, Cholesky_from_vec -end - function to_vec(x::S) where {U, S <: Union{LinearAlgebra.QRCompactWYQ{U}, LinearAlgebra.QRCompactWY{U}}} # x.T is composed of upper triangular blocks. The subdiagonals elements # of the blocks are abitrary. We make sure to set all of them to zero @@ -203,6 +139,11 @@ end # Non-array data structures +function to_vec(x::Tuple{}) + Tuple_from_vec(v) = () + return Bool[], Tuple_from_vec +end + function to_vec(x::Tuple) x_vecs_and_backs = map(to_vec, x) x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) @@ -260,3 +201,7 @@ function FiniteDifferences.to_vec(t::Thunk) Thunk_from_vec = v -> @thunk(back(v)) return v, Thunk_from_vec end + +# Things that aren't struct types and aren't differentiable. +to_vec(x::Char) = Bool[], _ -> x +to_vec(x::Symbol) = Bool[], _ -> x diff --git a/test/to_vec.jl b/test/to_vec.jl index 7e430c8..58314b6 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -68,6 +68,14 @@ function test_to_vec(x::T; check_inferred=true) where {T} end @testset "to_vec" begin + + @testset "Integer" begin + # Under ChainRules semantics, Integers cannot be perturbed. `to_vec` is primarily a + # tool designed to work with ChainRules, so we employ the same semantics here. + test_to_vec(5) + @test length(to_vec(5)[1]) == 0 + end + @testset "$T" for T in (Float32, ComplexF32, Float64, ComplexF64) if T == Float64 test_to_vec(1.0) @@ -171,6 +179,7 @@ end end @testset "Tuples" begin + test_to_vec(()) test_to_vec((5, 4)) test_to_vec((5, randn(T, 5)); check_inferred = VERSION ≥ v"1.2") # broken on Julia 1.6.0, fixed on 1.6.1 test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred=false)