Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ SciMLOperatorsStaticArraysCoreExt = "StaticArraysCore"

[compat]
Accessors = "0.1.42"
Adapt = "4"
ArrayInterface = "7.19"
Adapt = "4.5.2"
ArrayInterface = "7.24"
DocStringExtensions = "0.9.4"
LinearAlgebra = "1.10"
LoopVectorization = "0.12"
Expand Down
33 changes: 30 additions & 3 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,23 @@ Base.resize!(L::ScaledOperator, n::Integer) = (resize!(L.L, n); L)
LinearAlgebra.opnorm(L::ScaledOperator, p::Real = 2) = abs(L.λ) * opnorm(L.L, p)

function update_coefficients(L::ScaledOperator, u, p, t; kwargs...)
@reset L.L = update_coefficients(L.L, u, p, t; kwargs...)
@reset L.λ = update_coefficients(L.λ, u, p, t; kwargs...)
λ = _freeze_updated_scalar(update_coefficients(L.λ, u, p, t; kwargs...))
L_inner = update_coefficients(L.L, u, p, t; kwargs...)
return ScaledOperator(λ, L_inner)
end

return L
function _throw_updated_scaled_inplace()
throw(
ArgumentError(
"cannot update coefficients in-place after an out-of-place ScaledOperator update; call update_coefficients instead",
)
)
end

function update_coefficients!(
L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any}, u, p, t; kwargs...
)
_throw_updated_scaled_inplace()
end

function update_coefficients!(L::ScaledOperator, u, p, t; kwargs...)
Expand Down Expand Up @@ -445,6 +458,13 @@ end
return L.L(w, v, u, p, t, a, false; kwargs...)
end

@inline function (L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any})(
w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...
)
L = update_coefficients(L, u, p, t; kwargs...)
return mul!(w, L, v)
end

# In-place with scaling: w = α*(L*v) + β*w
@inline function (L::ScaledOperator)(
w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...
Expand All @@ -454,6 +474,13 @@ end
return L.L(w, v, u, p, t, a, β; kwargs...)
end

@inline function (L::ScaledOperator{<:Any, <:_UpdatedScalarOperator, <:Any})(
w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...
)
L = update_coefficients(L, u, p, t; kwargs...)
return mul!(w, L, v, α, β)
end

"""
Lazy operator addition

Expand Down
42 changes: 42 additions & 0 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ mutable struct ScalarOperator{T <: Number, F} <: AbstractSciMLScalarOperator{T}
update_func::F
end

# Immutable snapshot used by out-of-place updates after ScalarOperator has evaluated its state.
struct _UpdatedScalarOperator{T <: Number, F} <: AbstractSciMLScalarOperator{T}
val::T
update_func::F
end

_freeze_updated_scalar(α) = α
_freeze_updated_scalar(α::ScalarOperator) = _UpdatedScalarOperator(α.val, α.update_func)

"""
$SIGNATURES

Expand Down Expand Up @@ -186,13 +195,15 @@ end

# constructors
Base.convert(T::Type{<:Number}, α::ScalarOperator) = convert(T, α.val)
Base.convert(T::Type{<:Number}, α::_UpdatedScalarOperator) = convert(T, α.val)
Base.convert(::Type{ScalarOperator}, α::Number) = ScalarOperator(α)

ScalarOperator(α::AbstractSciMLScalarOperator) = α
ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ)

# traits
Base.show(io::IO, α::ScalarOperator) = print(io, "ScalarOperator($(α.val))")
Base.show(io::IO, α::_UpdatedScalarOperator) = print(io, "ScalarOperator($(α.val))")
function Base.conj(α::ScalarOperator) # TODO - test
val = conj(α.val)
update_func = (
Expand All @@ -208,12 +219,28 @@ function Base.conj(α::ScalarOperator) # TODO - test
return ScalarOperator(val; update_func = update_func, accepted_kwargs = NoKwargFilter())
end

function Base.conj(α::_UpdatedScalarOperator)
val = conj(α.val)
update_func = (
oldval, u, p, t;
kwargs...,
) -> α.update_func(
oldval |> conj,
u,
p,
t;
kwargs...
) |> conj
return _UpdatedScalarOperator(val, update_func)
end

Base.one(::AbstractSciMLScalarOperator{T}) where {T} = ScalarOperator(one(T))
Base.zero(::AbstractSciMLScalarOperator{T}) where {T} = ScalarOperator(zero(T))

Base.one(::Type{<:AbstractSciMLScalarOperator}) = ScalarOperator(true)
Base.zero(::Type{<:AbstractSciMLScalarOperator}) = ScalarOperator(false)
Base.abs(α::ScalarOperator) = abs(α.val)
Base.abs(α::_UpdatedScalarOperator) = abs(α.val)

function LinearAlgebra.exp(α::AbstractSciMLScalarOperator)
update_func = (
Expand All @@ -226,11 +253,16 @@ function LinearAlgebra.exp(α::AbstractSciMLScalarOperator)
end

Base.iszero(α::ScalarOperator) = iszero(α.val)
Base.iszero(α::_UpdatedScalarOperator) = iszero(α.val)

getops(α::ScalarOperator) = (α.val,)
getops(α::_UpdatedScalarOperator) = (α.val,)
isconstant(α::ScalarOperator) = update_func_isconstant(α.update_func)
isconstant(α::_UpdatedScalarOperator) = update_func_isconstant(α.update_func)
has_ldiv(α::ScalarOperator) = !iszero(α.val)
has_ldiv(α::_UpdatedScalarOperator) = !iszero(α.val)
has_ldiv!(α::ScalarOperator) = has_ldiv(α)
has_ldiv!(α::_UpdatedScalarOperator) = has_ldiv(α)

function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...)
L.val = L.update_func(L.val, u, p, t; kwargs...)
Expand All @@ -241,11 +273,21 @@ function SciMLOperators.update_coefficients(L::ScalarOperator, u, p, t; kwargs..
return ScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func)
end

function SciMLOperators.update_coefficients(
L::_UpdatedScalarOperator, u, p, t; kwargs...
)
return _UpdatedScalarOperator(L.update_func(L.val, u, p, t; kwargs...), L.update_func)
end

# Copy method to avoid aliasing
function Base.copy(L::ScalarOperator)
return ScalarOperator(L.val, L.update_func)
end

function Base.copy(L::_UpdatedScalarOperator)
return _UpdatedScalarOperator(L.val, L.update_func)
end

# Add ScalarOperator specific implementations for the new interface
function (α::ScalarOperator)(v::AbstractArray, u, p, t; kwargs...)
α = update_coefficients(α, u, p, t; kwargs...)
Expand Down
76 changes: 76 additions & 0 deletions test/ad_semantics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using SciMLOperators, LinearAlgebra, Test, Zygote

using SciMLOperators: concretize

const ad_n = 3
const ad_u = [0.3, -0.2, 0.7]
const ad_v = [1.0, -2.0, 0.5]
const ad_t = 0.4
const ad_pmat = [
0.0 2.0 -1.0
1.0 0.0 0.5
-0.25 0.75 0.0
]

ad_scalar() = ScalarOperator(0.0, (_, _, p, _) -> p)
ad_matrix() = MatrixOperator(ad_pmat)
ad_added_operator() = MatrixOperator(Matrix{Float64}(I, ad_n, ad_n)) + ad_scalar() * ad_matrix()

function ad_expected_scaled(p)
return p .* (ad_pmat * ad_v)
end

function ad_expected_added(p)
return (Matrix{Float64}(I, ad_n, ad_n) + p .* ad_pmat) * ad_v
end

@testset "AD semantic equivalence" begin
p = 1.7

@testset "ScalarOperator * MatrixOperator" begin
L = ad_scalar() * ad_matrix()

concretized_loss(p) = sum(concretize(update_coefficients(L, ad_u, p, ad_t)) * ad_v)
direct_loss(p) = sum(L(ad_v, ad_u, p, ad_t))

@test concretize(update_coefficients(L, ad_u, p, ad_t)) ≈ p .* ad_pmat
@test L(ad_v, ad_u, p, ad_t) ≈ ad_expected_scaled(p)

w = similar(ad_v)
L(w, ad_v, ad_u, p, ad_t)
@test w ≈ ad_expected_scaled(p)

expected_grad = sum(ad_pmat * ad_v)
@test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad
@test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad

updated_L = update_coefficients(L, ad_u, p, ad_t)
@test updated_L(ad_v, ad_u, p + 1, ad_t) ≈ ad_expected_scaled(p + 1)
@test_throws ArgumentError update_coefficients!(updated_L, ad_u, p + 1, ad_t)
updated_L(w, ad_v, ad_u, p + 1, ad_t)
@test w ≈ ad_expected_scaled(p + 1)

w .= 0.25
updated_L(w, ad_v, ad_u, p + 1, ad_t, 2.0, 0.5)
@test w ≈ 2 .* ad_expected_scaled(p + 1) .+ 0.125
end

@testset "MatrixOperator + ScalarOperator * MatrixOperator" begin
L = ad_added_operator()

concretized_loss(p) = sum(concretize(update_coefficients(L, ad_u, p, ad_t)) * ad_v)
direct_loss(p) = sum(L(ad_v, ad_u, p, ad_t))

@test concretize(update_coefficients(L, ad_u, p, ad_t)) ≈
Matrix{Float64}(I, ad_n, ad_n) + p .* ad_pmat
@test L(ad_v, ad_u, p, ad_t) ≈ ad_expected_added(p)

w = similar(ad_v)
L(w, ad_v, ad_u, p, ad_t)
@test w ≈ ad_expected_added(p)

expected_grad = sum(ad_pmat * ad_v)
@test only(Zygote.gradient(concretized_loss, p)) ≈ expected_grad
@test only(Zygote.gradient(direct_loss, p)) ≈ expected_grad
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ end
@time @safetestset "Zygote.jl" begin
include("zygote.jl")
end
@time @safetestset "AD semantics" begin
include("ad_semantics.jl")
end
@time @safetestset "Copy methods" begin
include("copy.jl")
end
Expand Down
14 changes: 14 additions & 0 deletions test/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,17 @@ for (LType, L) in (
end
end
end

@testset "Zygote update_coefficients concretize scaled operator" begin
A1 = MatrixOperator([1.0 0.0; 0.0 1.0])
A2 = MatrixOperator([1.0 0.0; 0.0 0.0])
coeff = ScalarOperator(0.0, (a, u, p, t) -> p)
L = A1 + coeff * A2

operator_entry(p) = (update_coefficients(L, 0, p, 0) |> concretize)[1, 1]
matrix_entry(p) = ([1.0 0.0; 0.0 1.0] + p * [1.0 0.0; 0.0 0.0])[1, 1]

p = 1.0
@test operator_entry(p) == matrix_entry(p)
@test Zygote.gradient(operator_entry, p)[1] == Zygote.gradient(matrix_entry, p)[1]
end
Loading