diff --git a/src/basic.jl b/src/basic.jl index d137131e..7501b7b0 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -15,6 +15,7 @@ function Base.one(L::AbstractSciMLOperator) end Base.convert(::Type{AbstractMatrix}, ii::IdentityOperator) = Diagonal(ones(Bool, ii.len)) +has_concretization(::IdentityOperator) = true # Copy method to avoid aliasing - IdentityOperator has no mutable fields, can return self Base.copy(L::IdentityOperator) = L @@ -138,6 +139,7 @@ function Base.zero(L::AbstractSciMLOperator) end Base.convert(::Type{AbstractMatrix}, nn::NullOperator) = Diagonal(zeros(Bool, nn.len)) +has_concretization(::NullOperator) = true # Copy method to avoid aliasing - NullOperator has no mutable fields, can return self Base.copy(L::NullOperator) = L @@ -303,6 +305,7 @@ end function Base.convert(::Type{AbstractMatrix}, L::ScaledOperator) return convert(Number, L.λ) * convert(AbstractMatrix, L.L) end +has_concretization(L::ScaledOperator) = has_concretization(L.λ) & has_concretization(L.L) # traits function Base.show(io::IO, L::ScaledOperator{T}) where {T} @@ -568,6 +571,7 @@ end function Base.convert(::Type{AbstractMatrix}, L::AddedOperator) return sum(op -> convert(AbstractMatrix, op), L.ops) end +has_concretization(L::AddedOperator) = all(has_concretization, L.ops) # traits function Base.show(io::IO, L::AddedOperator) @@ -808,6 +812,7 @@ end function Base.convert(::Type{AbstractMatrix}, L::ComposedOperator) return prod(op -> convert(AbstractMatrix, op), L.ops) end +has_concretization(L::ComposedOperator) = all(has_concretization, L.ops) # traits function Base.show(io::IO, L::ComposedOperator) @@ -1150,6 +1155,7 @@ Base.:/(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = A * inv(B) function Base.convert(::Type{AbstractMatrix}, L::InvertedOperator) return inv(convert(AbstractMatrix, L.L)) end +has_concretization(L::InvertedOperator) = has_concretization(L.L) function Base.show(io::IO, L::InvertedOperator) print(io, "1 / ") diff --git a/src/func.jl b/src/func.jl index bfe37f12..c4fa023d 100644 --- a/src/func.jl +++ b/src/func.jl @@ -765,6 +765,7 @@ end islinear(L::FunctionOperator) = L.traits.islinear isconvertible(L::FunctionOperator) = L.traits.isconvertible +has_concretization(L::FunctionOperator) = isconvertible(L) isconstant(L::FunctionOperator) = L.traits.isconstant has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing) has_mul(::FunctionOperator{iip}) where {iip} = true diff --git a/src/interface.jl b/src/interface.jl index 4cc687e7..2801dc70 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -465,6 +465,7 @@ for pred in ( :isposdef, ) @eval function LinearAlgebra.$pred(L::AbstractSciMLOperator) + has_concretization(L) || return false if !isconvertible(L) @warn """using convert-based fallback in $($pred).""" end diff --git a/src/left.jl b/src/left.jl index 584b8a6e..1366107b 100644 --- a/src/left.jl +++ b/src/left.jl @@ -110,6 +110,7 @@ for (op, LType, VType) in ( L.L ) ) + @eval has_concretization(L::$LType) = has_concretization(L.L) # traits @eval Base.size(L::$LType) = size(L.L) |> reverse diff --git a/src/matrix.jl b/src/matrix.jl index a2d52292..1ddd27a7 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -447,6 +447,7 @@ function Base.convert( end Base.convert(::Type{AbstractMatrix}, L::InvertibleOperator) = convert(AbstractMatrix, L.L) +has_concretization(L::InvertibleOperator) = has_concretization(L.L) # traits function Base.show(io::IO, L::InvertibleOperator) diff --git a/src/scalar.jl b/src/scalar.jl index 7dea75fe..8abc15ed 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -35,6 +35,7 @@ has_mul!(::AbstractSciMLScalarOperator) = true isconcrete(::AbstractSciMLScalarOperator) = true islinear(::AbstractSciMLScalarOperator) = true has_adjoint(::AbstractSciMLScalarOperator) = true +has_concretization(::AbstractSciMLScalarOperator) = true Base.:*(α::AbstractSciMLScalarOperator, u::AbstractArray) = convert(Number, α) * u Base.:\(α::AbstractSciMLScalarOperator, u::AbstractArray) = convert(Number, α) \ u diff --git a/src/tensor.jl b/src/tensor.jl index 9a0a0ed8..89cef769 100644 --- a/src/tensor.jl +++ b/src/tensor.jl @@ -180,6 +180,7 @@ end islinear(L::TensorProductOperator) = reduce(&, islinear.(L.ops)) isconvertible(::TensorProductOperator) = false +has_concretization(L::TensorProductOperator) = all(has_concretization, L.ops) Base.iszero(L::TensorProductOperator) = reduce(|, iszero.(L.ops)) has_adjoint(L::TensorProductOperator) = reduce(&, has_adjoint.(L.ops)) has_mul(L::TensorProductOperator) = reduce(&, has_mul.(L.ops)) diff --git a/src/utils.jl b/src/utils.jl index 01a38d46..1506122e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -81,3 +81,4 @@ _unwrap_val(x) = x _unwrap_val(::Val{X}) where {X} = X has_concretization(::AbstractSciMLOperator) = false +has_concretization(::Union{AbstractMatrix, UniformScaling, Factorization, Number}) = true diff --git a/test/basic.jl b/test/basic.jl index dac104b8..6ad4a6bf 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -420,6 +420,29 @@ end @test ldiv!(rand(N), op, u) ≈ op \ u end +@testset "has_concretization composites" begin + A = MatrixOperator(rand(N, N)) + B = MatrixOperator(rand(N, N)) + F = FunctionOperator( + (du, u, p, t) -> copyto!(du, u), + zeros(N), + zeros(N); + isinplace = true, + T = Float64, + islinear = true + ) + + @test has_concretization(A) + @test has_concretization(2A) + @test has_concretization(A + B) + @test has_concretization(A * B) + @test has_concretization(inv(A)) + @test !has_concretization(F) + @test !has_concretization(F * A) + @test !has_concretization(A + F) + @test !ishermitian(F * A) +end + @testset "Adjoint, Transpose" begin for ( op,