@@ -3649,14 +3649,31 @@ void PullbackCloner::Implementation::
3649
3649
if (originalValue != dti->getResult (0 ))
3650
3650
return ;
3651
3651
// Accumulate the array's adjoint value into the adjoint buffers of its
3652
- // element addresses: `pointer_to_address` and `index_addr` instructions.
3652
+ // element addresses: `pointer_to_address` and (optionally) `index_addr`
3653
+ // instructions.
3654
+ // The input code looks like as follows:
3655
+ // %17 = integer_literal $Builtin.Word, 1
3656
+ // function_ref _allocateUninitializedArray<A>(_:)
3657
+ // %18 = function_ref @$ss27_allocateUninitializedArrayySayxG_BptBwlF : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3658
+ // %19 = apply %18<Float>(%17) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
3659
+ // (%20, %21) = destructure_tuple %19
3660
+ // %22 = mark_dependence %21 on %20
3661
+ // %23 = pointer_to_address %22 to [strict] $*Float
3662
+ // store %0 to [trivial] %23
3663
+ // function_ref _finalizeUninitializedArray<A>(_:)
3664
+ // %25 = function_ref @$ss27_finalizeUninitializedArrayySayxGABnlF : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0>
3665
+ // %26 = apply %25<Float>(%20) : $@convention(thin) <τ_0_0> (@owned Array<τ_0_0>) -> @owned Array<τ_0_0> // user: %27
3666
+ // Note that %20 and %21 in some sense "aliases" each other. Here our `originalValue` is %20 in the code above.
3667
+ // We need to trace from %21 down to %23 and propagate (decomposed) adjoint of originalValue to adjoint of %23.
3668
+ // Then the generic adjoint propagation code would do its job to propagate %23' to %0'.
3669
+ // If we're initializing multiple values we're having additional `index_addr` instructions, but
3670
+ // the handling is similar.
3653
3671
LLVM_DEBUG (getADDebugStream ()
3654
3672
<< " Accumulating adjoint value for array literal into element "
3655
3673
" address adjoint buffers"
3656
3674
<< originalValue);
3657
3675
auto arrayAdjoint = materializeAdjointDirect (arrayAdjointValue, loc);
3658
3676
builder.setCurrentDebugScope (remapScope (dti->getDebugScope ()));
3659
- builder.setInsertionPoint (arrayAdjoint->getParentBlock ());
3660
3677
for (auto use : dti->getResult (1 )->getUses ()) {
3661
3678
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser ());
3662
3679
assert (mdi && " Expected mark_dependence user" );
0 commit comments