-
Notifications
You must be signed in to change notification settings - Fork 91
rrule for broadcasted cast of sparse array #586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
findnz
findnz
and broadcasted cast of sparse array
I think that the implementations of both features are correct but
Sorry for adding 2 loosely related features in the same PR, I've been a bit lazy |
I think JuliaDiff/ChainRulesTestUtils.jl#234 has some hints for testing Although the rules aren't doing any mathematics, so perhaps there's not much value from finite differencing. CRTU does also try to test various types of tangent etc, but maybe this isn't working. |
8fe3849
to
2744895
Compare
Pfff, why Julia 1.6 is failing now on an old test |
2744895
to
b2d483a
Compare
Should be ready now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! The findnz
part LGTM, but I didn't check the broadcasting
Actually, for testing, we need the compat for FiniteDifferences to be set to |
function rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractSparseArray) | ||
proj = ProjectTo(x) | ||
function broadcasted_cast_sparse(Δ) | ||
return NoTangent(), NoTangent(), proj(Δ) | ||
end | ||
return T.(x), broadcasted_cast_sparse | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the first broadcast rule in here. I'm not entirely sure how I feel about this.
T.(x)
is eager, would there be downsides to leaving it lazy, like broadcasted(T, x)
? That may mean more things fuse, or it may mean that subsequent Zygote broadcasting rules fail because they want an array?
Similarly proj(Δ)
won't accept something like a broadcasted, if we explore fused gradients that way.
These are the sort of questions that make me unsure we want to commit to one design here. I guess having one or two won't hurt too much, we can fix them here later if needed.
Zygote's present rule fails, I believe it's trying to make a sparsearray of tuples?
julia> gradient(x -> sum(Float32.(x)), sprand(10, 0.5))
ERROR: MethodError: no method matching zero(::Type{Any})
Closest candidates are:
zero(::Type{Union{Missing, T}}) where T at ~/.julia/dev/julia/usr/share/julia/base/missing.jl:105
zero(::Union{Type{P}, P}) where P<:Dates.Period at ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/Dates/src/periods.jl:53
zero(::CartesianIndex{N}) where N at ~/.julia/dev/julia/usr/share/julia/base/multidimensional.jl:106
...
Stacktrace:
[1] zero(#unused#::Type{Any})
@ Base ./missing.jl:106
[2] _zeros_eltypes
@ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/SparseArrays/src/higherorderfns.jl:208 [inlined]
[3] _noshapecheck_map(::typeof(first), ::SparseVector{Any, Int64})
@ SparseArrays.HigherOrderFns ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/SparseArrays/src/higherorderfns.jl:164
[4] map(f::typeof(first), A::SparseVector{Any, Int64})
@ SparseArrays.HigherOrderFns ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/SparseArrays/src/higherorderfns.jl:147
[5] adjoint
@ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:187 [inlined]
[6] _pullback(__context__::Zygote.Context, 714::typeof(Base.Broadcast.broadcasted), 715::SparseArrays.HigherOrderFns.SparseVecStyle, f::Type{Float32}, args::SparseVector{Float64, Int64})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
But perhaps Zygote ought to have a rule for broadcasting (::Type).(x)
, since what's written here would work for arbitrary arrays I believe, not just sparse, and would be faster than what it does now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe it's ok to have the first broadcast here covering a limited use case not currently supported by Zygote and see how it goes? There are plenty of issues related to sparse arrays in the ML ecosystem, I'd like to get things going for downstream packages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would merging the findnz
and separating broadcasting out in a separate PR be an option? Since findnz
is complete and broadcasting possibly requires in-depth discussion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
The non-controversial |
Co-authored-by: Miha Zgubic <[email protected]>
d27f942
to
a5e1728
Compare
findnz
and broadcasted cast of sparse array
So this is just a very specific instance of the general rule for broadcasting we have in Zygote. Would it be ok to merge this now and go with some generic rule (the one from Zygote or something else) later? |
Should we close this given FluxML/Zygote.jl#1171 or would you prefer to keep it open? |
Yeah, this is not needed anymore. |
This also relaxes a bit the signature for sparse constructors.
Tests fail, but maybe is just a matter of passing an appropriate output tanget totest_rrule
?Close FluxML/Zygote.jl#810