1- import Base. Broadcast: BroadcastStyle, DefaultArrayStyle, Style
1+ import Base. Broadcast: BroadcastStyle, DefaultArrayStyle, Style, AbstractArrayStyle, Unknown
22
33const STRICT_BROADCAST_CHECKS = Ref (true )
44const STRICT_BROADCAST_DOCS = """
@@ -35,10 +35,9 @@ strict_broadcast!(x::Bool) = STRICT_BROADCAST_CHECKS[] = x
3535# It preserves the dimension names.
3636# `S` should be the `BroadcastStyle` of the wrapped type.
3737# Copied from NamedDims.jl (thanks @oxinabox).
38- struct BasicDimensionalStyle{N} <: AbstractArrayStyle{Any} end
39-
40- struct DimensionalStyle{S<: BroadcastStyle } <: AbstractArrayStyle{Any} end
41- DimensionalStyle (:: S ) where {S} = DimensionalStyle {S} ()
38+ struct DimensionalStyle{S <: AbstractArrayStyle , N} <: AbstractArrayStyle{N} end
39+ DimensionalStyle (:: S ) where S<: AbstractArrayStyle{N} where N = DimensionalStyle {S, N} ()
40+ DimensionalStyle (:: S ) where {S<: DimensionalStyle } = S () # avoid nested dimensionalstyle
4241DimensionalStyle (:: S , :: Val{N} ) where {S,N} = DimensionalStyle (S (Val (N)))
4342DimensionalStyle (:: Val{N} ) where N = DimensionalStyle {DefaultArrayStyle{N}} ()
4443function DimensionalStyle (a:: BroadcastStyle , b:: BroadcastStyle )
@@ -51,86 +50,59 @@ function DimensionalStyle(a::BroadcastStyle, b::BroadcastStyle)
5150 end
5251end
5352
54- function BroadcastStyle (:: Type{<:AbstractDimArray{T,N,D,A}} ) where {T,N,D,A}
55- inner_style = typeof (BroadcastStyle (A))
56- return DimensionalStyle {inner_style} ()
57- end
58- BroadcastStyle (:: Type{<:AbstractBasicDimArray{T,N}} ) where {T,N} =
59- BasicDimensionalStyle {N} ()
60-
53+ BroadcastStyle (:: Type{<:AbstractDimArray{T,N,D,A}} ) where {T,N,D,A} =
54+ DimensionalStyle (BroadcastStyle (A))
55+ BroadcastStyle (:: Type{<:AbstractBasicDimArray{T,N,D}} ) where {T,N,D} =
56+ DimensionalStyle (DefaultArrayStyle {N} ())
6157BroadcastStyle (:: DimensionalStyle , :: Base.Broadcast.Unknown ) = Unknown ()
62- BroadcastStyle (:: Base.Broadcast.Unknown , :: DimensionalStyle ) = Unknown ()
6358BroadcastStyle (:: DimensionalStyle{A} , :: DimensionalStyle{B} ) where {A, B} = DimensionalStyle (A (), B ())
64- BroadcastStyle (:: DimensionalStyle{A} , b:: Style ) where {A} = DimensionalStyle (A (), b)
65- BroadcastStyle (a :: Style , :: DimensionalStyle{B } ) where {B } = DimensionalStyle (a, B ())
59+ BroadcastStyle (:: DimensionalStyle{A} , b:: AbstractArrayStyle{N} ) where {A,N } = DimensionalStyle (A (), b)
60+ BroadcastStyle (:: DimensionalStyle{A} , b :: DefaultArrayStyle{N } ) where {A,N } = DimensionalStyle (A (), b) # ambiguity
6661BroadcastStyle (:: DimensionalStyle{A} , b:: Style{Tuple} ) where {A} = DimensionalStyle (A (), b)
67- BroadcastStyle (a:: Style{Tuple} , :: DimensionalStyle{B} ) where {B} = DimensionalStyle (a, B ())
68- # We need to implement copy because if the wrapper array type does not
69- # support setindex then the `similar` based default method will not work
70- function Broadcast. copy (bc:: Broadcasted{DimensionalStyle{S}} ) where S
71- A = _firstdimarray (bc)
72- data = copy (_unwrap_broadcasted (bc))
73-
74- A isa Nothing && return data # No AbstractDimArray
75-
76- bdims = _broadcasted_dims (bc)
77- _comparedims_broadcast (A, bdims... )
78-
79- data isa AbstractArray || return data # result is a scalar
8062
81- # unwrap AbstractDimArray data
82- data = data isa AbstractDimArray ? parent (data) : data
83- dims = format (Dimensions. promotedims (bdims... ; skip_length_one= true ), data)
84- return rebuild (A; data, dims, refdims= refdims (A), name= Symbol (" " ))
85- end
86- function Broadcast. copy (bc:: Broadcasted{BasicDimensionalStyle{N}} ) where N
63+ # override base instantiate to check dimensions as well as axes
64+ @inline function Broadcast. instantiate (bc:: Broadcasted{<:DimensionalStyle{S}} ) where S
8765 A = _firstdimarray (bc)
88- data = collect (bc)
89- A isa Nothing && return data # No AbstractDimArray
90-
66+ # check if there is any DimArray and unwrap immediately if no
67+ isnothing (A) && return Broadcast. instantiate (_unwrap_broadcasted (bc))
9168 bdims = _broadcasted_dims (bc)
69+ if bc. axes isa Nothing
70+ axes = Base. Broadcast. combine_axes (_unwrap_broadcasted (bc). args... )
71+ ds = Dimensions. promotedims (bdims... ; skip_length_one= true )
72+ length (axes) == length (ds) ||
73+ throw (ArgumentError (" Number of broadcasted dimensions $(length (axes)) larger than $(ds) " ))
74+ axes = map (Dimensions. DimUnitRange, axes, ds)
75+ else # bc already has axes which might have dimensions, e.g. when assigning to a DimArray
76+ axes = bc. axes
77+ Base. Broadcast. check_broadcast_axes (axes, bc. args... )
78+ ds = dims (axes)
79+ isnothing (ds) || _comparedims_broadcast (A, ds, bdims... )
80+ end
9281 _comparedims_broadcast (A, bdims... )
93-
94- data isa AbstractArray || return data # result is a scalar
95-
96- # Return an AbstractDimArray
97- dims = format (Dimensions. promotedims (bdims... ; skip_length_one= true ), data)
98- return dimconstructor (dims)(data, dims; refdims= refdims (A), name= Symbol (" " ))
82+ return Broadcasted (bc. style, bc. f, bc. args, axes)
9983end
100-
101- function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{DimensionalStyle{S}} ) where S
102- fda = _firstdimarray (bc)
103- isnothing (fda) || _comparedims_broadcast (fda, _broadcasted_dims (bc)... )
104- copyto! (dest, _unwrap_broadcasted (bc))
105- end
106- function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{BasicDimensionalStyle{N}} ) where N
107- fda = _firstdimarray (bc)
108- isnothing (fda) || _comparedims_broadcast (fda, _broadcasted_dims (bc)... )
109- copyto! (dest, bc)
110- end
111-
112- @inline function Base. Broadcast. materialize! (dest:: AbstractDimArray , bc:: Base.Broadcast.Broadcasted{<:Any} )
113- # Need to check whether the dims are compatible in dest,
114- # which are already stripped when sent to copyto!
115- _comparedims_broadcast (dest, dims (dest), _broadcasted_dims (bc)... )
116- style = DimensionalData. DimensionalStyle (Base. Broadcast. combine_styles (parent (dest), bc))
117- Base. Broadcast. materialize! (style, parent (dest), bc)
118- return dest
119- end
120-
121- function Base. similar (bc:: Broadcast.Broadcasted{DimensionalStyle{S}} , :: Type{T} ) where {S,T}
84+ # Define copy because the inner style S might override copy (e.g. DiskArrays)
85+ function Base. copy (bc:: Broadcasted{<:DimensionalStyle{S}} ) where S
86+ data = copy (_unwrap_broadcasted (bc))
87+ data isa AbstractArray || return data # in the 0-d case data can be a scalar
88+ # let similar do the work - it will usually call rebuild unless A isa AbstractBasicDimArray
12289 A = _firstdimarray (bc)
123- data = similar (_unwrap_broadcasted (bc), T, size (bc))
124- dims, refdims = slicedims (A, axes (bc))
125- return rebuild (A; data, dims, refdims, name= Symbol (" " ))
90+ similar (A; data, dims = dims (axes (bc)))
12691end
127- function Base. similar (bc:: Broadcast.Broadcasted{BasicDimensionalStyle{N}} , :: Type{T} ) where {N,T}
92+ # similar is usually only called in broadcast_preserving_zero_d
93+ function Base. similar (bc:: Broadcasted{<:DimensionalStyle{S}} , :: Type{T} ) where {S,T}
12894 A = _firstdimarray (bc)
129- data = similar (A, T, size (bc))
130- dims, refdims = slicedims (A, axes (bc))
131- return dimconstructor (dims)(data, dims; refdims, name= Symbol (" " ))
95+ data = similar (_unwrap_broadcasted (bc), T)
96+ similar (A; data, dims = dims (axes (bc)))
13297end
13398
99+ @inline function Base. materialize! (:: DimensionalStyle , dest, bc:: Broadcasted )
100+ # check dimensions
101+ bci = Broadcast. instantiate (Broadcasted (bc. style, bc. f, bc. args, axes (dest)))
102+ # unwrap before copying
103+ Base. copyto! (_unwrap_broadcasted (dest), _unwrap_broadcasted (bci))
104+ return dest
105+ end
134106
135107"""
136108 @d broadcast_expression options
@@ -407,29 +379,31 @@ end
407379# Recursively unwraps `AbstractDimArray`s and `DimensionalStyle`s.
408380# replacing the `AbstractDimArray`s with the wrapped array,
409381# and `DimensionalStyle` with the wrapped `BroadcastStyle`.
410- function _unwrap_broadcasted (bc:: Broadcasted{DimensionalStyle{S}} ) where S
382+
383+ function _unwrap_broadcasted (bc:: Broadcasted{<:DimensionalStyle{S}} ) where {S}
411384 innerargs = map (_unwrap_broadcasted, bc. args)
412- return Broadcasted {S} (bc. f, innerargs)
385+ return Broadcasted {S} (bc. f, innerargs, _unwrap_broadcasted (bc . axes) )
413386end
414387_unwrap_broadcasted (x) = x
415388_unwrap_broadcasted (nda:: AbstractDimArray ) = parent (nda)
416- _unwrap_broadcasted (boda:: BroadcastOptionsDimArray ) = parent (parent (boda))
417-
389+ _unwrap_broadcasted (bda:: AbstractBasicDimArray ) = OpaqueArray (bda)
390+ _unwrap_broadcasted (boda:: BroadcastOptionsDimArray ) = _unwrap_broadcasted (parent (boda))
391+ _unwrap_broadcasted (t:: Tuple ) = map (_unwrap_broadcasted, t)
392+ _unwrap_broadcasted (du:: Dimensions.DimUnitRange ) = parent (du)
418393# Get the first dimensional array in the broadcast
419394_firstdimarray (x:: Broadcasted ) = _firstdimarray (x. args)
420- _firstdimarray (x:: Tuple{<:AbstractBasicDimArray,Vararg} ) = x[1 ]
421- _firstdimarray (x:: AbstractBasicDimArray ) = x
422- _firstdimarray (ext:: Base.Broadcast.Extruded ) = _firstdimarray (ext. x)
423- function _firstdimarray (x:: Tuple{<:Union{Broadcasted,Base.Broadcast.Extruded},Vararg} )
395+ function _firstdimarray (x:: Tuple )
424396 found = _firstdimarray (x[1 ])
425397 if found isa Nothing
426398 _firstdimarray (tail (x))
427399 else
428400 found
429401 end
430402end
431- _firstdimarray (x:: Tuple ) = _firstdimarray (tail (x))
432403_firstdimarray (x:: Tuple{} ) = nothing
404+ _firstdimarray (ext:: Base.Broadcast.Extruded ) = _firstdimarray (ext. x)
405+ _firstdimarray (x:: AbstractBasicDimArray ) = x
406+ _firstdimarray (x) = nothing
433407
434408# Make sure all arrays have the same dims, and return them
435409_broadcasted_dims (bc:: Broadcasted ) = _broadcasted_dims (bc. args... )
0 commit comments