Skip to content

Commit 2744895

Browse files
test for broadcasted sparse
1 parent 37e65b9 commit 2744895

File tree

4 files changed

+30
-8
lines changed

4 files changed

+30
-8
lines changed

Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ version = "1.28.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
87
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
98
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/rulesets/SparseArrays/sparsematrix.jl

+3-7
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,12 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
5050
return (I, V), findnz_pullback
5151
end
5252

53-
function rrule(::typeof(broadcast), T::Type{<:Number}, x::AbstractSparseArray)
53+
function rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractSparseArray)
5454
proj = ProjectTo(x)
55-
function broadcast_cast_sparse(Δ)
55+
function broadcasted_cast_sparse(Δ)
5656
return NoTangent(), NoTangent(), proj(Δ)
5757
end
58-
T.(x), broadcast_cast_sparse
58+
T.(x), broadcasted_cast_sparse
5959
end
6060

61-
# These rules help with testing, and won't hurt:
62-
# They are correct as we always `collect` the primal result as we need that
63-
# information for the reverse pass
64-
# ChainRules.rrule(::typeof(SparseMatrixCSC∘broadcast), T::Type{<:Number}, x::AbstractSparseArray) = rrule(broadcast, T, x)
6561

test/rulesets/SparseArrays/sparsematrix.jl

+5
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,8 @@ end
2727
test_rrule(findnz, A dA, output_tangent=(zeros(length(I)), zeros(length(J)), V̄))
2828
end
2929

30+
@testset "broadcasted cast SparseMatrixCSC" begin
31+
A = sprand(5, 5, 0.5)
32+
test_rrule(_broadcast, Float32, A, rtol=1e-5)
33+
end
34+

test/test_helpers.jl

+22
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,28 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x)
7575
return make_two_vec(x), make_two_vec_pullback
7676
end
7777

78+
"""
79+
Helper function for testing rrule(::typeof(Broadcast.broadcasted), ...)
80+
81+
# Examples
82+
83+
```julia
84+
using SparseArrays, ChainRulesTestUtils
85+
86+
A = sprand(5, 5, 0.5)
87+
test_rrule(_broadcast, Float32, A, rtol=1e-5)
88+
```
89+
"""
90+
_broadcast(f::F, x...) where F = broadcast(f, x...)
91+
92+
# we need this due to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234
93+
function ChainRulesCore.rrule(::typeof(_broadcast), f::F, args...) where F
94+
rr = rrule(Broadcast.broadcasted, f, args...)
95+
rr === nothing && return nothing
96+
y, pb = rr
97+
Broadcast.materialize(Broadcast.instantiate(y)), pb
98+
end
99+
78100
@testset "test_helpers.jl" begin
79101

80102
@testset "Multiplier" begin

0 commit comments

Comments
 (0)