diff --git a/Project.toml b/Project.toml index d91b1cbd8..87f0f1504 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,10 @@ ArrayInterfaceBandedMatricesExt = "BandedMatrices" ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" ArrayInterfaceCUDAExt = "CUDA" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" +ArrayInterfaceOffsetArraysExt = "OffsetArrays" ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" +ArrayInterfaceStaticArraysExt = "StaticArrays" +ArrayInterfaceStaticExt = "Static" ArrayInterfaceTrackerExt = "Tracker" [extras] @@ -30,21 +33,26 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker"] +test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "OffsetArrays", "StaticArrays", "StaticArraysCore", "Static", "Tracker"] [weakdeps] BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/docs/src/indexing.md b/docs/src/indexing.md index ed1282e74..35f255097 100644 --- a/docs/src/indexing.md +++ b/docs/src/indexing.md @@ -14,6 +14,7 @@ ArrayInterface.can_change_size ArrayInterface.can_setindex ArrayInterface.fast_scalar_indexing ArrayInterface.ismutable +ArrayInterface.is_splat_index ArrayInterface.ndims_index ArrayInterface.ndims_shape ArrayInterface.defines_strides @@ -22,6 +23,11 @@ ArrayInterface.ensures_sorted ArrayInterface.indices_do_not_alias ArrayInterface.instances_do_not_alias ArrayInterface.device +ArrayInterface.known_first +ArrayInterface.known_step +ArrayInterface.known_last +ArrayInterface.known_size +ArrayInterface.known_length ``` ## Allowed Indexing Functions @@ -46,4 +52,4 @@ and index translations. ArrayInterface.ArrayIndex ArrayInterface.GetIndex ArrayInterface.SetIndex! -``` \ No newline at end of file +``` diff --git a/ext/ArrayInterfaceOffsetArraysExt.jl b/ext/ArrayInterfaceOffsetArraysExt.jl new file mode 100644 index 000000000..645dc7e9f --- /dev/null +++ b/ext/ArrayInterfaceOffsetArraysExt.jl @@ -0,0 +1,21 @@ +module ArrayInterfaceOffsetArraysExt + +if isdefined(Base, :get_extension) + using ArrayInterface + using OffsetArrays +else + using ..ArrayInterface + using ..OffsetArrays +end + +ArrayInterface.parent_type(@nospecialize T::Type{<:OffsetArrays.IdOffsetRange}) = fieldtype(T, :parent) +ArrayInterface.parent_type(@nospecialize T::Type{<:OffsetArray}) = fieldtype(T, :parent) + +function ArrayInterface.known_size(@nospecialize T::Type{<:OffsetArrays.IdOffsetRange}) + ArrayInterface.known_size(ArrayInterface.parent_type(T)) +end +function ArrayInterface.known_size(@nospecialize T::Type{<:OffsetArray}) + ArrayInterface.known_size(ArrayInterface.parent_type(T)) +end + +end diff --git a/ext/ArrayInterfaceStaticArraysCoreExt.jl b/ext/ArrayInterfaceStaticArraysCoreExt.jl index 5c555f638..8303af5dc 100644 --- a/ext/ArrayInterfaceStaticArraysCoreExt.jl +++ b/ext/ArrayInterfaceStaticArraysCoreExt.jl @@ -32,4 +32,13 @@ end ArrayInterface.restructure(x::StaticArraysCore.SArray{S}, y) where {S} = StaticArraysCore.SArray{S}(y) +function ArrayInterface.known_size(::Type{<:StaticArraysCore.StaticArray{S}}) where {S} + @isdefined(S) ? tuple(S.parameters...) : ntuple(_-> nothing, ndims(T)) +end + +function ArrayInterface.known_length(T::Type{<:StaticArraysCore.StaticArray}) + sz = ArrayInterface.known_size(T) + isa(sz, Tuple{Vararg{Nothing}}) ? nothing : prod(sz) +end + end diff --git a/ext/ArrayInterfaceStaticArraysExt.jl b/ext/ArrayInterfaceStaticArraysExt.jl new file mode 100644 index 000000000..133574377 --- /dev/null +++ b/ext/ArrayInterfaceStaticArraysExt.jl @@ -0,0 +1,26 @@ +module ArrayInterfaceStaticArraysExt + +if isdefined(Base, :get_extension) + import ArrayInterface + import StaticArrays +else + import ..ArrayInterface + import ..StaticArrays +end + +ArrayInterface.known_first(@nospecialize T::Type{<:StaticArrays.SOneTo}) = 1 +ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = @isdefined(N) ? N::Int : nothing + +function ArrayInterface.known_first(::Type{<:StaticArrays.SUnitRange{S}}) where {S} + @isdefined(S) ? S::Int : nothing +end +function ArrayInterface.known_size(::Type{<:StaticArrays.SUnitRange{<:Any, L}}) where {L} + @isdefined(L) ? (L::Int,) : (nothing,) +end +function ArrayInterface.known_last(::Type{<:StaticArrays.SUnitRange{S, L}}) where {S, L} + start = @isdefined(S) ? S::Int : nothing + len = @isdefined(L) ? L::Int : nothing + (start === nothing || len === nothing) ? nothing : (start + len - 1) +end + +end diff --git a/ext/ArrayInterfaceStaticExt.jl b/ext/ArrayInterfaceStaticExt.jl new file mode 100644 index 000000000..3480d6aca --- /dev/null +++ b/ext/ArrayInterfaceStaticExt.jl @@ -0,0 +1,19 @@ +module ArrayInterfaceStaticExt + +if isdefined(Base, :get_extension) + import ArrayInterface + import Static +else + import ..ArrayInterface + import ..Static +end + +ArrayInterface.known_first(::Type{<:Static.OptionallyStaticUnitRange{Static.StaticInt{F}}}) where {F} = F::Int +ArrayInterface.known_first(::Type{<:Static.OptionallyStaticStepRange{Static.StaticInt{F}}}) where {F} = F::Int + +ArrayInterface.known_step(::Type{<:Static.OptionallyStaticStepRange{<:Any,Static.StaticInt{S}}}) where {S} = S::Int + +ArrayInterface.known_last(::Type{<:Static.OptionallyStaticUnitRange{<:Any,Static.StaticInt{L}}}) where {L} = L::Int +ArrayInterface.known_last(::Type{<:Static.OptionallyStaticStepRange{<:Any,<:Any,Static.StaticInt{L}}}) where {L} = L::Int + +end diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index ed616a87f..32bfff076 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -17,6 +17,7 @@ else end end end + @assume_effects :total __parameterless_type(T)=Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) @@ -486,6 +487,7 @@ end function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT) cholesky(sparse(similar(A, 1, 1)), check = false) end + """ cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a @@ -837,6 +839,13 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int) end end +""" + is_splat_index(::Type{T}) -> Bool + +Returns `true` if `T` is a type that splats across multiple dimensions. +""" +is_splat_index(T::Type) = false +is_splat_index(@nospecialize(x)) = is_splat_index(typeof(x)) """ ndims_index(::Type{I}) -> Int @@ -866,7 +875,7 @@ ndims_index(::Type{CartesianIndices{0, Tuple{}}}) = 1 ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T) ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T)) ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask)) -ndims_index(T::Type) = 1 +ndims_index(@nospecialize(T::Type)) = 1 ndims_index(@nospecialize(i)) = ndims_index(typeof(i)) """ @@ -887,7 +896,7 @@ julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2) 1 """ -ndims_shape(T::DataType) = ndims_index(T) +ndims_shape(T::Type) = ndims_index(T) ndims_shape(::Type{Colon}) = 1 ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T) ndims_shape(@nospecialize T::Type{<:Union{Number, Base.AbstractCartesianIndex}}) = 0 @@ -895,8 +904,6 @@ ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1 ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T) ndims_shape(x) = ndims_shape(typeof(x)) - - """ instances_do_not_alias(::Type{T}) -> Bool @@ -1030,6 +1037,237 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x)) +""" + known_first(I::Type) -> Union{Int, Nothing} + +Return the first index in an index range of type `I` when known at compile time. +Otherwise, return `nothing`. + +See also: [`ArrayInterface.known_last`](@ref), [`ArrayInterface.known_step`](@ref) + +```julia +julia> known_first(typeof(1:4)) +nothing + +julia> known_first(typeof(Base.OneTo(4))) +1 +``` +""" +known_first(x) = known_first(typeof(x)) +known_first(T::Type) = is_forwarding_wrapper(T) ? known_first(parent_type(T)) : nothing +known_first(::Type{<:Base.OneTo}) = 1 +known_first(@nospecialize T::Type{<:LinearIndices}) = 1 +known_first(@nospecialize T::Type{<:Base.IdentityUnitRange}) = known_first(parent_type(T)) +@inline function known_first(::Type{<:CartesianIndices{N, R}}) where {N, R} + tup = ntuple(i -> known_first(fieldtype(R, i)), Val(N)) + isa(tup, NTuple{N, Int}) ? CartesianIndex(tup) : nothing +end + +""" + known_last(::Type{T}) -> Union{Int, Nothing} + +Return the last index in an index range of type `I` when known at compile time. +Otherwise, return `nothing`. + +See also: [`ArrayInterface.known_first`](@ref), [`ArrayInterface.known_step`](@ref) + +```julia +julia> known_last(typeof(1:4)) +nothing + +julia> known_first(typeof(static(1):static(4))) +4 + +``` +""" +known_last(x) = known_last(typeof(x)) +known_last(T::Type) = is_forwarding_wrapper(T) ? known_last(parent_type(T)) : nothing +@inline function known_last(::Type{<:CartesianIndices{N, R}}) where {N, R} + tup = ntuple(i -> known_last(fieldtype(R, i)), Val(N)) + isa(tup, NTuple{N, Int}) ? CartesianIndex(tup) : nothing +end + +""" + known_step(I::Type) -> Union{Int, Nothing} + +Return the step size for an index range of type `I` when known at compile time. +Otherwise, return `nothing`. + +See also: [`ArrayInterface.known_first`](@ref), [`ArrayInterface.known_last`](@ref) + +```julia +julia> known_step(typeof(1:2:8)) +nothing + +julia> known_step(typeof(1:4)) +1 + +``` +""" +known_step(x) = known_step(typeof(x)) +known_step(T::Type) = is_forwarding_wrapper(T) ? known_step(parent_type(T)) : nothing +known_step(@nospecialize T::Type{<:AbstractUnitRange}) = 1 + +""" + known_size(::Type{T}) -> Tuple + known_size(::Type{T}, dim) -> Union{Int, Nothing} + +Returns the size of each dimension of `A` or along dimension `dim` of `A` that is known at +compile time. If a dimension does not have a known size along a dimension then `nothing` is +returned in its position. +""" +@inline known_size(x, dim::Integer) = ndims(x) < dim ? 1 : known_size(x)[dim] +known_size(x) = known_size(typeof(x)) +@inline function known_size(T::Type) + if is_forwarding_wrapper(T) + return known_size(parent_type(T)) + elseif isa(Base.IteratorSize(T), Base.HasShape) + return ntuple(_ -> nothing, ndims(T)) + else + return (known_length(T),) + end +end +@inline known_size(@nospecialize T::Type{<:Number}) = () +@inline known_size(@nospecialize T::Type{<:VecAdjTrans}) = (1, known_length(parent_type(T))) +@inline function known_size(@nospecialize T::Type{<:MatAdjTrans}) + s1, s2 = known_size(parent_type(T)) + (s2, s1) +end +function known_size(::Type{<:PermutedDimsArray{<:Any, N, I1, I2, P}}) where {N, I1, I2, P} + psize = known_size(P) + ntuple(i -> getfield(psize, getfield(I1, i)), Val{N}()) +end +function known_size(@nospecialize T::Type{<:Diagonal}) + s = known_length(parent_type(T)) + (s, s) +end +known_size(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = known_size(parent_type(T)) +@inline function known_size(::Type{<:Base.ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped} + psize = known_size(A) + if IsReshaped + if sizeof(S) > sizeof(T) + return (div(sizeof(S), sizeof(T)), psize...) + elseif sizeof(S) < sizeof(T) + return Base.tail(psize) + else + return psize + end + else + if Base.issingletontype(T) || first(psize) === nothing + return psize + else + return (div(first(psize) * sizeof(S), sizeof(T)), Base.tail(psize)...) + end + end +end +known_size(::Type{<:Base.IdentityUnitRange{I}}) where {I} = known_size(I) +known_size(::Type{<:Base.Generator{I}}) where {I} = known_size(I) +known_size(::Type{<:Iterators.Reverse{I}}) where {I} = known_size(I) +known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I) +known_size(::Type{<:Iterators.Accumulate{<:Any,I}}) where {I} = known_size(I) +known_size(::Type{<:Iterators.Pairs{<:Any,<:Any,I}}) where {I} = known_size(I) +@inline function known_size(::Type{<:Iterators.ProductIterator{T}}) where {T} + ntuple(i -> known_length(fieldtype(T, i)), Val(known_length(T))) +end +@inline function known_size(@nospecialize T::Type{<:AbstractRange}) + if is_forwarding_wrapper(T) + return known_size(parent_type(T)) + else + start = known_first(T) + s = known_step(T) + stop = known_last(T) + if isa(stop, Int) && isa(s, Int) && isa(start, Int) + if s > 0 + return (stop < start ? 0 : div(stop - start, s) + 1,) + else + return (stop > start ? 0 : div(start - stop, -s) + 1,) + end + else + return (nothing,) + end + end +end + +@inline function known_size(@nospecialize T::Type{<:Union{LinearIndices,CartesianIndices}}) + I = fieldtype(T, :indices) + ntuple(i -> known_length(fieldtype(I, i)), Val(ndims(T))) +end + +@inline function known_size(T::Type{<:SubArray}) + I = fieldtype(T, :indices) + ninds = fieldcount(I) + if ninds === 1 + I_1 = fieldtype(I, 1) + return I_1 <: Base.Slice ? (known_length(parent_type(T)),) : known_size(I_1) + else + psize = known_size(parent_type(T)) + ndi_summed = cumsum(map_tuple_type(ndims_index, I)) + sz = ntuple(Val{nfields(ndi_summed)}()) do i + I_i = fieldtype(I, i) + if I_i <: Base.Slice + getfield(psize, getfield(ndi_summed, i)) + else + known_size(I_i) + end + end + return flatten_tuples(sz) + end +end + +# 1. `Zip` doesn't check that its collections are compatible (same size) at construction, +# but we assume as much b/c otherwise it will error while iterating. So we promote to the +# known size if matching a `Nothing` and `Int` size. +# 2. `promote_shape(::Tuple{Vararg{IntType}}, ::Tuple{Vararg{IntType}})` promotes +# trailing dimensions (which must be of size 1), to `static(1)`. We want to stick to +# `Nothing` and `Int` types, so we do one last pass to ensure everything is dynamic +@inline function known_size(::Type{<:Iterators.Zip{T}}) where {T} + reduce(promote_known_shape, map_tuple_type(known_size, T)) +end +function promote_known_shape(x::Tuple{Vararg{Union{Nothing,Int}, XN}}, y::Tuple{Vararg{Union{Nothing,Int}, YN}}) where {XN, YN} + if XN >= YN + ntuple(Val{XN}()) do i + x_i = getfield(x, i) + x_i === nothing ? i > YN ? 1 : getfield(y, i) : x_i + end + else + return promote_known_shape(y, x) + end +end + +""" + known_length(::Type{T}) -> Union{Int, Nothing} + +If `length` of an instance of type `T` is known at compile time, return it. +Otherwise, return `nothing`. +""" +known_length(x) = known_length(typeof(x)) +function known_length(::Type{T}) where {T} + if isa(Base.IteratorSize(T), Base.HasShape) + # this is a multidimensional iterator so we assume that known_size is defined + sz = known_size(T) + len = 1 + for sz_i in sz + isa(sz_i, Int) || return nothing + len *= sz_i + end + return len + else + # if it is an iterator with length it's compile time length is not provided + return nothing + end +end + +known_length(::Type{<:NamedTuple{L}}) where {L} = nfields(L) +known_length(@nospecialize T::Type{<:Base.Slice}) = known_length(parent_type(T)) +known_length(@nospecialize T::Type{<:Tuple}) = fieldcount(T) +known_length(@nospecialize T::Type{<:Number}) = 1 +known_length(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N::Int +function known_length(::Type{<:Iterators.Flatten{I}}) where {I} + lenitr = known_length(I) + lenelt = known_length(eltype(I)) + (lenelt isa Int && lenitr isa Int) ? (lenitr * lenelt) : nothing +end + ## Extensions import Requires @@ -1039,6 +1277,8 @@ import Requires Requires.@require BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" begin include("../ext/ArrayInterfaceBlockBandedMatricesExt.jl") end Requires.@require GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" begin include("../ext/ArrayInterfaceGPUArraysCoreExt.jl") end Requires.@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysCoreExt.jl") end + Requires.@require StaticArrays = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysExt.jl") end + Requires.@require Static = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticExt.jl") end Requires.@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin include("../ext/ArrayInterfaceCUDAExt.jl") end Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/ArrayInterfaceTrackerExt.jl") end end diff --git a/test/core.jl b/test/core.jl index bd0cd6cf3..f21a6ade5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -262,7 +262,6 @@ end @testset "linearalgebra instances" begin for A in [rand(2,2), rand(Float32,2,2), rand(BigFloat,2,2)] - @test ArrayInterface.lu_instance(A) isa typeof(lu(A)) @test ArrayInterface.qr_instance(A) isa typeof(qr(A)) @@ -282,4 +281,53 @@ end end @test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A))) end -end \ No newline at end of file +end + +@testset "known values" begin + CI = CartesianIndices((2, 2)) + + @test isnothing(@inferred(ArrayInterface.known_first(typeof(1:4)))) + @test isone(@inferred(ArrayInterface.known_first(Base.OneTo(4)))) + @test isone(@inferred(ArrayInterface.known_first(Base.IdentityUnitRange(Base.OneTo(4))))) + @test isone(@inferred(ArrayInterface.known_first(LinearIndices((1, 1, 1))))) + @test isone(@inferred(ArrayInterface.known_first(typeof(Base.OneTo(4))))) + @test @inferred(ArrayInterface.known_first(typeof(CI))) == CartesianIndex(1, 1) + @test @inferred(ArrayInterface.known_first(typeof(CI))) == CartesianIndex(1, 1) + + @test isnothing(@inferred(ArrayInterface.known_last(1:4))) + @test isnothing(@inferred(ArrayInterface.known_last(typeof(1:4)))) + @test @inferred(ArrayInterface.known_last(typeof(CI))) === nothing + + @test isnothing(@inferred(ArrayInterface.known_step(typeof(1:0.2:4)))) + @test isone(@inferred(ArrayInterface.known_step(1:4))) + @test isone(@inferred(ArrayInterface.known_step(typeof(1:4)))) + @test isone(@inferred(ArrayInterface.known_step(typeof(Base.Slice(1:4))))) + @test isone(@inferred(ArrayInterface.known_step(typeof(view(1:4, 1:2))))) + + A = zeros(3, 4, 5); + A[:] = 1:60 + Ap = @view(PermutedDimsArray(A, (3, 1, 2))[:, 1:2, 1])'; + Ar = reinterpret(Float32, A); + A_trailingdim = zeros(2, 3, 4, 1) + D = @view(A[:, 2:2:4, :]); + A2 = zeros(4, 3, 5) + A2r = reinterpret(ComplexF64, A2) + + @test @inferred(ArrayInterface.known_size(1)) === () + @test @inferred(ArrayInterface.known_size([1, 1]')) === (1, nothing) + @test @inferred(ArrayInterface.known_size(view([1, 1]', :, 1))) === (1, ) + @test @inferred(ArrayInterface.known_size(Diagonal(view([1, 1]', :, 1)))) === (1, 1) + @test @inferred(ArrayInterface.known_size(view(rand(4), reshape(1:4, 2, 2)))) == (nothing, nothing) + @test @inferred(ArrayInterface.known_size(A)) === (nothing, nothing, nothing) + @test @inferred(ArrayInterface.known_size(Ap)) === (nothing, nothing) + @test @inferred(ArrayInterface.known_size(Ar)) === (nothing, nothing, nothing,) + @test ArrayInterface.known_size(Ar, 1) === nothing + @test ArrayInterface.known_size(Ar, 4) === 1 + @test @inferred(ArrayInterface.known_size(A2)) === (nothing, nothing, nothing) + @test @inferred(ArrayInterface.known_size(A2r)) === (nothing, nothing, nothing) + + @test @inferred(ArrayInterface.known_length(1)) === 1 + @test @inferred(ArrayInterface.known_length(Base.Slice(1:2))) === nothing + @test @inferred(ArrayInterface.known_length(CartesianIndex(1, 2, 3))) === 3 + @test @inferred(ArrayInterface.known_length((x = 1, y = 2))) === 2 +end diff --git a/test/offsetarrays.jl b/test/offsetarrays.jl new file mode 100644 index 000000000..6353bf4f6 --- /dev/null +++ b/test/offsetarrays.jl @@ -0,0 +1,15 @@ + +using ArrayInterface +using OffsetArrays +using StaticArrays +using Test + +oa = OffsetArray([1, 2]', 1, 1) +@test @inferred(ArrayInterface.known_size(oa)) == (1, nothing) +@test @inferred(ArrayInterface.known_length(oa)) === nothing + + +id = OffsetArrays.IdOffsetRange(SOneTo(10), 1) +@test @inferred(ArrayInterface.known_size(id)) == (10, ) +@test @inferred(ArrayInterface.known_length(id)) == 10 + diff --git a/test/runtests.jl b/test/runtests.jl index ec3493fd8..72a8ba0df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,10 +14,12 @@ end @time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end @time @safetestset "Core" begin include("core.jl") end @time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end + @time @safetestset "StaticArrays" begin include("staticarrays.jl") end + @time @safetestset "Static" begin include("static.jl") end end if GROUP == "GPU" activate_gpu_env() @time @safetestset "CUDA" begin include("gpu/cuda.jl") end end -end \ No newline at end of file +end diff --git a/test/static.jl b/test/static.jl new file mode 100644 index 000000000..382b2a4c5 --- /dev/null +++ b/test/static.jl @@ -0,0 +1,11 @@ + +using ArrayInterface +using Static +using Test + +iprod = Iterators.product(static(1):static(2), static(1):static(3), static(1):static(4)) +@test @inferred(ArrayInterface.known_size(iprod)) === (2, 3, 4) + +iflat = Iterators.flatten(iprod) +@test @inferred(ArrayInterface.known_size(iflat)) === (72,) + diff --git a/test/staticarrays.jl b/test/staticarrays.jl new file mode 100644 index 000000000..a6292e6f1 --- /dev/null +++ b/test/staticarrays.jl @@ -0,0 +1,52 @@ + +using ArrayInterface +using StaticArrays +using Test + +so = SOneTo(10) +@test ArrayInterface.known_first(typeof(so)) == first(so) +@test ArrayInterface.known_last(typeof(so)) == last(so) +@test ArrayInterface.known_length(typeof(so)) == length(so) + +su = StaticArrays.SUnitRange(2, 10) +@test ArrayInterface.known_first(typeof(su)) == first(su) +@test ArrayInterface.known_last(typeof(su)) == last(su) +@test ArrayInterface.known_length(typeof(su)) == length(su) + +S = @SArray(zeros(2, 3, 4)) +Sp = @view(PermutedDimsArray(S, (3, 1, 2))[2:3, 1:2, :]); +Sp2 = @view(PermutedDimsArray(S, (3, 2, 1))[2:3, :, :]); +Mp = @view(PermutedDimsArray(S, (3, 1, 2))[:, 2, :])'; +Mp2 = @view(PermutedDimsArray(S, (3, 1, 2))[2:3, :, 2])'; + + +irev = Iterators.reverse(S) +igen = Iterators.map(identity, S) +iacc = Iterators.accumulate(+, S) + +ienum = enumerate(S) +ipairs = pairs(S) +izip = zip(S, S) + +@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(irev)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(igen)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(iacc)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(ienum)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(izip)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(ipairs)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(zip(S, zeros(2, 3, 4, 1)))) === (2, 3, 4, 1) +@test @inferred(ArrayInterface.known_size(zip(zeros(2, 3, 4, 1), S))) === (2, 3, 4, 1) +@test @inferred(ArrayInterface.known_length(Iterators.flatten(((x, y) for x in 0:1 for y in 'a':'c')))) === nothing +@test ArrayInterface.known_length(S) == length(S) + + +@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4) +@test @inferred(ArrayInterface.known_size(Sp)) === (nothing, nothing, 3) +@test @inferred(ArrayInterface.known_size(Sp2)) === (nothing, 3, 2) +@test ArrayInterface.known_size(Sp2, 1) === nothing +@test ArrayInterface.known_size(Sp2, 2) === 3 +@test ArrayInterface.known_size(Sp2, 3) === 2 +@test @inferred(ArrayInterface.known_size(Mp)) === (3, 4) +@test @inferred(ArrayInterface.known_size(Mp2)) === (2, nothing) + diff --git a/test/staticarrayscore.jl b/test/staticarrayscore.jl index 420a05c74..f3f740ae1 100644 --- a/test/staticarrayscore.jl +++ b/test/staticarrayscore.jl @@ -1,3 +1,4 @@ + using StaticArrays, ArrayInterface, Test using LinearAlgebra using ArrayInterface: undefmatrix, zeromatrix @@ -43,3 +44,4 @@ zr = ArrayInterface.restructure(x, z) end end end +