Fix Zygote gradient through scaled concretization#371
Open
AshtonSBradley wants to merge 5 commits intoSciML:masterfrom
Open
Fix Zygote gradient through scaled concretization#371AshtonSBradley wants to merge 5 commits intoSciML:masterfrom
AshtonSBradley wants to merge 5 commits intoSciML:masterfrom
Conversation
| (λ L)*(v) = λ * L(v) | ||
| """ | ||
| struct ScaledOperator{ | ||
| mutable struct ScaledOperator{ |
Member
There was a problem hiding this comment.
this will cause some allocations, seems like not the right solution.
Contributor
Author
There was a problem hiding this comment.
I hope this is getting closer
dd10804 to
10e831a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
update_coefficients(...) |> concretizefor a scaled operator could produce a doubled gradient compared with the equivalent matrix expression.ScaledOperatorimmutable, addressing the allocation concern from review.ScalarOperatorstate into a private immutable scalar snapshot for the out-of-place scaled update path.ArrayInterfacelower bound so downgrade CI resolves a test-compatibleAdaptversion withALLOW_RERESOLVE=false.Fixes #305.
Explanation
Issue #305 reduces to a small operator expression of the form
MatrixOperator(A1) + ScalarOperator(...) * MatrixOperator(A2). The primal value fromupdate_coefficients(...) |> concretizematches the equivalent dense matrix expression, but Zygote sees an extra sensitivity contribution through the scaled-operator update path and returns a doubled gradient.The first version of this PR avoided that by making
ScaledOperatormutable, but review correctly pointed out that this is allocation-sensitive and not the right tradeoff. The revised fix keepsScaledOperatorimmutable. Instead, after the scalar coefficient has been updated out-of-place, the updatedScalarOperatoris converted to a private immutable_UpdatedScalarOperatorsnapshot before reconstructing theScaledOperator.That preserves the existing mutable
ScalarOperatorbehavior for normal in-place coefficient updates while giving Zygote an immutable scalar value in the out-of-place path that is used by concretization. The in-placeupdate_coefficients!(::ScaledOperator, ...)behavior is left unchanged, and this does not add new public API or dependencies.The downgrade CI failure was a resolver conflict before tests ran: the root lower-bound environment selected
Adapt v4.0.0, while the test dependency stack required compatibility withAdapt v4.5.2-4. Raising the package lower bound toArrayInterface = "7.24"makes the downgraded root environment select anAdaptversion compatible with the tests while keeping the downgrade check meaningful.Tests
1.0.julia --project=test -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); using Test; include("test/zygote.jl"); include("test/ad_semantics.jl")'Zygote update_coefficients concretize scaled operator | 2 passedAD semantic equivalence | 10 passedjulia --project=@runic -e 'using Runic; exit(Runic.main(ARGS))' -- --check src/basic.jl src/scalar.jl test/runtests.jl test/ad_semantics.jl Project.tomljulia --project=. -e 'using Pkg; Pkg.test()'SciMLOperators | 896 passed, 2 broken, 898 totaljulia +1.10 --project=. -e 'import Pkg; Pkg.test(; coverage=true, julia_args=["--check-bounds=yes", "--compiled-modules=yes", "--depwarn=yes"], force_latest_compatible_version=false, allow_reresolve=false)'SciMLOperators | 896 passed, 2 broken, 898 total