|
| 1 | +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat |
| 2 | +//@ no-prefer-dynamic |
| 3 | +//@ needs-enzyme |
| 4 | +// |
| 5 | +// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir |
| 6 | +// level. If a user tries to differentiate two identical functions within the same compilation unit, |
| 7 | +// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the |
| 8 | +// merged placeholder function anymore, and compilation would fail. We prevent this by disabling |
| 9 | +// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working. |
| 10 | +// We also explicetly test that we keep running merge_function after AD, by checking for two |
| 11 | +// identical function calls in the LLVM-IR, while having two different calls in the Rust code. |
| 12 | +#![feature(autodiff)] |
| 13 | + |
| 14 | +use std::autodiff::autodiff; |
| 15 | + |
| 16 | +#[autodiff(d_square, Reverse, Duplicated, Active)] |
| 17 | +fn square(x: &f64) -> f64 { |
| 18 | + x * x |
| 19 | +} |
| 20 | + |
| 21 | +#[autodiff(d_square2, Reverse, Duplicated, Active)] |
| 22 | +fn square2(x: &f64) -> f64 { |
| 23 | + x * x |
| 24 | +} |
| 25 | + |
| 26 | +// CHECK:; identical_fnc::main |
| 27 | +// CHECK-NEXT:; Function Attrs: |
| 28 | +// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E() |
| 29 | +// CHECK-NEXT:start: |
| 30 | +// CHECK-NOT:br |
| 31 | +// CHECK-NOT:ret |
| 32 | +// CHECK:; call identical_fnc::d_square |
| 33 | +// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1) |
| 34 | +// CHECK-NEXT:; call identical_fnc::d_square |
| 35 | +// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2) |
| 36 | + |
| 37 | +fn main() { |
| 38 | + let x = std::hint::black_box(3.0); |
| 39 | + let mut dx1 = std::hint::black_box(1.0); |
| 40 | + let mut dx2 = std::hint::black_box(1.0); |
| 41 | + let _ = d_square(&x, &mut dx1, 1.0); |
| 42 | + let _ = d_square2(&x, &mut dx2, 1.0); |
| 43 | + assert_eq!(dx1, 6.0); |
| 44 | + assert_eq!(dx2, 6.0); |
| 45 | +} |
0 commit comments