Skip to content

Commit a5e1728

Browse files
rrule for broadcasted cast of sparse matrix
Co-authored-by: Miha Zgubic <[email protected]>
1 parent b1daa7a commit a5e1728

File tree

4 files changed

+38
-1
lines changed

4 files changed

+38
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
ChainRulesCore = "1.12"
1717
ChainRulesTestUtils = "1.5"
1818
Compat = "3.35"
19-
FiniteDifferences = "0.12.20"
19+
FiniteDifferences = "0.12.23"
2020
IrrationalConstants = "0.1.1"
2121
JuliaInterpreter = "0.8" # latest is "0.9.1"
2222
RealDot = "0.1"

src/rulesets/SparseArrays/sparsematrix.jl

+10
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,13 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
4949

5050
return (I, V), findnz_pullback
5151
end
52+
53+
function rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractSparseArray)
54+
proj = ProjectTo(x)
55+
56+
function broadcasted_cast_sparse(Δ)
57+
return NoTangent(), NoTangent(), proj(Δ)
58+
end
59+
60+
return T.(x), broadcasted_cast_sparse
61+
end

test/rulesets/SparseArrays/sparsematrix.jl

+5
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ end
3333
= rand!(similar(V))
3434
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
3535
end
36+
37+
@testset "broadcasted cast SparseMatrixCSC" begin
38+
A = sprand(5, 5, 0.5)
39+
test_rrule(_broadcast, Float32, A, rtol=1e-4)
40+
end

test/test_helpers.jl

+22
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,28 @@ function ChainRulesCore.rrule(::typeof(make_two_vec), x)
7575
return make_two_vec(x), make_two_vec_pullback
7676
end
7777

78+
"""
79+
Helper function for testing rrule(::typeof(Broadcast.broadcasted), ...)
80+
81+
# Examples
82+
83+
```julia
84+
using SparseArrays, ChainRulesTestUtils
85+
86+
A = sprand(5, 5, 0.5)
87+
test_rrule(_broadcast, Float32, A, rtol=1e-5)
88+
```
89+
"""
90+
_broadcast(f::F, x...) where F = broadcast(f, x...)
91+
92+
# we need this due to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234
93+
function ChainRulesCore.rrule(::typeof(_broadcast), f::F, args...) where F
94+
rr = rrule(Broadcast.broadcasted, f, args...)
95+
rr === nothing && return nothing
96+
y, pb = rr
97+
Broadcast.materialize(Broadcast.instantiate(y)), pb
98+
end
99+
78100
@testset "test_helpers.jl" begin
79101

80102
@testset "Multiplier" begin

0 commit comments

Comments
 (0)