Skip to content

Commit 3ce4d9c

Browse files
committed
Macro expansion with rustc_intrinsic
WARNING: ad function defined in traits are broken
1 parent 065172e commit 3ce4d9c

File tree

7 files changed

+67
-142
lines changed

7 files changed

+67
-142
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ mod llvm_enzyme {
330330
.filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
331331
.count() as u32;
332332
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
333-
let d_body = gen_enzyme_body(
333+
334+
// UNUSED
335+
let _d_body = gen_enzyme_body(
334336
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
335337
&generics,
336338
);
@@ -342,7 +344,7 @@ mod llvm_enzyme {
342344
ident: first_ident(&meta_item_vec[0]),
343345
generics,
344346
contract: None,
345-
body: Some(d_body),
347+
body: None,
346348
define_opaque: None,
347349
});
348350
let mut rustc_ad_attr =
@@ -429,12 +431,18 @@ mod llvm_enzyme {
429431
tokens: ts,
430432
});
431433

434+
let rustc_intrinsic_attr =
435+
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic)));
436+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
437+
let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span);
438+
439+
let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
432440
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
433441
let d_annotatable = match &item {
434442
Annotatable::AssocItem(_, _) => {
435443
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
436444
let d_fn = P(ast::AssocItem {
437-
attrs: thin_vec![d_attr, inline_never],
445+
attrs: thin_vec![d_attr, intrinsic_attr],
438446
id: ast::DUMMY_NODE_ID,
439447
span,
440448
vis,
@@ -444,13 +452,15 @@ mod llvm_enzyme {
444452
Annotatable::AssocItem(d_fn, Impl { of_trait: false })
445453
}
446454
Annotatable::Item(_) => {
447-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
455+
let mut d_fn =
456+
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
448457
d_fn.vis = vis;
449458

450459
Annotatable::Item(d_fn)
451460
}
452461
Annotatable::Stmt(_) => {
453-
let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
462+
let mut d_fn =
463+
ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf));
454464
d_fn.vis = vis;
455465

456466
Annotatable::Stmt(P(ast::Stmt {

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 33 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6+
#![feature(intrinsics)]
67
#[prelude_import]
78
use ::std::prelude::rust_2015::*;
89
#[macro_use]
@@ -36,163 +37,92 @@
3637
::core::panicking::panic("not implemented")
3738
}
3839
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
39-
#[inline(never)]
40-
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
41-
unsafe { asm!("NOP", options(pure, nomem)); };
42-
::core::hint::black_box(f1(x, y));
43-
::core::hint::black_box((bx_0,));
44-
::core::hint::black_box(<(f64, f64)>::default())
45-
}
40+
#[rustc_intrinsic]
41+
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64);
4642
#[rustc_autodiff]
4743
#[inline(never)]
4844
pub fn f2(x: &[f64], y: f64) -> f64 {
4945
::core::panicking::panic("not implemented")
5046
}
5147
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
52-
#[inline(never)]
53-
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
54-
unsafe { asm!("NOP", options(pure, nomem)); };
55-
::core::hint::black_box(f2(x, y));
56-
::core::hint::black_box((bx_0,));
57-
::core::hint::black_box(f2(x, y))
58-
}
48+
#[rustc_intrinsic]
49+
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64;
5950
#[rustc_autodiff]
6051
#[inline(never)]
6152
pub fn f3(x: &[f64], y: f64) -> f64 {
6253
::core::panicking::panic("not implemented")
6354
}
6455
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
65-
#[inline(never)]
66-
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
67-
unsafe { asm!("NOP", options(pure, nomem)); };
68-
::core::hint::black_box(f3(x, y));
69-
::core::hint::black_box((bx_0,));
70-
::core::hint::black_box(f3(x, y))
71-
}
56+
#[rustc_intrinsic]
57+
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64;
7258
#[rustc_autodiff]
7359
#[inline(never)]
7460
pub fn f4() {}
7561
#[rustc_autodiff(Forward, 1, None)]
76-
#[inline(never)]
77-
pub fn df4() -> () {
78-
unsafe { asm!("NOP", options(pure, nomem)); };
79-
::core::hint::black_box(f4());
80-
::core::hint::black_box(());
81-
}
62+
#[rustc_intrinsic]
63+
pub fn df4() -> ();
8264
#[rustc_autodiff]
8365
#[inline(never)]
8466
pub fn f5(x: &[f64], y: f64) -> f64 {
8567
::core::panicking::panic("not implemented")
8668
}
8769
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
88-
#[inline(never)]
89-
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
90-
unsafe { asm!("NOP", options(pure, nomem)); };
91-
::core::hint::black_box(f5(x, y));
92-
::core::hint::black_box((by_0,));
93-
::core::hint::black_box(f5(x, y))
94-
}
70+
#[rustc_intrinsic]
71+
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64;
9572
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
96-
#[inline(never)]
97-
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
98-
unsafe { asm!("NOP", options(pure, nomem)); };
99-
::core::hint::black_box(f5(x, y));
100-
::core::hint::black_box((bx_0,));
101-
::core::hint::black_box(f5(x, y))
102-
}
73+
#[rustc_intrinsic]
74+
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64;
10375
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
104-
#[inline(never)]
105-
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
106-
unsafe { asm!("NOP", options(pure, nomem)); };
107-
::core::hint::black_box(f5(x, y));
108-
::core::hint::black_box((dx_0, dret));
109-
::core::hint::black_box(f5(x, y))
110-
}
76+
#[rustc_intrinsic]
77+
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64;
11178
struct DoesNotImplDefault;
11279
#[rustc_autodiff]
11380
#[inline(never)]
11481
pub fn f6() -> DoesNotImplDefault {
11582
::core::panicking::panic("not implemented")
11683
}
11784
#[rustc_autodiff(Forward, 1, Const)]
118-
#[inline(never)]
119-
pub fn df6() -> DoesNotImplDefault {
120-
unsafe { asm!("NOP", options(pure, nomem)); };
121-
::core::hint::black_box(f6());
122-
::core::hint::black_box(());
123-
::core::hint::black_box(f6())
124-
}
85+
#[rustc_intrinsic]
86+
pub fn df6() -> DoesNotImplDefault;
12587
#[rustc_autodiff]
12688
#[inline(never)]
12789
pub fn f7(x: f32) -> () {}
12890
#[rustc_autodiff(Forward, 1, Const, None)]
129-
#[inline(never)]
130-
pub fn df7(x: f32) -> () {
131-
unsafe { asm!("NOP", options(pure, nomem)); };
132-
::core::hint::black_box(f7(x));
133-
::core::hint::black_box(());
134-
}
91+
#[rustc_intrinsic]
92+
pub fn df7(x: f32) -> ();
13593
#[no_mangle]
13694
#[rustc_autodiff]
13795
#[inline(never)]
13896
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
13997
#[rustc_autodiff(Forward, 4, Dual, Dual)]
140-
#[inline(never)]
98+
#[rustc_intrinsic]
14199
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
142-
-> [f32; 5usize] {
143-
unsafe { asm!("NOP", options(pure, nomem)); };
144-
::core::hint::black_box(f8(x));
145-
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
146-
::core::hint::black_box(<[f32; 5usize]>::default())
147-
}
100+
-> [f32; 5usize];
148101
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
149-
#[inline(never)]
102+
#[rustc_intrinsic]
150103
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
151-
-> [f32; 4usize] {
152-
unsafe { asm!("NOP", options(pure, nomem)); };
153-
::core::hint::black_box(f8(x));
154-
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
155-
::core::hint::black_box(<[f32; 4usize]>::default())
156-
}
104+
-> [f32; 4usize];
157105
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
158-
#[inline(never)]
159-
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
160-
unsafe { asm!("NOP", options(pure, nomem)); };
161-
::core::hint::black_box(f8(x));
162-
::core::hint::black_box((bx_0,));
163-
::core::hint::black_box(<f32>::default())
164-
}
106+
#[rustc_intrinsic]
107+
fn f8_1(x: &f32, bx_0: &f32) -> f32;
165108
pub fn f9() {
166109
#[rustc_autodiff]
167110
#[inline(never)]
168111
fn inner(x: f32) -> f32 { x * x }
169112
#[rustc_autodiff(Forward, 1, Dual, Dual)]
170-
#[inline(never)]
171-
fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) {
172-
unsafe { asm!("NOP", options(pure, nomem)); };
173-
::core::hint::black_box(inner(x));
174-
::core::hint::black_box((bx_0,));
175-
::core::hint::black_box(<(f32, f32)>::default())
176-
}
113+
#[rustc_intrinsic]
114+
fn d_inner_2(x: f32, bx_0: f32)
115+
-> (f32, f32);
177116
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
178-
#[inline(never)]
179-
fn d_inner_1(x: f32, bx_0: f32) -> f32 {
180-
unsafe { asm!("NOP", options(pure, nomem)); };
181-
::core::hint::black_box(inner(x));
182-
::core::hint::black_box((bx_0,));
183-
::core::hint::black_box(<f32>::default())
184-
}
117+
#[rustc_intrinsic]
118+
fn d_inner_1(x: f32, bx_0: f32)
119+
-> f32;
185120
}
186121
#[rustc_autodiff]
187122
#[inline(never)]
188123
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
189124
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
190-
#[inline(never)]
125+
#[rustc_intrinsic]
191126
pub fn d_square<T: std::ops::Mul<Output = T> +
192-
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
193-
unsafe { asm!("NOP", options(pure, nomem)); };
194-
::core::hint::black_box(f10::<T>(x));
195-
::core::hint::black_box((dx_0, dret));
196-
::core::hint::black_box(f10::<T>(x))
197-
}
127+
Copy>(x: &T, dx_0: &mut T, dret: T) -> T;
198128
fn main() {}

tests/pretty/autodiff/autodiff_forward.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//@ needs-enzyme
22

33
#![feature(autodiff)]
4+
#![feature(intrinsics)]
45
//@ pretty-mode:expanded
56
//@ pretty-compare-only
67
//@ pp-exact:autodiff_forward.pp

tests/pretty/autodiff/autodiff_reverse.pp

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6+
#![feature(intrinsics)]
67
#[prelude_import]
78
use ::std::prelude::rust_2015::*;
89
#[macro_use]
@@ -29,58 +30,36 @@
2930
::core::panicking::panic("not implemented")
3031
}
3132
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
32-
#[inline(never)]
33-
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
34-
unsafe { asm!("NOP", options(pure, nomem)); };
35-
::core::hint::black_box(f1(x, y));
36-
::core::hint::black_box((dx_0, dret));
37-
::core::hint::black_box(f1(x, y))
38-
}
33+
#[rustc_intrinsic]
34+
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64;
3935
#[rustc_autodiff]
4036
#[inline(never)]
4137
pub fn f2() {}
4238
#[rustc_autodiff(Reverse, 1, None)]
43-
#[inline(never)]
44-
pub fn df2() {
45-
unsafe { asm!("NOP", options(pure, nomem)); };
46-
::core::hint::black_box(f2());
47-
::core::hint::black_box(());
48-
}
39+
#[rustc_intrinsic]
40+
pub fn df2();
4941
#[rustc_autodiff]
5042
#[inline(never)]
5143
pub fn f3(x: &[f64], y: f64) -> f64 {
5244
::core::panicking::panic("not implemented")
5345
}
5446
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
55-
#[inline(never)]
56-
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
57-
unsafe { asm!("NOP", options(pure, nomem)); };
58-
::core::hint::black_box(f3(x, y));
59-
::core::hint::black_box((dx_0, dret));
60-
::core::hint::black_box(f3(x, y))
61-
}
47+
#[rustc_intrinsic]
48+
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64;
6249
enum Foo { Reverse, }
6350
use Foo::Reverse;
6451
#[rustc_autodiff]
6552
#[inline(never)]
6653
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
6754
#[rustc_autodiff(Reverse, 1, Const, None)]
68-
#[inline(never)]
69-
pub fn df4(x: f32) {
70-
unsafe { asm!("NOP", options(pure, nomem)); };
71-
::core::hint::black_box(f4(x));
72-
::core::hint::black_box(());
73-
}
55+
#[rustc_intrinsic]
56+
pub fn df4(x: f32);
7457
#[rustc_autodiff]
7558
#[inline(never)]
7659
pub fn f5(x: *const f32, y: &f32) {
7760
::core::panicking::panic("not implemented")
7861
}
7962
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
80-
#[inline(never)]
81-
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
82-
unsafe { asm!("NOP", options(pure, nomem)); };
83-
::core::hint::black_box(f5(x, y));
84-
::core::hint::black_box((dx_0, dy_0));
85-
}
63+
#[rustc_intrinsic]
64+
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32);
8665
fn main() {}

tests/pretty/autodiff/autodiff_reverse.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//@ needs-enzyme
22

33
#![feature(autodiff)]
4+
#![feature(intrinsics)]
45
//@ pretty-mode:expanded
56
//@ pretty-compare-only
67
//@ pp-exact:autodiff_reverse.pp
@@ -23,7 +24,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 {
2324
unimplemented!()
2425
}
2526

26-
enum Foo { Reverse }
27+
enum Foo {
28+
Reverse,
29+
}
2730
use Foo::Reverse;
2831
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant
2932
// constructor) namespace? > It's expected to work normally.

tests/pretty/autodiff/inherent_impl.pp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//@ needs-enzyme
44

55
#![feature(autodiff)]
6+
#![feature(intrinsics)]
67
#[prelude_import]
78
use ::std::prelude::rust_2015::*;
89
#[macro_use]
@@ -31,7 +32,7 @@
3132
self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln())
3233
}
3334
#[rustc_autodiff(Reverse, 1, Const, Active, Active)]
34-
#[inline(never)]
35+
#[rustc_intrinsic]
3536
fn df(&self, x: f64, dret: f64) -> (f64, f64) {
3637
unsafe { asm!("NOP", options(pure, nomem)); };
3738
::core::hint::black_box(self.f(x));

tests/pretty/autodiff/inherent_impl.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//@ needs-enzyme
22

33
#![feature(autodiff)]
4+
#![feature(intrinsics)]
45
//@ pretty-mode:expanded
56
//@ pretty-compare-only
67
//@ pp-exact:inherent_impl.pp

0 commit comments

Comments
 (0)