Skip to content

Commit 799e0f8

Browse files
authored
[AutoDiff] Correct propagate adjoints for array literal values (#81676)
Fixes #81607
1 parent 828876f commit 799e0f8

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,14 +3666,31 @@ void PullbackCloner::Implementation::
36663666
if (originalValue != dti->getResult(0))
36673667
return;
36683668
// Accumulate the array's adjoint value into the adjoint buffers of its
3669-
// element addresses: `pointer_to_address` and `index_addr` instructions.
3669+
// element addresses: `pointer_to_address` and (optionally) `index_addr`
3670+
// instructions.
3671+
// The input code looks like as follows:
3672+
// %17 = integer_literal $Builtin.Word, 1
3673+
// function_ref _allocateUninitializedArray<A>(_:)
3674+
// %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3675+
// %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3676+
// (%20, %21) = destructure_tuple %19
3677+
// %22 = mark_dependence %21 on %20
3678+
// %23 = pointer_to_address %22 to [strict] $*Float
3679+
// store %0 to [trivial] %23
3680+
// function_ref _finalizeUninitializedArray<A>(_:)
3681+
// %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
3682+
// %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
3683+
// Note that %20 and %21 are in some sense "aliases" for each other. Here our `originalValue` is %20 in the code above.
3684+
// We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
3685+
// Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
3686+
// If we're initializing multiple values we're having additional `index_addr` instructions, but
3687+
// the handling is similar.
36703688
LLVM_DEBUG(getADDebugStream()
36713689
<< "Accumulating adjoint value for array literal into element "
36723690
"address adjoint buffers"
36733691
<< originalValue);
36743692
auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc);
36753693
builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
3676-
builder.setInsertionPoint(arrayAdjoint->getParentBlock());
36773694
for (auto use : dti->getResult(1)->getUses()) {
36783695
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
36793696
assert(mdi && "Expected mark_dependence user");
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// https://github.com/swiftlang/swift/issues/81607
4+
// Ensure we're propagating array adjoint in the correct BB
5+
6+
import _Differentiation
7+
8+
@differentiable(reverse)
9+
func sum(_ a: Float, _ b: [Float]) -> [Float] {
10+
if b.count != 0 {
11+
return b
12+
}
13+
return [a]
14+
}

0 commit comments

Comments
 (0)