Skip to content

Commit afc9759

Browse files
committed
Preserve more FieldArrays with parametric eltype.
And return a `MArray` for mutable `FieldArray`
1 parent def8fc2 commit afc9759

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

src/FieldArray.jl

+16-11
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,19 @@ Base.cconvert(::Type{<:Ptr}, a::FieldArray) = Base.RefValue(a)
125125
Base.unsafe_convert(::Type{Ptr{T}}, m::Base.RefValue{FA}) where {N,T,D,FA<:FieldArray{N,T,D}} =
126126
Ptr{T}(Base.unsafe_convert(Ptr{FA}, m))
127127

128-
# We can automatically preserve FieldArrays in array operations which do not
129-
# change their eltype or Size. This should cover all non-parametric FieldArray,
130-
# but for those which are parametric on the eltype the user will still need to
131-
# overload similar_type themselves.
132-
similar_type(::Type{A}, ::Type{T}, S::Size) where {N, T, A<:FieldArray{N, T}} =
133-
_fieldarray_similar_type(A, T, S, Size(A))
134-
135-
# Extra layer of dispatch to match NewSize and OldSize
136-
_fieldarray_similar_type(A, T, NewSize::S, OldSize::S) where {S} = A
137-
_fieldarray_similar_type(A, T, NewSize, OldSize) =
138-
default_similar_type(T, NewSize, length_val(NewSize))
128+
# We can preserve FieldArrays in array operations which do not change their `Size` and `eltype`.
129+
# FieldArrays with parametric `eltype` would be adapted to the new `eltype` automatically.
130+
# Otherwise, we fallback to `S/MArray` based on it's mutability.
131+
function similar_type(::Type{A}, ::Type{T}, S::Size) where {T,A<:FieldArray}
132+
A′ = Base.typeintersect(base_type(A), StaticArray{Tuple{Tuple(S)...},T,length(S)})
133+
isabstracttype(A′) || A′ === Union{} || return A′
134+
if ismutabletype(A)
135+
return mutable_similar_type(T, S, length_val(S))
136+
else
137+
return default_similar_type(T, S, length_val(S))
138+
end
139+
end
140+
@pure base_type(@nospecialize(T::Type)) = Base.unwrap_unionall(T).name.wrapper
141+
if VERSION < v"1.7"
142+
@pure ismutabletype(@nospecialize(T::Type)) = Base.unwrap_unionall(T).mutable
143+
end

test/FieldMatrix.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
yy::T
6060
end
6161

62-
StaticArrays.similar_type(::Type{<:Tensor2x2}, ::Type{T}, s::Size{(2,2)}) where {T} = Tensor2x2{T}
6362
end)
6463

6564
p = Tensor2x2(0.0, 0.0, 0.0, 0.0)
@@ -83,8 +82,8 @@
8382

8483
@test @inferred(similar_type(Tensor2x2{Float64})) == Tensor2x2{Float64}
8584
@test @inferred(similar_type(Tensor2x2{Float64}, Float32)) == Tensor2x2{Float32}
86-
@test @inferred(similar_type(Tensor2x2{Float64}, Size(3,3))) == SMatrix{3,3,Float64,9}
87-
@test @inferred(similar_type(Tensor2x2{Float64}, Float32, Size(4,4))) == SMatrix{4,4,Float32,16}
85+
@test @inferred(similar_type(Tensor2x2{Float64}, Size(3, 3))) == MMatrix{3,3,Float64,9}
86+
@test @inferred(similar_type(Tensor2x2{Float64}, Float32, Size(4, 4))) == MMatrix{4,4,Float32,16}
8887

8988
# eltype promotion
9089
@test Tuple(@inferred(Tensor2x2(1., 2, 3, 4f0))) === (1.,2.,3.,4.)

test/FieldVector.jl

+3-4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
y::T
6464
end
6565

66-
StaticArrays.similar_type(::Type{<:Point2D}, ::Type{T}, s::Size{(2,)}) where {T} = Point2D{T}
6766
end)
6867

6968
p = Point2D(0.0, 0.0)
@@ -86,8 +85,8 @@
8685

8786
@test @inferred(similar_type(Point2D{Float64})) == Point2D{Float64}
8887
@test @inferred(similar_type(Point2D{Float64}, Float32)) == Point2D{Float32}
89-
@test @inferred(similar_type(Point2D{Float64}, Size(4))) == SVector{4,Float64}
90-
@test @inferred(similar_type(Point2D{Float64}, Float32, Size(4))) == SVector{4,Float32}
88+
@test @inferred(similar_type(Point2D{Float64}, Size(4))) == MVector{4,Float64}
89+
@test @inferred(similar_type(Point2D{Float64}, Float32, Size(4))) == MVector{4,Float32}
9190

9291
# eltype promotion
9392
@test Point2D(1f0, 2) isa Point2D{Float32}
@@ -122,7 +121,7 @@
122121
# No similar_type defined - test fallback codepath
123122
end)
124123

125-
@test @inferred(similar_type(FVT{Float64}, Float32)) == SVector{2,Float32} # Fallback code path
124+
@test @inferred(similar_type(FVT{Float64}, Float32)) == FVT{Float32}
126125
@test @inferred(similar_type(FVT{Float64}, Size(2))) == FVT{Float64}
127126
@test @inferred(similar_type(FVT{Float64}, Size(3))) == SVector{3,Float64}
128127
@test @inferred(similar_type(FVT{Float64}, Float32, Size(3))) == SVector{3,Float32}

0 commit comments

Comments
 (0)