Skip to content

Commit fe0bda9

Browse files
Implement Base.instantiate - take 2 (#1118)
* implement `instantiate` - get rid of BasicDimensionalStyle * fix setindex! for opaquearray to make some error messages clearer * fix materialize!
1 parent 2e5a812 commit fe0bda9

File tree

2 files changed

+57
-83
lines changed

2 files changed

+57
-83
lines changed

src/array/broadcast.jl

Lines changed: 55 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle, Style
1+
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle, Style, AbstractArrayStyle, Unknown
22

33
const STRICT_BROADCAST_CHECKS = Ref(true)
44
const STRICT_BROADCAST_DOCS = """
@@ -35,10 +35,9 @@ strict_broadcast!(x::Bool) = STRICT_BROADCAST_CHECKS[] = x
3535
# It preserves the dimension names.
3636
# `S` should be the `BroadcastStyle` of the wrapped type.
3737
# Copied from NamedDims.jl (thanks @oxinabox).
38-
struct BasicDimensionalStyle{N} <: AbstractArrayStyle{Any} end
39-
40-
struct DimensionalStyle{S<:BroadcastStyle} <: AbstractArrayStyle{Any} end
41-
DimensionalStyle(::S) where {S} = DimensionalStyle{S}()
38+
struct DimensionalStyle{S <: AbstractArrayStyle, N} <: AbstractArrayStyle{N} end
39+
DimensionalStyle(::S) where S<:AbstractArrayStyle{N} where N = DimensionalStyle{S, N}()
40+
DimensionalStyle(::S) where {S<:DimensionalStyle} = S() # avoid nested dimensionalstyle
4241
DimensionalStyle(::S, ::Val{N}) where {S,N} = DimensionalStyle(S(Val(N)))
4342
DimensionalStyle(::Val{N}) where N = DimensionalStyle{DefaultArrayStyle{N}}()
4443
function DimensionalStyle(a::BroadcastStyle, b::BroadcastStyle)
@@ -51,86 +50,59 @@ function DimensionalStyle(a::BroadcastStyle, b::BroadcastStyle)
5150
end
5251
end
5352

54-
function BroadcastStyle(::Type{<:AbstractDimArray{T,N,D,A}}) where {T,N,D,A}
55-
inner_style = typeof(BroadcastStyle(A))
56-
return DimensionalStyle{inner_style}()
57-
end
58-
BroadcastStyle(::Type{<:AbstractBasicDimArray{T,N}}) where {T,N} =
59-
BasicDimensionalStyle{N}()
60-
53+
BroadcastStyle(::Type{<:AbstractDimArray{T,N,D,A}}) where {T,N,D,A} =
54+
DimensionalStyle(BroadcastStyle(A))
55+
BroadcastStyle(::Type{<:AbstractBasicDimArray{T,N,D}}) where {T,N,D} =
56+
DimensionalStyle(DefaultArrayStyle{N}())
6157
BroadcastStyle(::DimensionalStyle, ::Base.Broadcast.Unknown) = Unknown()
62-
BroadcastStyle(::Base.Broadcast.Unknown, ::DimensionalStyle) = Unknown()
6358
BroadcastStyle(::DimensionalStyle{A}, ::DimensionalStyle{B}) where {A, B} = DimensionalStyle(A(), B())
64-
BroadcastStyle(::DimensionalStyle{A}, b::Style) where {A} = DimensionalStyle(A(), b)
65-
BroadcastStyle(a::Style, ::DimensionalStyle{B}) where {B} = DimensionalStyle(a, B())
59+
BroadcastStyle(::DimensionalStyle{A}, b::AbstractArrayStyle{N}) where {A,N} = DimensionalStyle(A(), b)
60+
BroadcastStyle(::DimensionalStyle{A}, b::DefaultArrayStyle{N}) where {A,N} = DimensionalStyle(A(), b) # ambiguity
6661
BroadcastStyle(::DimensionalStyle{A}, b::Style{Tuple}) where {A} = DimensionalStyle(A(), b)
67-
BroadcastStyle(a::Style{Tuple}, ::DimensionalStyle{B}) where {B} = DimensionalStyle(a, B())
68-
# We need to implement copy because if the wrapper array type does not
69-
# support setindex then the `similar` based default method will not work
70-
function Broadcast.copy(bc::Broadcasted{DimensionalStyle{S}}) where S
71-
A = _firstdimarray(bc)
72-
data = copy(_unwrap_broadcasted(bc))
73-
74-
A isa Nothing && return data # No AbstractDimArray
75-
76-
bdims = _broadcasted_dims(bc)
77-
_comparedims_broadcast(A, bdims...)
78-
79-
data isa AbstractArray || return data # result is a scalar
8062

81-
# unwrap AbstractDimArray data
82-
data = data isa AbstractDimArray ? parent(data) : data
83-
dims = format(Dimensions.promotedims(bdims...; skip_length_one=true), data)
84-
return rebuild(A; data, dims, refdims=refdims(A), name=Symbol(""))
85-
end
86-
function Broadcast.copy(bc::Broadcasted{BasicDimensionalStyle{N}}) where N
63+
# override base instantiate to check dimensions as well as axes
64+
@inline function Broadcast.instantiate(bc::Broadcasted{<:DimensionalStyle{S}}) where S
8765
A = _firstdimarray(bc)
88-
data = collect(bc)
89-
A isa Nothing && return data # No AbstractDimArray
90-
66+
# check if there is any DimArray and unwrap immediately if no
67+
isnothing(A) && return Broadcast.instantiate(_unwrap_broadcasted(bc))
9168
bdims = _broadcasted_dims(bc)
69+
if bc.axes isa Nothing
70+
axes = Base.Broadcast.combine_axes(_unwrap_broadcasted(bc).args...)
71+
ds = Dimensions.promotedims(bdims...; skip_length_one=true)
72+
length(axes) == length(ds) ||
73+
throw(ArgumentError("Number of broadcasted dimensions $(length(axes)) larger than $(ds)"))
74+
axes = map(Dimensions.DimUnitRange, axes, ds)
75+
else # bc already has axes which might have dimensions, e.g. when assigning to a DimArray
76+
axes = bc.axes
77+
Base.Broadcast.check_broadcast_axes(axes, bc.args...)
78+
ds = dims(axes)
79+
isnothing(ds) || _comparedims_broadcast(A, ds, bdims...)
80+
end
9281
_comparedims_broadcast(A, bdims...)
93-
94-
data isa AbstractArray || return data # result is a scalar
95-
96-
# Return an AbstractDimArray
97-
dims = format(Dimensions.promotedims(bdims...; skip_length_one=true), data)
98-
return dimconstructor(dims)(data, dims; refdims=refdims(A), name=Symbol(""))
82+
return Broadcasted(bc.style, bc.f, bc.args, axes)
9983
end
100-
101-
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{DimensionalStyle{S}}) where S
102-
fda = _firstdimarray(bc)
103-
isnothing(fda) || _comparedims_broadcast(fda, _broadcasted_dims(bc)...)
104-
copyto!(dest, _unwrap_broadcasted(bc))
105-
end
106-
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{BasicDimensionalStyle{N}}) where N
107-
fda = _firstdimarray(bc)
108-
isnothing(fda) || _comparedims_broadcast(fda, _broadcasted_dims(bc)...)
109-
copyto!(dest, bc)
110-
end
111-
112-
@inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
113-
# Need to check whether the dims are compatible in dest,
114-
# which are already stripped when sent to copyto!
115-
_comparedims_broadcast(dest, dims(dest), _broadcasted_dims(bc)...)
116-
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
117-
Base.Broadcast.materialize!(style, parent(dest), bc)
118-
return dest
119-
end
120-
121-
function Base.similar(bc::Broadcast.Broadcasted{DimensionalStyle{S}}, ::Type{T}) where {S,T}
84+
# Define copy because the inner style S might override copy (e.g. DiskArrays)
85+
function Base.copy(bc::Broadcasted{<:DimensionalStyle{S}}) where S
86+
data = copy(_unwrap_broadcasted(bc))
87+
data isa AbstractArray || return data # in the 0-d case data can be a scalar
88+
# let similar do the work - it will usually call rebuild unless A isa AbstractBasicDimArray
12289
A = _firstdimarray(bc)
123-
data = similar(_unwrap_broadcasted(bc), T, size(bc))
124-
dims, refdims = slicedims(A, axes(bc))
125-
return rebuild(A; data, dims, refdims, name=Symbol(""))
90+
similar(A; data, dims = dims(axes(bc)))
12691
end
127-
function Base.similar(bc::Broadcast.Broadcasted{BasicDimensionalStyle{N}}, ::Type{T}) where {N,T}
92+
# similar is usually only called in broadcast_preserving_zero_d
93+
function Base.similar(bc::Broadcasted{<:DimensionalStyle{S}}, ::Type{T}) where {S,T}
12894
A = _firstdimarray(bc)
129-
data = similar(A, T, size(bc))
130-
dims, refdims = slicedims(A, axes(bc))
131-
return dimconstructor(dims)(data, dims; refdims, name=Symbol(""))
95+
data = similar(_unwrap_broadcasted(bc), T)
96+
similar(A; data, dims = dims(axes(bc)))
13297
end
13398

99+
@inline function Base.materialize!(::DimensionalStyle, dest, bc::Broadcasted)
100+
# check dimensions
101+
bci = Broadcast.instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest)))
102+
# unwrap before copying
103+
Base.copyto!(_unwrap_broadcasted(dest), _unwrap_broadcasted(bci))
104+
return dest
105+
end
134106

135107
"""
136108
@d broadcast_expression options
@@ -407,29 +379,31 @@ end
407379
# Recursively unwraps `AbstractDimArray`s and `DimensionalStyle`s.
408380
# replacing the `AbstractDimArray`s with the wrapped array,
409381
# and `DimensionalStyle` with the wrapped `BroadcastStyle`.
410-
function _unwrap_broadcasted(bc::Broadcasted{DimensionalStyle{S}}) where S
382+
383+
function _unwrap_broadcasted(bc::Broadcasted{<:DimensionalStyle{S}}) where {S}
411384
innerargs = map(_unwrap_broadcasted, bc.args)
412-
return Broadcasted{S}(bc.f, innerargs)
385+
return Broadcasted{S}(bc.f, innerargs, _unwrap_broadcasted(bc.axes))
413386
end
414387
_unwrap_broadcasted(x) = x
415388
_unwrap_broadcasted(nda::AbstractDimArray) = parent(nda)
416-
_unwrap_broadcasted(boda::BroadcastOptionsDimArray) = parent(parent(boda))
417-
389+
_unwrap_broadcasted(bda::AbstractBasicDimArray) = OpaqueArray(bda)
390+
_unwrap_broadcasted(boda::BroadcastOptionsDimArray) = _unwrap_broadcasted(parent(boda))
391+
_unwrap_broadcasted(t::Tuple) = map(_unwrap_broadcasted, t)
392+
_unwrap_broadcasted(du::Dimensions.DimUnitRange) = parent(du)
418393
# Get the first dimensional array in the broadcast
419394
_firstdimarray(x::Broadcasted) = _firstdimarray(x.args)
420-
_firstdimarray(x::Tuple{<:AbstractBasicDimArray,Vararg}) = x[1]
421-
_firstdimarray(x::AbstractBasicDimArray) = x
422-
_firstdimarray(ext::Base.Broadcast.Extruded) = _firstdimarray(ext.x)
423-
function _firstdimarray(x::Tuple{<:Union{Broadcasted,Base.Broadcast.Extruded},Vararg})
395+
function _firstdimarray(x::Tuple)
424396
found = _firstdimarray(x[1])
425397
if found isa Nothing
426398
_firstdimarray(tail(x))
427399
else
428400
found
429401
end
430402
end
431-
_firstdimarray(x::Tuple) = _firstdimarray(tail(x))
432403
_firstdimarray(x::Tuple{}) = nothing
404+
_firstdimarray(ext::Base.Broadcast.Extruded) = _firstdimarray(ext.x)
405+
_firstdimarray(x::AbstractBasicDimArray) = x
406+
_firstdimarray(x) = nothing
433407

434408
# Make sure all arrays have the same dims, and return them
435409
_broadcasted_dims(bc::Broadcasted) = _broadcasted_dims(bc.args...)

src/opaque.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ OpaqueArray(st::P) where P<:AbstractDimStack{<:Any,T,N} where {T,N} = OpaqueArra
1515
Base.size(A::OpaqueArray) = size(A.parent)
1616
Base.getindex(A::OpaqueArray, I::Union{StandardIndices,Not{<:StandardIndices}}...) =
1717
Base.getindex(A.parent, I...)
18-
Base.setindex!(A::OpaqueArray, I::Union{StandardIndices,Not{<:StandardIndices}}...) =
19-
Base.setindex!(A.parent, I...)
18+
Base.setindex!(A::OpaqueArray, x, I::Union{StandardIndices,Not{<:StandardIndices}}...) =
19+
Base.setindex!(A.parent, x, I...)

0 commit comments

Comments
 (0)