Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Dictionaries = "0.4"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.6"
MatrixAlgebraKit = "0.6.7"
Mooncake = "0.5.27"
OhMyThreads = "0.8.0"
Printf = "1"
Expand Down
45 changes: 43 additions & 2 deletions src/factorizations/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ for pullback! in (:qr_null_pullback!, :lq_null_pullback!)
return Δt
end
end

_notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t))

for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!)
Expand All @@ -51,8 +50,50 @@ for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_
foreachblock(Δt, t) do c, (Δb, b)
Fc = block.(F, Ref(c))
ΔFc = block.(ΔF, Ref(c))
return MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
return nothing
end
return Δt
end
end

for f in (:qr, :lq)
remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!)
@eval function MAK.$remove_f_gauge_dependence!(
ΔF₁::AbstractTensorMap, ΔF₂::AbstractTensorMap, A, F₁, F₂;
kwargs...
)
foreachblock(ΔF₁, ΔF₂, A, F₁, F₂) do _, (Δf₁, Δf₂, a, f₁, f₂)
MAK.$remove_f_gauge_dependence!(Δf₁, Δf₂, a, f₁, f₂; kwargs...)
return nothing
end
return ΔF₁, ΔF₂
end
# Already captured by MAK implementation
# @eval function MAK.$remove_f_null_gauge_dependence!(ΔN::AbstractTensorMap, A, N; kwargs...)
# foreachblock(ΔN, A, N) do _, (Δn, a, n)
# $remove_f_gauge_dependence!(Δn, a, n)
# end
# return ΔN
# end
end

for f in (:eig, :eigh)
remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!)
@eval function MAK.$remove_f_gauge_dependence!(ΔV::AbstractTensorMap, D, V; kwargs...)
foreachblock(ΔV, D, V) do c, (Δv, d, v)
MAK.$remove_f_gauge_dependence!(Δv, d, v; kwargs...)
return nothing
end
return ΔV
end
end
function MAK.remove_svd_gauge_dependence!(
ΔU::AbstractTensorMap, ΔVᴴ::AbstractTensorMap, U, S, Vᴴ; kwargs...
)
foreachblock(ΔU, ΔVᴴ, U, S, Vᴴ) do c, (Δu, Δvᴴ, u, s, vᴴ)
MAK.remove_svd_gauge_dependence!(Δu, Δvᴴ, u, s, vᴴ; kwargs...)
return nothing
end
return ΔU, ΔVᴴ
end
15 changes: 8 additions & 7 deletions test/chainrules/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using LinearAlgebra
using Zygote
using MatrixAlgebraKit
using MatrixAlgebraKit: diagview

using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!

# Tests
# -----
Expand Down Expand Up @@ -52,7 +53,7 @@ for V in spacelist
@test_logs (:warn, r"^`qr") match_mode = :any full_pb((ΔQ, ΔR))
end

remove_qrgauge_dependence!(ΔQ, t, Q)
remove_qr_gauge_dependence!(ΔQ, ΔR, t, Q, R)

test_ad_rrule(qr_full, t; fkwargs, atol, rtol, output_tangent = (ΔQ, ΔR))
test_ad_rrule(
Expand Down Expand Up @@ -90,7 +91,7 @@ for V in spacelist
# @test_logs (:warn, r"^`lq") match_mode = :any full_pb((ΔL, ΔQ))
end

remove_lqgauge_dependence!(ΔQ, t, Q)
remove_lq_gauge_dependence!(ΔL, ΔQ, t, L, Q)

test_ad_rrule(lq_full, t; fkwargs, atol, rtol, output_tangent = (ΔL, ΔQ))
test_ad_rrule(
Expand All @@ -114,7 +115,7 @@ for V in spacelist
Δv = rand_tangent(v)
Δd = rand_tangent(d)
Δd2 = randn!(similar(d, space(d)))
remove_eiggauge_dependence!(Δv, d, v)
remove_eig_gauge_dependence!(Δv, d, v)

test_ad_rrule(eig_full, t; output_tangent = (Δd, Δv), atol, rtol)
test_ad_rrule(first ∘ eig_full, t; output_tangent = Δd, atol, rtol)
Expand All @@ -126,7 +127,7 @@ for V in spacelist
Δv = rand_tangent(v)
Δd = rand_tangent(d)
Δd2 = randn!(similar(d, space(d)))
remove_eighgauge_dependence!(Δv, d, v)
remove_eigh_gauge_dependence!(Δv, d, v)

# necessary for FiniteDifferences to not complain
eigh_full′ = eigh_full ∘ project_hermitian
Expand Down Expand Up @@ -155,7 +156,7 @@ for V in spacelist
USVᴴ = svd_compact(t)
ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ)
ΔS2 = randn!(similar(ΔS, space(ΔS)))
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol)
ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol)

# test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol)
# test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol)
Expand All @@ -170,7 +171,7 @@ for V in spacelist
trunc = truncspace(V_trunc)
USVᴴ_trunc = svd_trunc(t; trunc)
ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc)))
remove_svdgauge_dependence!(
remove_svd_gauge_dependence!(
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
)
test_ad_rrule(
Expand Down
26 changes: 14 additions & 12 deletions test/mooncake/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using TensorKit
using TensorOperations
using VectorInterface: Zero, One
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!
using Mooncake
using Random

Expand All @@ -25,7 +27,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
QR = qr_full(A)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -37,7 +39,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
QR = qr_full(A)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -51,7 +53,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -63,7 +65,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -73,13 +75,13 @@ eltypes = (Float64, ComplexF64)
for t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]))
DV = eig_full(t)
ΔDV = Mooncake.randn_tangent(rng, DV)
remove_eiggauge_dependence!(ΔDV[2], DV...)
remove_eig_gauge_dependence!(ΔDV[2], DV...)
Mooncake.TestUtils.test_rule(rng, eig_full, t; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false)

th = project_hermitian(t)
DV = eigh_full(th)
ΔDV = Mooncake.randn_tangent(rng, DV)
remove_eighgauge_dependence!(ΔDV[2], DV...)
remove_eigh_gauge_dependence!(ΔDV[2], DV...)
Mooncake.TestUtils.test_rule(rng, eigh_full ∘ project_hermitian, th; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false)
end
end
Expand All @@ -88,20 +90,20 @@ eltypes = (Float64, ComplexF64)
for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])'))
USVᴴ = svd_compact(t)
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)

# USVᴴ = svd_full(t)
# ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
# remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
# Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)
USVᴴ = svd_full(t)
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)

V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc)
USVᴴtrunc = svd_trunc(t, alg)
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode)
end
end
Expand Down
74 changes: 0 additions & 74 deletions test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ export random_fusion
export sectorlist, fast_sectorlist
# export dim_isapprox
export default_spacelist, factorization_spacelist, ad_spacelist
export remove_qrgauge_dependence!, remove_lqgauge_dependence!
export remove_eiggauge_dependence!, remove_eighgauge_dependence!, remove_svdgauge_dependence!
export test_ad_rrule
export _isunitary, _isone

Expand Down Expand Up @@ -398,78 +396,6 @@ function ad_spacelist(fast_tests::Bool)
return fast_tests ? (Vtr, VRepU₁, VfHubbard, VRepA4Twistedℤ₄) : (Vtr, VRepℤ₂, VRepCU₁, VfHubbard, VRepA4Twistedℤ₄, VIBMRepA4)
end

# Gauge-fixing tangents for AD factorization tests
# -------------------------------------------------
function remove_qrgauge_dependence!(ΔQ, t, Q)
for (c, b) in blocks(ΔQ)
m, n = size(block(t, c))
minmn = min(m, n)
Qc = block(Q, c)
Q1 = view(Qc, 1:m, 1:minmn)
ΔQ2 = view(b, :, (minmn + 1):m)
mul!(ΔQ2, Q1, Q1' * ΔQ2)
end
return ΔQ
end
function remove_lqgauge_dependence!(ΔQ, t, Q)
for (c, b) in blocks(ΔQ)
m, n = size(block(t, c))
minmn = min(m, n)
Qc = block(Q, c)
Q1 = view(Qc, 1:minmn, 1:n)
ΔQ2 = view(b, (minmn + 1):n, :)
mul!(ΔQ2, ΔQ2 * Q1', Q1)
end
return ΔQ
end
function remove_eiggauge_dependence!(
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
)
gaugepart = V' * ΔV
for (c, b) in blocks(gaugepart)
Dc = diagview(block(D, c))
# for some reason this fails only on tests, and I cannot reproduce it in an
# interactive session.
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
for j in axes(b, 2), i in axes(b, 1)
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
end
end
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
)
gaugepart = project_antihermitian!(V' * ΔV)
for (c, b) in blocks(gaugepart)
Dc = diagview(block(D, c))
# for some reason this fails only on tests, and I cannot reproduce it in an
# interactive session.
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
for j in axes(b, 2), i in axes(b, 1)
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
end
end
mul!(ΔV, V, gaugepart, -1, 1)
return ΔV
end
function remove_svdgauge_dependence!(
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S)
)
gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ')
for (c, b) in blocks(gaugepart)
Sd = diagview(block(S, c))
# for some reason this fails only on tests, and I cannot reproduce it in an
# interactive session.
# b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0
for j in axes(b, 2), i in axes(b, 1)
abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0)
end
end
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end

# ChainRules test utilities
# -------------------------
Expand Down
Loading