Skip to content

Improvements to rules for norm #337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 10, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.63"
version = "0.7.64"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
149 changes: 90 additions & 59 deletions src/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x)
y = norm(x)
return y, _norm2_forward(x, Δx, norm(x))
end

function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
y = norm(x, p)
∂y = if iszero(Δx) || iszero(p)
@@ -17,15 +18,12 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
return y, ∂y
end

function rrule(
::typeof(norm),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
p::Real,
)
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
y = LinearAlgebra.norm(x, p)
function norm_pullback(Δy)
∂x = Thunk() do
return if isempty(x) || p == 0
function norm_pullback_p(Δy)
∂x = InplaceableThunk(
# out-of-place versions
@thunk(if isempty(x) || p == 0
zero.(x) .* (zero(y) * zero(real(Δy)))
elseif p == 2
_norm2_back(x, y, Δy)
@@ -37,35 +35,52 @@ function rrule(
_normInf_back(x, y, Δy)
else
_normp_back_x(x, p, y, Δy)
end)
, # in-place versions -- can be fixed when actually useful?
dx -> if isempty(x) || p == 0
dx
elseif p == 2
_norm2_back!(dx, x, y, Δy)
elseif p == 1
_norm1_back!(dx, x, y, Δy)
elseif p == Inf
dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
elseif p == -Inf
dx .+= _normInf_back(x, y, Δy)
else
dx .+= _normp_back_x(x, p, y, Δy)
end
end
)
∂p = @thunk _normp_back_p(x, p, y, Δy)
return (NO_FIELDS, ∂x, ∂p)
end
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, norm_pullback
norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, norm_pullback_p
end
function rrule(
::typeof(norm),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
)

function rrule(::typeof(norm), x::AbstractArray{<:Number})
y = LinearAlgebra.norm(x)
function norm_pullback(Δy)
∂x = if isempty(x)
zero.(x) .* (zero(y) * zero(real(Δy)))
else
_norm2_back(x, y, Δy)
end
function norm_pullback_2(Δy)
∂x = InplaceableThunk(
@thunk(if isempty(x)
zero.(x) .* (zero(y) * zero(real(Δy)))
else
_norm2_back(x, y, Δy)
end)
,
dx -> if isempty(x)
dx
else
_norm2_back!(dx, x, y, Δy)
end
)
return (NO_FIELDS, ∂x)
end
norm_pullback(::Zero) = (NO_FIELDS, Zero())
return y, norm_pullback
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())
return y, norm_pullback_2
end
function rrule(
::typeof(norm),
x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec},
p::Real,
)

function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real)
y, inner_pullback = rrule(norm, parent(x), p)
function norm_pullback(Δy)
(∂self, ∂x′, ∂p) = inner_pullback(Δy)
@@ -75,6 +90,7 @@ function rrule(
end
return y, norm_pullback
end

function rrule(::typeof(norm), x::Number, p::Real)
y = norm(x, p)
function norm_pullback(Δy)
@@ -94,11 +110,7 @@ end
##### `normp`
#####

function rrule(
::typeof(LinearAlgebra.normp),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
p,
)
function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p)
y = LinearAlgebra.normp(x, p)
function normp_pullback(Δy)
∂x = @thunk _normp_back_x(x, p, y, Δy)
@@ -111,15 +123,24 @@ end

function _normp_back_x(x, p, y, Δy)
c = real(Δy) / y
∂x = similar(x)
broadcast!(∂x, x) do xi
∂x = map(x) do xi
a = norm(xi)
∂xi = xi * ((a / y)^(p - 2) * c)
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
end
return ∂x
end

function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc.
c = real(Δy) / y
∂x_data = map(parent(x)) do xi
a = norm(xi)
∂xi = xi * ((a / y)^(p - 2) * c)
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
end
return withsomezeros_rewrap(x, ∂x_data)
end

function _normp_back_p(x, p, y, Δy)
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
s = sum(x) do xi
@@ -135,20 +156,14 @@ end
##### `normMinusInf`/`normInf`
#####

function rrule(
::typeof(LinearAlgebra.normMinusInf),
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
)
function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number})
y = LinearAlgebra.normMinusInf(x)
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
return y, normMinusInf_pullback
end

function rrule(
::typeof(LinearAlgebra.normInf),
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
)
function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number})
y = LinearAlgebra.normInf(x)
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
@@ -172,19 +187,26 @@ end
##### `norm1`
#####

function rrule(
::typeof(LinearAlgebra.norm1),
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
)
function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number})
y = LinearAlgebra.norm1(x)
norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy))
norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
@thunk(_norm1_back(x, y, Δy)),
dx -> _norm1_back!(dx, x, y, Δy),
))
norm1_pullback(::Zero) = (NO_FIELDS, Zero())
return y, norm1_pullback
end

function _norm1_back(x, y, Δy)
∂x = similar(x)
∂x .= sign.(x) .* real(Δy)
∂x = sign.(x) .* real(Δy)
return ∂x
end
function _norm1_back(x::WithSomeZeros, y, Δy)
∂x_data = sign.(parent(x)) .* real(Δy)
return withsomezeros_rewrap(x, ∂x_data)
end
function _norm1_back!(∂x, x, y, Δy)
∂x .+= sign.(x) .* real(Δy)
return ∂x
end

@@ -197,12 +219,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
return y, _norm2_forward(x, Δx, y)
end

function rrule(
::typeof(LinearAlgebra.norm2),
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
)
function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number})
y = LinearAlgebra.norm2(x)
norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy))
norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
@thunk(_norm2_back(x, y, Δy)),
dx -> _norm2_back!(dx, x, y, Δy),
))
norm2_pullback(::Zero) = (NO_FIELDS, Zero())
return y, norm2_pullback
end
@@ -212,16 +234,24 @@ function _norm2_forward(x, Δx, y)
return ∂y
end
function _norm2_back(x, y, Δy)
∂x = similar(x)
∂x .= x .* (real(Δy) * pinv(y))
∂x = x .* (real(Δy) * pinv(y))
return ∂x
end
function _norm2_back(x::WithSomeZeros, y, Δy)
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
∂x_data = parent(x) .* (real(Δy) * pinv(y))
return withsomezeros_rewrap(x, ∂x_data)
end
function _norm2_back!(∂x, x, y, Δy)
∂x .+= x .* (real(Δy) * pinv(y))
return ∂x # must return after mutating
end

#####
##### `normalize`
#####

function rrule(::typeof(normalize), x::AbstractVector, p::Real)
function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
nrm, inner_pullback = rrule(norm, x, p)
Ty = typeof(first(x) / nrm)
y = copyto!(similar(x, Ty), x)
@@ -236,7 +266,8 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real)
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
return y, normalize_pullback
end
function rrule(::typeof(normalize), x::AbstractVector)

function rrule(::typeof(normalize), x::AbstractVector{<:Number})
nrm = LinearAlgebra.norm2(x)
Ty = typeof(first(x) / nrm)
y = copyto!(similar(x, Ty), x)
33 changes: 33 additions & 0 deletions src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
@@ -43,3 +43,36 @@ Symmetric
````
"""
_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper

"""
WithSomeZeros{T}

This is a union of LinearAlgebra types, all of which are partly structral zeros,
with a simple backing array given by `parent(x)`. All have methods of `_rewrap`
to re-create.

This exists to solve a type instability, as broadcasting for instance
`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`.
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable.
"""
WithSomeZeros{T} = Union{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call these StructuredSparseArray

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You approve of the mechanism, #337 (comment)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am willing to give it a shot.
We can always change it later.
It's not going to lead to wrong behavour AFAICT.

It seems unfortunate not to take advantage of the fact that we know where the zeros are,
and we know that the pullback is going to map zeros to zeros, since linear.
So we should be able to skip some.
But idk that that is a generic API for our structurally sparse matrixes to know if an index will be zero.

Copy link
Member Author

@mcabbott mcabbott May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstand you, but both λ .* Diagonal(rand(3)) and this function do know where the zeros are, and do O(N) work. That's the only really sparse one.

For UpperTriangular, I haven't tried to time this against broadcasting... there could be trade-offs, maybe broadcasting skips half, but if so it needs lots of if statements. Frankly I doubt that anyone has ever called norm(::UpperTriangular) outside a test, though. So perhaps thinking about that can wait until this finds wider use where someone does need to care.

Copy link
Member Author

@mcabbott mcabbott May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be good to fix this instability upstream. Can't we argue that the off-diagonal elements are a strong zero like false, and make NaN .* Diagonal(rand(3)) just work? Is there an issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like all structural zeros should be strong yes.
I was sure I had seem julia displaying that behavour on SparseCSC matrixes, but I can't reproduce it right now.

Diagonal{T},
UpperTriangular{T},
UnitUpperTriangular{T},
# UpperHessenberg{T}, # doesn't exist in Julia 1.0
LowerTriangular{T},
UnitLowerTriangular{T},
}
for S in [
:Diagonal,
:UpperTriangular,
:UnitUpperTriangular,
# :UpperHessenberg,
:LowerTriangular,
:UnitLowerTriangular,
]
@eval withsomezeros_rewrap(::$S, x) = $S(x)
end

# Bidiagonal, Tridiagonal have more complicated storage.
# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))
67 changes: 56 additions & 11 deletions test/rulesets/LinearAlgebra/norm.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
@testset "norm functions" begin

# First test the un-exported functions which norm(A,p) calls
# ==========================================================

@testset "$fnorm(x::Array{$T,$(length(sz))})" for
fnorm in (
LinearAlgebra.norm1,
@@ -23,8 +27,10 @@
kwargs = NamedTuple()
end

fnorm === LinearAlgebra.norm2 && @testset "frule" begin
test_frule(fnorm, x)
if fnorm === LinearAlgebra.norm2
@testset "frule" begin
test_frule(fnorm, x)
end
end
@testset "rrule" begin
test_rrule(fnorm, x; kwargs...)
@@ -36,7 +42,27 @@
@test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) zero(x)
@test rrule(fnorm, x)[2](Zero())[2] isa Zero
end
ndims(x) > 1 && @testset "non-strided" begin
xp = if x isa Matrix
view(x, [1,2,3], 1:3)
elseif x isa Array{T,3}
PermutedDimsArray(x, (1,2,3))
end
@test !(xp isa StridedArray)
test_rrule(fnorm, xp rand(T, size(xp)))
end
T == Float64 && ndims(x) == 1 && @testset "Integer input" begin
x = [1,2,3]
int_fwd, int_back = rrule(fnorm, x)
float_fwd, float_back = rrule(fnorm, float(x))
@test int_fwd float_fwd
@test unthunk(int_back(1.0)[2]) unthunk(float_back(1.0)[2])
end
end

# Next test norm(A, p=2) -- two methods
# =====================================

@testset "norm(x::Array{$T,$(length(sz))})" for
T in (Float64, ComplexF64),
sz in [(0,), (3,), (3, 3), (3, 2, 1)]
@@ -53,15 +79,23 @@
@testset "rrule" begin
test_rrule(norm, x)
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
# we don't check inference on older julia versions. Improvements to
# inference mean on 1.5+ it works, and that is good enough
test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5")
end

= rand_tangent(norm(x))
@test extern(rrule(norm, zero(x))[2](ȳ)[2]) zero(x)
@test rrule(norm, x)[2](Zero())[2] isa Zero
end
ndims(x) > 1 && @testset "non-strided" begin
xp = if x isa Matrix
view(x, [1,2,3], 1:3)
elseif x isa Array{T,3}
PermutedDimsArray(x, (1,2,3))
end
@test !(xp isa StridedArray)
test_frule(norm, xp rand(T, size(xp)))
test_rrule(norm, xp rand(T, size(xp))) # rand_tangent does not work here because eltype(xp)==Int
end
end
@testset "$fnorm(x::Array{$T,$(length(sz))}, $p) with size $sz" for
fnorm in (norm, LinearAlgebra.normp),
@@ -71,7 +105,7 @@

x = randn(T, sz)
# finite differences is unstable if maxabs (minabs) values are not well
# separated from other values
# separated from other values (same as above)
if p == Inf
if !isempty(x)
x[end] = 1000rand(T)
@@ -87,20 +121,24 @@
kwargs = NamedTuple()
end


test_rrule(fnorm, x, p; kwargs...)
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
test_rrule(fnorm, MT(x), p;
#Don't check inference on old julia, what matters is that works on new
check_inferred=VERSION>=v"1.5", kwargs...
)
test_rrule(fnorm, MT(x), p; kwargs..., check_inferred=VERSION>=v"1.5")
end

= rand_tangent(fnorm(x, p))
@test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) zero(x)
@test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
T == Float64 && sz == (3,) && @testset "Integer input, p=$p" begin
x = [1,2,3]
int_fwd, int_back = rrule(fnorm, x, p)
float_fwd, float_back = rrule(fnorm, float(x), p)
@test int_fwd float_fwd
@test unthunk(int_back(1.0)[2]) unthunk(float_back(1.0)[2])
end
end
@testset "norm($fdual(::Vector{$T}), p)" for
# Extra test for norm(adjoint vector, p)
@testset "norm($fdual(::Vector{$T}), 2.5)" for
T in (Float64, ComplexF64),
fdual in (adjoint, transpose)

@@ -111,6 +149,10 @@
= rand_tangent(norm(x, p))
@test extern(rrule(norm, x, p)[2](ȳ)[2]) isa typeof(x)
end

# Scalar norm(x, p)
# =================

@testset "norm(x::$T, p)" for T in (Float64, ComplexF64)
@testset "p = $p" for p in (-1.0, 2.0, 2.5)
test_frule(norm, randn(T), p)
@@ -136,6 +178,9 @@
end
end

# normalise(x, p) and normalise(A, p)
# ===================================

@testset "normalize" begin
@testset "x::Vector{$T}" for T in (Float64, ComplexF64)
x = randn(T, 3)