Skip to content

Commit dc37e20

Browse files
authored
Fix 0-dim broadcasting bug (#174)
1 parent 8cfd7f9 commit dc37e20

File tree

5 files changed

+82
-62
lines changed

5 files changed

+82
-62
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.10.0"
4+
version = "0.10.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 65 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,43 @@ using SparseArraysBase:
2828
setunstoredindex!,
2929
storedlength
3030

31+
function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
32+
return view!(a, Tuple(index)...)
33+
end
34+
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}
35+
blocks(a)[Int.(index)...] = blocks(a)[Int.(index)...]
36+
return blocks(a)[Int.(index)...]
37+
end
38+
# Fix ambiguity error.
39+
function view!(a::AbstractArray{<:Any,0})
40+
blocks(a)[] = blocks(a)[]
41+
return blocks(a)[]
42+
end
43+
44+
function view!(a::AbstractArray{<:Any,N}, index::BlockIndexRange{N}) where {N}
45+
# TODO: Is there a better code pattern for this?
46+
indices = ntuple(N) do dim
47+
return Tuple(Block(index))[dim][index.indices[dim]]
48+
end
49+
return view!(a, indices...)
50+
end
51+
function view!(a::AbstractArray{<:Any,N}, index::Vararg{BlockIndexRange{1},N}) where {N}
52+
b = view!(a, Block.(index)...)
53+
r = map(index -> only(index.indices), index)
54+
return @view b[r...]
55+
end
56+
57+
using MacroTools: @capture
58+
is_getindex_expr(expr::Expr) = (expr.head === :ref)
59+
is_getindex_expr(x) = false
60+
macro view!(expr)
61+
if !is_getindex_expr(expr)
62+
error("@view must be used with getindex syntax (as `@view! a[i,j,...]`)")
63+
end
64+
@capture(expr, array_[indices__])
65+
return :(view!($(esc(array)), $(esc.(indices)...)))
66+
end
67+
3168
# A return type for `blocks(array)` when `array` isn't blocked.
3269
# Represents a vector with just that single block.
3370
struct SingleBlockView{N,Array<:AbstractArray{<:Any,N}} <: AbstractArray{Array,N}
@@ -568,12 +605,34 @@ function Base.getindex(a::BlockView{<:Any,N}, index::Vararg{Int,N}) where {N}
568605
return blocks(parent(a))[Int.(a.block)...][index...]
569606
end
570607
function Base.setindex!(a::BlockView{<:Any,N}, value, index::Vararg{Int,N}) where {N}
571-
I = Int.(a.block)
572-
if !isstored(blocks(parent(a)), I...)
573-
unstored_value = getunstoredindex(blocks(parent(a)), I...)
574-
setunstoredindex!(blocks(parent(a)), unstored_value, I...)
575-
end
576-
blocks(parent(a))[I...][index...] = value
608+
b = @view! parent(a)[a.block...]
609+
b[index...] = value
610+
return a
611+
end
612+
function Base.fill!(a::BlockView, value)
613+
b = @view! parent(a)[a.block...]
614+
fill!(b, value)
615+
end
616+
using Base.Broadcast: AbstractArrayStyle, Broadcasted, broadcasted
617+
materialize_blockviews(x) = x
618+
materialize_blockviews(a::BlockView) = blocks(parent(a))[Int.(a.block)...]
619+
function materialize_blockviews(bc::Broadcasted)
620+
return broadcasted(bc.f, map(materialize_blockviews, bc.args)...)
621+
end
622+
function Base.copyto!(a::BlockView, bc::Broadcasted)
623+
b = @view! parent(a)[a.block...]
624+
bc′ = materialize_blockviews(bc)
625+
copyto!(b, bc′)
626+
return a
627+
end
628+
function Base.copyto!(a::BlockView, bc::Broadcasted{<:AbstractArrayStyle{0}})
629+
b = @view! parent(a)[a.block...]
630+
copyto!(b, bc)
631+
return a
632+
end
633+
function Base.copyto!(a::BlockView, src::AbstractArray)
634+
b = @view! parent(a)[a.block...]
635+
copyto!(b, src)
577636
return a
578637
end
579638

@@ -602,43 +661,6 @@ function ArrayLayouts.sub_materialize(a::BlockView)
602661
return blocks(parent(a))[Int.(a.block)...]
603662
end
604663

605-
function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
606-
return view!(a, Tuple(index)...)
607-
end
608-
function view!(a::AbstractArray{<:Any,N}, index::Vararg{Block{1},N}) where {N}
609-
blocks(a)[Int.(index)...] = blocks(a)[Int.(index)...]
610-
return blocks(a)[Int.(index)...]
611-
end
612-
# Fix ambiguity error.
613-
function view!(a::AbstractArray{<:Any,0})
614-
blocks(a)[] = blocks(a)[]
615-
return blocks(a)[]
616-
end
617-
618-
function view!(a::AbstractArray{<:Any,N}, index::BlockIndexRange{N}) where {N}
619-
# TODO: Is there a better code pattern for this?
620-
indices = ntuple(N) do dim
621-
return Tuple(Block(index))[dim][index.indices[dim]]
622-
end
623-
return view!(a, indices...)
624-
end
625-
function view!(a::AbstractArray{<:Any,N}, index::Vararg{BlockIndexRange{1},N}) where {N}
626-
b = view!(a, Block.(index)...)
627-
r = map(index -> only(index.indices), index)
628-
return @view b[r...]
629-
end
630-
631-
using MacroTools: @capture
632-
is_getindex_expr(expr::Expr) = (expr.head === :ref)
633-
is_getindex_expr(x) = false
634-
macro view!(expr)
635-
if !is_getindex_expr(expr)
636-
error("@view must be used with getindex syntax (as `@view! a[i,j,...]`)")
637-
end
638-
@capture(expr, array_[indices__])
639-
return :(view!($(esc(array)), $(esc.(indices)...)))
640-
end
641-
642664
# SVD additions
643665
# -------------
644666
using LinearAlgebra: Algorithm

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ end
353353
# TODO: Maybe use `map` over `blocks(a)` or something
354354
# like that.
355355
for b in BlockRange(a)
356-
a[b] .= value
356+
fill!(@view!(a[b]), value)
357357
end
358358
return a
359359
end

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ end
6363
# which is logic that is handled by `fill!`.
6464
function copyto_blocksparse!(dest::AbstractArray, bc::Broadcasted{<:AbstractArrayStyle{0}})
6565
# `[]` is used to unwrap zero-dimensional arrays.
66-
value = @allowscalar bc.f(bc.args...)[]
66+
bcf = Broadcast.flatten(bc)
67+
value = @allowscalar bcf.f(map(arg -> arg[], bcf.args)...)
6768
return @interface BlockSparseArrayInterface() fill!(dest, value)
6869
end
6970

test/test_map.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using DerivableInterfaces: zero!
1414
using GPUArraysCore: @allowscalar
1515
using JLArrays: JLArray
1616
using SparseArraysBase: storedlength
17+
using StableRNGs: StableRNG
1718
using Test: @test, @test_broken, @test_throws, @testset
1819

1920
elts = (Float32, Float64, ComplexF32)
@@ -31,14 +32,7 @@ arrayts = (Array, JLArray)
3132
@test blockstoredlength(a) == 2
3233
@test storedlength(a) == 2 * 4 + 3 * 3
3334

34-
# TODO: Broken on GPU.
35-
if arrayt Array
36-
a = dev(BlockSparseArray{elt}(undef, [2, 3], [3, 4]))
37-
@test_broken a[Block(1, 2)] .= 2
38-
end
39-
40-
# TODO: Broken on GPU.
41-
a = BlockSparseArray{elt}(undef, [2, 3], [3, 4])
35+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [3, 4]))
4236
a[Block(1, 2)] .= 2
4337
@test eltype(a) == elt
4438
@test all(==(2), a[Block(1, 2)])
@@ -48,14 +42,7 @@ arrayts = (Array, JLArray)
4842
@test blockstoredlength(a) == 1
4943
@test storedlength(a) == 2 * 4
5044

51-
# TODO: Broken on GPU.
52-
if arrayt Array
53-
a = dev(BlockSparseArray{elt}(undef, [2, 3], [3, 4]))
54-
@test_broken a[Block(1, 2)] .= 0
55-
end
56-
57-
# TODO: Broken on GPU.
58-
a = BlockSparseArray{elt}(undef, [2, 3], [3, 4])
45+
a = dev(BlockSparseArray{elt}(undef, [2, 3], [3, 4]))
5946
a[Block(1, 2)] .= 0
6047
@test eltype(a) == elt
6148
@test iszero(a[Block(1, 1)])
@@ -83,6 +70,16 @@ arrayts = (Array, JLArray)
8370
@test blocktype(a′) <: arrayt{Float32,3}
8471
@test axes(a′) == (blockedrange([2, 4]), blockedrange([2, 5]), blockedrange([2, 2]))
8572

73+
# Regression test for 0-dimensional in-place broadcasting.
74+
rng = StableRNG(123)
75+
a = dev(BlockSparseArray{elt}(undef))
76+
@allowscalar a[] = randn(rng, elt)
77+
b = dev(BlockSparseArray{elt}(undef))
78+
@allowscalar b[] = randn(rng, elt)
79+
c = similar(a)
80+
c .= 2 .* a .+ 3 .* b
81+
@allowscalar @test c[] == 2 * a[] + 3 * b[]
82+
8683
a = dev(BlockSparseArray{elt}(undef, ([2, 3], [3, 4])))
8784
@views for b in [Block(1, 2), Block(2, 1)]
8885
a[b] = dev(randn(elt, size(a[b])))

0 commit comments

Comments
 (0)