From 924c989704431769d1b7ebf06974803f2e15afb9 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 26 May 2024 01:15:00 +0400 Subject: [PATCH] make 3-arg dot rrule partially lazy --- Project.toml | 2 ++ src/ChainRules.jl | 1 + src/rulesets/LinearAlgebra/dense.jl | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7732e7d1d..683b80866 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" @@ -27,6 +28,7 @@ Distributed = "1" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0" IrrationalConstants = "0.1.1, 0.2" +LazyArrays = "1, 2" JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" LinearAlgebra = "1" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..5ac5c45f7 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -14,6 +14,7 @@ using RealDot: realdot using SparseArrays using Statistics using StructArrays +using LazyArrays: @~ # Basically everything this package does is overloading these, so we make an exception # to the normal rule of only overload via `ChainRulesCore.rrule`. diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index a5edd6cd5..4c69f8dad 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -35,7 +35,8 @@ function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:N function dot_pullback(Ω̄) ΔΩ = unthunk(Ω̄) dx = @thunk project_x(conj(ΔΩ) .* Ay) - dA = @thunk project_A(ΔΩ .* x .* adjoint(y)) + ay = adjoint(y) + dA = @thunk @~(ΔΩ .* x .* ay) dy = @thunk project_y(ΔΩ .* (adjoint(A) * x)) return (NoTangent(), dx, dA, dy) end