Skip to content

Commit ea96c2d

Browse files
committed
Correct propagate adjoints for array literal values
Fixes #81607
1 parent 742a96d commit ea96c2d

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,14 +3649,31 @@ void PullbackCloner::Implementation::
36493649
if (originalValue != dti->getResult(0))
36503650
return;
36513651
// Accumulate the array's adjoint value into the adjoint buffers of its
3652-
// element addresses: `pointer_to_address` and `index_addr` instructions.
3652+
// element addresses: `pointer_to_address` and (optionally) `index_addr`
3653+
// instructions.
3654+
// The input code looks like as follows:
3655+
// %17 = integer_literal $Builtin.Word, 1
3656+
// function_ref _allocateUninitializedArray<A>(_:)
3657+
// %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3658+
// %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3659+
// (%20, %21) = destructure_tuple %19
3660+
// %22 = mark_dependence %21 on %20
3661+
// %23 = pointer_to_address %22 to [strict] $*Float
3662+
// store %0 to [trivial] %23
3663+
// function_ref _finalizeUninitializedArray<A>(_:)
3664+
// %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
3665+
// %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
3666+
// Note that %20 and %21 in some sense "aliases" each other. Here our `originalValue` is %20 in the code above.
3667+
// We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
3668+
// Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
3669+
// If we're initializing multiple values we're having additional `index_addr` instructions, but
3670+
// the handling is similar.
36533671
LLVM_DEBUG(getADDebugStream()
36543672
<< "Accumulating adjoint value for array literal into element "
36553673
"address adjoint buffers"
36563674
<< originalValue);
36573675
auto arrayAdjoint = materializeAdjointDirect(arrayAdjointValue, loc);
36583676
builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
3659-
builder.setInsertionPoint(arrayAdjoint->getParentBlock());
36603677
for (auto use : dti->getResult(1)->getUses()) {
36613678
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
36623679
assert(mdi && "Expected mark_dependence user");
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
3+
// https://github.com/swiftlang/swift/issues/81607
4+
// Ensure we're propagating array adjoint in the correct BB
5+
6+
import _Differentiation
7+
8+
@differentiable(reverse)
9+
func sum(_ a: Float, _ b: [Float]) -> [Float] {
10+
if b.count != 0 {
11+
return b
12+
}
13+
return [a]
14+
}

0 commit comments

Comments
 (0)