Skip to content

Commit 8e88741

Browse files
authored
Unrolled build for rust-lang#140049
Rollup merge of rust-lang#140049 - haenoe:fix-autodiff-generics, r=ZuseZ4 fix autodiff macro on generic functions heloo there! This short PR allows applying the `autodiff` macro to generic functions like this one. It only touches the frontend part, since the `rustc_autodiff` macro can already handle generics. ```rust #[autodiff(d_square, Reverse, Duplicated, Active)] fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x } ``` Thanks to Manuel for creating an issue on this. For more information on this see rust-lang#140032 r? `@ZuseZ4` As always: thanks for any piece of feedback!! Fixes: rust-lang#140032 Tracking issue for autodiff: rust-lang#124509
2 parents e42bbfe + a504759 commit 8e88741

File tree

4 files changed

+124
-11
lines changed

4 files changed

+124
-11
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ mod llvm_enzyme {
7373
}
7474

7575
// Get information about the function the macro is applied to
76-
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident)> {
76+
fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
7777
match &iitem.kind {
78-
ItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
79-
Some((iitem.vis.clone(), sig.clone(), ident.clone()))
78+
ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79+
Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
8080
}
8181
_ => None,
8282
}
@@ -210,16 +210,18 @@ mod llvm_enzyme {
210210
}
211211
let dcx = ecx.sess.dcx();
212212

213-
// first get information about the annotable item:
214-
let Some((vis, sig, primal)) = (match &item {
213+
// first get information about the annotable item: visibility, signature, name and generic
214+
// parameters.
215+
// these will be used to generate the differentiated version of the function
216+
let Some((vis, sig, primal, generics)) = (match &item {
215217
Annotatable::Item(iitem) => extract_item_info(iitem),
216218
Annotatable::Stmt(stmt) => match &stmt.kind {
217219
ast::StmtKind::Item(iitem) => extract_item_info(iitem),
218220
_ => None,
219221
},
220222
Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
221-
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, .. }) => {
222-
Some((assoc_item.vis.clone(), sig.clone(), ident.clone()))
223+
ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
224+
Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
223225
}
224226
_ => None,
225227
},
@@ -303,14 +305,15 @@ mod llvm_enzyme {
303305
let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
304306
let d_body = gen_enzyme_body(
305307
ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
308+
&generics,
306309
);
307310

308311
// The first element of it is the name of the function to be generated
309312
let asdf = Box::new(ast::Fn {
310313
defaultness: ast::Defaultness::Final,
311314
sig: d_sig,
312315
ident: first_ident(&meta_item_vec[0]),
313-
generics: Generics::default(),
316+
generics,
314317
contract: None,
315318
body: Some(d_body),
316319
define_opaque: None,
@@ -475,6 +478,7 @@ mod llvm_enzyme {
475478
new_decl_span: Span,
476479
idents: &[Ident],
477480
errored: bool,
481+
generics: &Generics,
478482
) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
479483
let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
480484
let noop = ast::InlineAsm {
@@ -497,7 +501,7 @@ mod llvm_enzyme {
497501
};
498502
let unsf_expr = ecx.expr_block(P(unsf_block));
499503
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
500-
let primal_call = gen_primal_call(ecx, span, primal, idents);
504+
let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
501505
let black_box_primal_call = ecx.expr_call(
502506
new_decl_span,
503507
blackbox_call_expr.clone(),
@@ -546,6 +550,7 @@ mod llvm_enzyme {
546550
sig_span: Span,
547551
idents: Vec<Ident>,
548552
errored: bool,
553+
generics: &Generics,
549554
) -> P<ast::Block> {
550555
let new_decl_span = d_sig.span;
551556

@@ -566,6 +571,7 @@ mod llvm_enzyme {
566571
new_decl_span,
567572
&idents,
568573
errored,
574+
generics,
569575
);
570576

571577
if !has_ret(&d_sig.decl.output) {
@@ -608,7 +614,6 @@ mod llvm_enzyme {
608614
panic!("Did not expect Default ret ty: {:?}", span);
609615
}
610616
};
611-
612617
if x.mode.is_fwd() {
613618
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
614619
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -668,8 +673,10 @@ mod llvm_enzyme {
668673
span: Span,
669674
primal: Ident,
670675
idents: &[Ident],
676+
generics: &Generics,
671677
) -> P<ast::Expr> {
672678
let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
679+
673680
if has_self {
674681
let args: ThinVec<_> =
675682
idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
@@ -678,7 +685,51 @@ mod llvm_enzyme {
678685
} else {
679686
let args: ThinVec<_> =
680687
idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
681-
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
688+
let mut primal_path = ecx.path_ident(span, primal);
689+
690+
let is_generic = !generics.params.is_empty();
691+
692+
match (is_generic, primal_path.segments.last_mut()) {
693+
(true, Some(function_path)) => {
694+
let primal_generic_types = generics
695+
.params
696+
.iter()
697+
.filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
698+
699+
let generated_generic_types = primal_generic_types
700+
.map(|type_param| {
701+
let generic_param = TyKind::Path(
702+
None,
703+
ast::Path {
704+
span,
705+
segments: thin_vec![ast::PathSegment {
706+
ident: type_param.ident,
707+
args: None,
708+
id: ast::DUMMY_NODE_ID,
709+
}],
710+
tokens: None,
711+
},
712+
);
713+
714+
ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
715+
id: type_param.id,
716+
span,
717+
kind: generic_param,
718+
tokens: None,
719+
})))
720+
})
721+
.collect();
722+
723+
function_path.args =
724+
Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
725+
span,
726+
args: generated_generic_types,
727+
})));
728+
}
729+
_ => {}
730+
}
731+
732+
let primal_call_expr = ecx.expr_path(primal_path);
682733
ecx.expr_call(span, primal_call_expr, args)
683734
}
684735
}

tests/codegen/autodiff/generic.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
#![feature(autodiff)]
5+
6+
use std::autodiff::autodiff;
7+
8+
#[autodiff(d_square, Reverse, Duplicated, Active)]
9+
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
10+
*x * *x
11+
}
12+
13+
// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
14+
//
15+
// CHECK: ; generic::square
16+
// CHECK-NEXT: ; Function Attrs:
17+
// CHECK-NEXT: define internal {{.*}} double
18+
// CHECK-NEXT: start:
19+
// CHECK-NOT: ret
20+
// CHECK: fmul double
21+
22+
// Ensure that `d_square::<f32>` code is generated
23+
//
24+
// CHECK: ; generic::square
25+
// CHECK-NEXT: ; Function Attrs: {{.*}}
26+
// CHECK-NEXT: define internal {{.*}} float
27+
// CHECK-NEXT: start:
28+
// CHECK-NOT: ret
29+
// CHECK: fmul float
30+
31+
fn main() {
32+
let xf32: f32 = std::hint::black_box(3.0);
33+
let xf64: f64 = std::hint::black_box(3.0);
34+
35+
let outputf32 = square::<f32>(&xf32);
36+
assert_eq!(9.0, outputf32);
37+
38+
let mut df_dxf64: f64 = std::hint::black_box(0.0);
39+
40+
let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
41+
assert_eq!(6.0, df_dxf64);
42+
}

tests/pretty/autodiff/autodiff_forward.pp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131

3232
// We want to make sure that we can use the macro for functions defined inside of functions
3333

34+
// Make sure we can handle generics
35+
3436
::core::panicking::panic("not implemented")
3537
}
3638
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -181,4 +183,16 @@
181183
::core::hint::black_box(<f32>::default())
182184
}
183185
}
186+
#[rustc_autodiff]
187+
#[inline(never)]
188+
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T { *x * *x }
189+
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
190+
#[inline(never)]
191+
pub fn d_square<T: std::ops::Mul<Output = T> +
192+
Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
193+
unsafe { asm!("NOP", options(pure, nomem)); };
194+
::core::hint::black_box(f10::<T>(x));
195+
::core::hint::black_box((dx_0, dret));
196+
::core::hint::black_box(f10::<T>(x))
197+
}
184198
fn main() {}

tests/pretty/autodiff/autodiff_forward.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,10 @@ pub fn f9() {
6363
}
6464
}
6565

66+
// Make sure we can handle generics
67+
#[autodiff(d_square, Reverse, Duplicated, Active)]
68+
pub fn f10<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
69+
*x * *x
70+
}
71+
6672
fn main() {}

0 commit comments

Comments
 (0)