Skip to content

rrule for broadcasted cast of sparse array #586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

CarloLucibello
Copy link
Contributor

@CarloLucibello CarloLucibello commented Feb 7, 2022

This also relaxes a bit the signature for sparse constructors.

Tests fail, but maybe is just a matter of passing an appropriate output tanget to test_rrule?

Close FluxML/Zygote.jl#810

@CarloLucibello CarloLucibello changed the title rrule for findnz rrule for findnz and broadcasted cast of sparse array Feb 8, 2022
@CarloLucibello
Copy link
Contributor Author

CarloLucibello commented Feb 8, 2022

I think that the implementations of both features are correct but

  • I don't know how to test the broadcasted rule. Maybe just test with Zygote via test_rrule(..., rrule_f = ZYGOTE_SOMETHING)?
  • I cannot manage to get the tests for findnz to work, I get a very long FiniteDifference error

Sorry for adding 2 loosely related features in the same PR, I've been a bit lazy

@mcabbott
Copy link
Member

mcabbott commented Feb 8, 2022

I think JuliaDiff/ChainRulesTestUtils.jl#234 has some hints for testing broadcasted rules.

Although the rules aren't doing any mathematics, so perhaps there's not much value from finite differencing. CRTU does also try to test various types of tangent etc, but maybe this isn't working.

@CarloLucibello
Copy link
Contributor Author

CarloLucibello commented Feb 14, 2022

Pfff, why Julia 1.6 is failing now on an old test

@CarloLucibello
Copy link
Contributor Author

Should be ready now

Copy link
Member

@mzgubic mzgubic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! The findnz part LGTM, but I didn't check the broadcasting

@mzgubic
Copy link
Member

mzgubic commented Feb 15, 2022

Actually, for testing, we need the compat for FiniteDifferences to be set to 0.12.23

Comment on lines 53 to 61
function rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractSparseArray)
proj = ProjectTo(x)
function broadcasted_cast_sparse(Δ)
return NoTangent(), NoTangent(), proj(Δ)
end
return T.(x), broadcasted_cast_sparse
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the first broadcast rule in here. I'm not entirely sure how I feel about this.

T.(x) is eager, would there be downsides to leaving it lazy, like broadcasted(T, x)? That may mean more things fuse, or it may mean that subsequent Zygote broadcasting rules fail because they want an array?

Similarly proj(Δ) won't accept something like a broadcasted, if we explore fused gradients that way.

These are the sort of questions that make me unsure we want to commit to one design here. I guess having one or two won't hurt too much, we can fix them here later if needed.

Zygote's present rule fails, I believe it's trying to make a sparsearray of tuples?

julia> gradient(x -> sum(Float32.(x)), sprand(10, 0.5))
ERROR: MethodError: no method matching zero(::Type{Any})
Closest candidates are:
  zero(::Type{Union{Missing, T}}) where T at ~/.julia/dev/julia/usr/share/julia/base/missing.jl:105
  zero(::Union{Type{P}, P}) where P<:Dates.Period at ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/Dates/src/periods.jl:53
  zero(::CartesianIndex{N}) where N at ~/.julia/dev/julia/usr/share/julia/base/multidimensional.jl:106
  ...
Stacktrace:
  [1] zero(#unused#::Type{Any})
    @ Base ./missing.jl:106
  [2] _zeros_eltypes
    @ ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/SparseArrays/src/higherorderfns.jl:208 [inlined]
  [3] _noshapecheck_map(::typeof(first), ::SparseVector{Any, Int64})
    @ SparseArrays.HigherOrderFns ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/SparseArrays/src/higherorderfns.jl:164
  [4] map(f::typeof(first), A::SparseVector{Any, Int64})
    @ SparseArrays.HigherOrderFns ~/.julia/dev/julia/usr/share/julia/stdlib/v1.8/SparseArrays/src/higherorderfns.jl:147
  [5] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:187 [inlined]
  [6] _pullback(__context__::Zygote.Context, 714::typeof(Base.Broadcast.broadcasted), 715::SparseArrays.HigherOrderFns.SparseVecStyle, f::Type{Float32}, args::SparseVector{Float64, Int64})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65

But perhaps Zygote ought to have a rule for broadcasting (::Type).(x), since what's written here would work for arbitrary arrays I believe, not just sparse, and would be faster than what it does now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's ok to have the first broadcast here covering a limited use case not currently supported by Zygote and see how it goes? There are plenty of issues related to sparse arrays in the ML ecosystem, I'd like to get things going for downstream packages

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would merging the findnz and separating broadcasting out in a separate PR be an option? Since findnz is complete and broadcasting possibly requires in-depth discussion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@CarloLucibello
Copy link
Contributor Author

The non-controversial findnz part is now in #590. After merging that, I will update this PR to contain the broadcast part only

@mzgubic mzgubic changed the title rrule for findnz and broadcasted cast of sparse array rrule for broadcasted cast of sparse array Feb 18, 2022
@CarloLucibello
Copy link
Contributor Author

So this is just a very specific instance of the general rule for broadcasting we have in Zygote. Would it be ok to merge this now and go with some generic rule (the one from Zygote or something else) later?

@mzgubic
Copy link
Member

mzgubic commented Feb 21, 2022

Should we close this given FluxML/Zygote.jl#1171 or would you prefer to keep it open?

@CarloLucibello
Copy link
Contributor Author

Yeah, this is not needed anymore.

rofinn added a commit that referenced this pull request Sep 13, 2022
rofinn added a commit that referenced this pull request Sep 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

gradient calculation with explicit type cast is broken
3 participants