@@ -3666,14 +3666,31 @@ void PullbackCloner::Implementation::
3666
3666
if (originalValue != dti->getResult (0 ))
3667
3667
return ;
3668
3668
// Accumulate the array's adjoint value into the adjoint buffers of its
3669
- // element addresses: `pointer_to_address` and `index_addr` instructions.
3669
+ // element addresses: `pointer_to_address` and (optionally) `index_addr`
3670
+ // instructions.
3671
+ // The input code looks like as follows:
3672
+ // %17 = integer_literal $Builtin.Word, 1
3673
+ // function_ref _allocateUninitializedArray<A>(_:)
3674
+ // %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3675
+ // %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3676
+ // (%20, %21) = destructure_tuple %19
3677
+ // %22 = mark_dependence %21 on %20
3678
+ // %23 = pointer_to_address %22 to [strict] $*Float
3679
+ // store %0 to [trivial] %23
3680
+ // function_ref _finalizeUninitializedArray<A>(_:)
3681
+ // %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
3682
+ // %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
3683
+ // Note that %20 and %21 are in some sense "aliases" for each other. Here our `originalValue` is %20 in the code above.
3684
+ // We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
3685
+ // Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
3686
+ // If we're initializing multiple values we're having additional `index_addr` instructions, but
3687
+ // the handling is similar.
3670
3688
LLVM_DEBUG (getADDebugStream ()
3671
3689
<< " Accumulating adjoint value for array literal into element "
3672
3690
" address adjoint buffers"
3673
3691
<< originalValue);
3674
3692
auto arrayAdjoint = materializeAdjointDirect (arrayAdjointValue, loc);
3675
3693
builder.setCurrentDebugScope (remapScope (dti->getDebugScope ()));
3676
- builder.setInsertionPoint (arrayAdjoint->getParentBlock ());
3677
3694
for (auto use : dti->getResult (1 )->getUses ()) {
3678
3695
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser ());
3679
3696
assert (mdi && " Expected mark_dependence user" );
0 commit comments