Skip to content

Commit f79a992

Browse files
committed
add tests for merge_function handling
1 parent 5ea9125 commit f79a992

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//
5+
// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
6+
// level. If a user tries to differentiate two identical functions within the same compilation unit,
7+
// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
8+
// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
9+
// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
10+
// We also explicetly test that we keep running merge_function after AD, by checking for two
11+
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
12+
#![feature(autodiff)]
13+
14+
use std::autodiff::autodiff;
15+
16+
#[autodiff(d_square, Reverse, Duplicated, Active)]
17+
fn square(x: &f64) -> f64 {
18+
x * x
19+
}
20+
21+
#[autodiff(d_square2, Reverse, Duplicated, Active)]
22+
fn square2(x: &f64) -> f64 {
23+
x * x
24+
}
25+
26+
// CHECK:; identical_fnc::main
27+
// CHECK-NEXT:; Function Attrs:
28+
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
29+
// CHECK-NEXT:start:
30+
// CHECK-NOT:br
31+
// CHECK-NOT:ret
32+
// CHECK:; call identical_fnc::d_square
33+
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
34+
// CHECK-NEXT:; call identical_fnc::d_square
35+
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
36+
37+
fn main() {
38+
let x = std::hint::black_box(3.0);
39+
let mut dx1 = std::hint::black_box(1.0);
40+
let mut dx2 = std::hint::black_box(1.0);
41+
let _ = d_square(&x, &mut dx1, 1.0);
42+
let _ = d_square2(&x, &mut dx2, 1.0);
43+
assert_eq!(dx1, 6.0);
44+
assert_eq!(dx2, 6.0);
45+
}

0 commit comments

Comments
 (0)