diff --git a/Project.toml b/Project.toml index d7bb14ed..f56d2733 100644 --- a/Project.toml +++ b/Project.toml @@ -23,8 +23,8 @@ SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore" [compat] Accessors = "0.1.42" -Adapt = "4" -ArrayInterface = "7.19" +Adapt = "4.5.2" +ArrayInterface = "7.24" DocStringExtensions = "0.9.4" LinearAlgebra = "1.10" LoopVectorization = "0.12" diff --git a/src/basic.jl b/src/basic.jl index c7a34060..82577f00 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -340,10 +340,23 @@ Base.resize!(L::ScaledOperator, n::Integer) = (resize!(L.L, n); L) LinearAlgebra.opnorm(L::ScaledOperator, p::Real = 2) = abs(L.λ) * opnorm(L.L, p) function update_coefficients(L::ScaledOperator, u, p, t; kwargs...) - @reset L.L = update_coefficients(L.L, u, p, t; kwargs...) - @reset L.λ = update_coefficients(L.λ, u, p, t; kwargs...) + λ = _freeze_updated_scalar(update_coefficients(L.λ, u, p, t; kwargs...)) + L_inner = update_coefficients(L.L, u, p, t; kwargs...) + return ScaledOperator(λ, L_inner) +end - return L +function _throw_updated_scaled_inplace() + throw( + ArgumentError( + "cannot update coefficients in-place after an out-of-place ScaledOperator update; call update_coefficients instead", + ) + ) +end + +function update_coefficients!( + L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any}, u, p, t; kwargs... + ) + _throw_updated_scaled_inplace() end function update_coefficients!(L::ScaledOperator, u, p, t; kwargs...) @@ -445,6 +458,13 @@ end return L.L(w, v, u, p, t, a, false; kwargs...) end +@inline function (L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any})( + w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs... + ) + L = update_coefficients(L, u, p, t; kwargs...) + return mul!(w, L, v) +end + # In-place with scaling: w = α*(L*v) + β*w @inline function (L::ScaledOperator)( w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs... @@ -454,6 +474,13 @@ end return L.L(w, v, u, p, t, a, β; kwargs...) end +@inline function (L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any})( + w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs... + ) + L = update_coefficients(L, u, p, t; kwargs...) + return mul!(w, L, v, α, β) +end + """ Lazy operator addition diff --git a/src/scalar.jl b/src/scalar.jl index 2404e026..4ce9e3d3 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -126,6 +126,15 @@ mutable struct ScalarOperator{T <: Number, F} <: AbstractSciMLScalarOperator{T} update_func::F end +# Immutable snapshot used by out-of-place updates after ScalarOperator has evaluated its state. +struct _UpdatedScalarOperator{T <: Number, F} <: AbstractSciMLScalarOperator{T} + val::T + update_func::F +end + +_freeze_updated_scalar(α) = α +_freeze_updated_scalar(α::ScalarOperator) = _UpdatedScalarOperator(α.val, α.update_func) + """ $SIGNATURES @@ -186,6 +195,7 @@ end # constructors Base.convert(T::Type{<:Number}, α::ScalarOperator) = convert(T, α.val) +Base.convert(T::Type{<:Number}, α::_UpdatedScalarOperator) = convert(T, α.val) Base.convert(::Type{ScalarOperator}, α::Number) = ScalarOperator(α) ScalarOperator(α::AbstractSciMLScalarOperator) = α @@ -193,6 +203,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ) # traits Base.show(io::IO, α::ScalarOperator) = print(io, "ScalarOperator($(α.val))") +Base.show(io::IO, α::_UpdatedScalarOperator) = print(io, "ScalarOperator($(α.val))") function Base.conj(α::ScalarOperator) # TODO - test val = conj(α.val) update_func = ( @@ -208,12 +219,28 @@ function Base.conj(α::ScalarOperator) # TODO - test return ScalarOperator(val; update_func = update_func, accepted_kwargs = NoKwargFilter()) end +function Base.conj(α::_UpdatedScalarOperator) + val = conj(α.val) + update_func = ( + oldval, u, p, t; + kwargs..., + ) -> α.update_func( + oldval |> conj, + u, + p, + t; + kwargs... + ) |> conj + return _UpdatedScalarOperator(val, update_func) +end + Base.one(::AbstractSciMLScalarOperator{T}) where {T} = ScalarOperator(one(T)) Base.zero(::AbstractSciMLScalarOperator{T}) where {T} = ScalarOperator(zero(T)) Base.one(::Type{<:AbstractSciMLScalarOperator}) = ScalarOperator(true) Base.zero(::Type{<:AbstractSciMLScalarOperator}) = ScalarOperator(false) Base.abs(α::ScalarOperator) = abs(α.val) +Base.abs(α::_UpdatedScalarOperator) = abs(α.val) function LinearAlgebra.exp(α::AbstractSciMLScalarOperator) update_func = ( @@ -226,11 +253,16 @@ function LinearAlgebra.exp(α::AbstractSciMLScalarOperator) end Base.iszero(α::ScalarOperator) = iszero(α.val) +Base.iszero(α::_UpdatedScalarOperator) = iszero(α.val) getops(α::ScalarOperator) = (α.val,) +getops(α::_UpdatedScalarOperator) = (α.val,) isconstant(α::ScalarOperator) = update_func_isconstant(α.update_func) +isconstant(α::_UpdatedScalarOperator) = update_func_isconstant(α.update_func) has_ldiv(α::ScalarOperator) = !iszero(α.val) +has_ldiv(α::_UpdatedScalarOperator) = !iszero(α.val) has_ldiv!(α::ScalarOperator) = has_ldiv(α) +has_ldiv!(α::_UpdatedScalarOperator) = has_ldiv(α) function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...) L.val = L.update_func(L.val, u, p, t; kwargs...) @@ -241,11 +273,21 @@ function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs.. return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func) end +function SciMLOperators.update_coefficients( + L::_UpdatedScalarOperator, u, p, t; kwargs... + ) + return _UpdatedScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func) +end + # Copy method to avoid aliasing function Base.copy(L::ScalarOperator) return ScalarOperator(L.val, L.update_func) end +function Base.copy(L::_UpdatedScalarOperator) + return _UpdatedScalarOperator(L.val, L.update_func) +end + # Add ScalarOperator specific implementations for the new interface function (α::ScalarOperator)(v::AbstractArray, u, p, t; kwargs...) α = update_coefficients(α, u, p, t; kwargs...) diff --git a/test/ad_semantics.jl b/test/ad_semantics.jl new file mode 100644 index 00000000..4650ef3b --- /dev/null +++ b/test/ad_semantics.jl @@ -0,0 +1,76 @@ +using SciMLOperators, LinearAlgebra, Test, Zygote + +using SciMLOperators: concretize + +const ad_n = 3 +const ad_u = [0.3, -0.2, 0.7] +const ad_v = [1.0, -2.0, 0.5] +const ad_t = 0.4 +const ad_pmat = [ + 0.0 2.0 -1.0 + 1.0 0.0 0.5 + -0.25 0.75 0.0 +] + +ad_scalar() = ScalarOperator(0.0, (_, _, p, _) -> p) +ad_matrix() = MatrixOperator(ad_pmat) +ad_added_operator() = MatrixOperator(Matrix{Float64}(I, ad_n, ad_n)) + ad_scalar() * ad_matrix() + +function ad_expected_scaled(p) + return p .* (ad_pmat * ad_v) +end + +function ad_expected_added(p) + return (Matrix{Float64}(I, ad_n, ad_n) + p .* ad_pmat) * ad_v +end + +@testset "AD semantic equivalence" begin + p = 1.7 + + @testset "ScalarOperator * MatrixOperator" begin + L = ad_scalar() * ad_matrix() + + concretized_loss(p) = sum(concretize(update_coefficients(L, ad_u, p, ad_t)) * ad_v) + direct_loss(p) = sum(L(ad_v, ad_u, p, ad_t)) + + @test concretize(update_coefficients(L, ad_u, p, ad_t)) ≈ p .* ad_pmat + @test L(ad_v, ad_u, p, ad_t) ≈ ad_expected_scaled(p) + + w = similar(ad_v) + L(w, ad_v, ad_u, p, ad_t) + @test w ≈ ad_expected_scaled(p) + + expected_grad = sum(ad_pmat * ad_v) + @test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad + @test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad + + updated_L = update_coefficients(L, ad_u, p, ad_t) + @test updated_L(ad_v, ad_u, p + 1, ad_t) ≈ ad_expected_scaled(p + 1) + @test_throws ArgumentError update_coefficients!(updated_L, ad_u, p + 1, ad_t) + updated_L(w, ad_v, ad_u, p + 1, ad_t) + @test w ≈ ad_expected_scaled(p + 1) + + w .= 0.25 + updated_L(w, ad_v, ad_u, p + 1, ad_t, 2.0, 0.5) + @test w ≈ 2 .* ad_expected_scaled(p + 1) .+ 0.125 + end + + @testset "MatrixOperator + ScalarOperator * MatrixOperator" begin + L = ad_added_operator() + + concretized_loss(p) = sum(concretize(update_coefficients(L, ad_u, p, ad_t)) * ad_v) + direct_loss(p) = sum(L(ad_v, ad_u, p, ad_t)) + + @test concretize(update_coefficients(L, ad_u, p, ad_t)) ≈ + Matrix{Float64}(I, ad_n, ad_n) + p .* ad_pmat + @test L(ad_v, ad_u, p, ad_t) ≈ ad_expected_added(p) + + w = similar(ad_v) + L(w, ad_v, ad_u, p, ad_t) + @test w ≈ ad_expected_added(p) + + expected_grad = sum(ad_pmat * ad_v) + @test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad + @test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d27c6e77..95903e74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,9 @@ end @time @safetestset "Zygote.jl" begin include("zygote.jl") end + @time @safetestset "AD semantics" begin + include("ad_semantics.jl") + end @time @safetestset "Copy methods" begin include("copy.jl") end diff --git a/test/zygote.jl b/test/zygote.jl index ea3a46d4..1a80fb2a 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -111,3 +111,17 @@ for (LType, L) in ( end end end + +@testset "Zygote update_coefficients concretize scaled operator" begin + A1 = MatrixOperator([1.0 0.0; 0.0 1.0]) + A2 = MatrixOperator([1.0 0.0; 0.0 0.0]) + coeff = ScalarOperator(0.0, (a, u, p, t) -> p) + L = A1 + coeff * A2 + + operator_entry(p) = (update_coefficients(L, 0, p, 0) |> concretize)[1, 1] + matrix_entry(p) = ([1.0 0.0; 0.0 1.0] + p * [1.0 0.0; 0.0 0.0])[1, 1] + + p = 1.0 + @test operator_entry(p) == matrix_entry(p) + @test Zygote.gradient(operator_entry, p)[1] == Zygote.gradient(matrix_entry, p)[1] +end