From 6108a2dc3536302f77241c001d109b69c50640c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Wed, 4 Jun 2025 08:10:51 +0000 Subject: [PATCH 01/15] Lower autodiff functions using instrinsics --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 3 +++ compiler/rustc_hir_analysis/src/check/intrinsic.rs | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index f7f062849a8b5..9d16a799d56b5 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -187,6 +187,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } + _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + return Err(ty::Instance::new_raw(def_id, instance.args)); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 060fc51b7bda7..f9fb0b4e2b964 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -170,6 +170,8 @@ pub(crate) fn check_intrinsic_type( } }; + let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); + let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -197,6 +199,17 @@ pub(crate) fn check_intrinsic_type( let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { + _ if has_autodiff => { + let sig = tcx.fn_sig(intrinsic_id.to_def_id()); + let sig = sig.skip_binder(); + let n_tps = generics.own_counts().types; + let n_cts = generics.own_counts().consts; + + let inputs = sig.skip_binder().inputs().to_vec(); + let output = sig.skip_binder().output(); + + (n_tps, n_cts, inputs, output) + } sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), From dc56f2b5c7ef1d42740d7cceb2cc3f41520e5a3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 5 Jun 2025 17:46:21 +0000 Subject: [PATCH 02/15] Macro expansion with `rustc_intrinsic` WARNING: ad function defined in traits are broken --- compiler/rustc_builtin_macros/src/autodiff.rs | 20 ++- tests/pretty/autodiff/autodiff_forward.pp | 136 +++++------------- tests/pretty/autodiff/autodiff_forward.rs | 1 + tests/pretty/autodiff/autodiff_reverse.pp | 43 ++---- tests/pretty/autodiff/autodiff_reverse.rs | 5 +- tests/pretty/autodiff/inherent_impl.pp | 3 +- tests/pretty/autodiff/inherent_impl.rs | 1 + 7 files changed, 67 insertions(+), 142 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index df1b1eb60e18f..9d7a42c4079b8 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -330,7 +330,9 @@ mod llvm_enzyme { .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - let d_body = gen_enzyme_body( + + // UNUSED + let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); @@ -342,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: Some(d_body), + body: None, define_opaque: None, }); let mut rustc_ad_attr = @@ -429,12 +431,18 @@ mod llvm_enzyme { tokens: ts, }); + let rustc_intrinsic_attr = + P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span); + + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, inline_never], + attrs: thin_vec![d_attr, intrinsic_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -444,13 +452,15 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf)); + let mut d_fn = + ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { diff --git a/tests/pretty/autodiff/autodiff_forward.pp b/tests/pretty/autodiff/autodiff_forward.pp index a2525abc83207..787c2e517492c 100644 --- a/tests/pretty/autodiff/autodiff_forward.pp +++ b/tests/pretty/autodiff/autodiff_forward.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -36,78 +37,44 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Dual)] -#[inline(never)] -pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f64, f64)>::default()) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64); #[rustc_autodiff] #[inline(never)] pub fn f2(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f2(x, y)) -} +#[rustc_intrinsic] +pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f4() {} #[rustc_autodiff(Forward, 1, None)] -#[inline(never)] -pub fn df4() -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4() -> (); #[rustc_autodiff] #[inline(never)] pub fn f5(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const, Dual, Const)] -#[inline(never)] -pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((by_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64; #[rustc_autodiff(Forward, 1, Dual, Const, Const)] -#[inline(never)] -pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64; #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f5(x, y)) -} +#[rustc_intrinsic] +pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; struct DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] @@ -115,84 +82,47 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 1, Const)] -#[inline(never)] -pub fn df6() -> DoesNotImplDefault { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f6()); - ::core::hint::black_box(()); - ::core::hint::black_box(f6()) -} +#[rustc_intrinsic] +pub fn df6() -> DoesNotImplDefault; #[rustc_autodiff] #[inline(never)] pub fn f7(x: f32) -> () {} #[rustc_autodiff(Forward, 1, Const, None)] -#[inline(never)] -pub fn df7(x: f32) -> () { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f7(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df7(x: f32) -> (); #[no_mangle] #[rustc_autodiff] #[inline(never)] fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Forward, 4, Dual, Dual)] -#[inline(never)] +#[rustc_intrinsic] fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 5usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 5usize]>::default()) -} +-> [f32; 5usize]; #[rustc_autodiff(Forward, 4, Dual, DualOnly)] -#[inline(never)] +#[rustc_intrinsic] fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32) - -> [f32; 4usize] { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0, bx_1, bx_2, bx_3)); - ::core::hint::black_box(<[f32; 4usize]>::default()) -} +-> [f32; 4usize]; #[rustc_autodiff(Forward, 1, Dual, DualOnly)] -#[inline(never)] -fn f8_1(x: &f32, bx_0: &f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f8(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) -} +#[rustc_intrinsic] +fn f8_1(x: &f32, bx_0: &f32) -> f32; pub fn f9() { #[rustc_autodiff] #[inline(never)] fn inner(x: f32) -> f32 { x * x } #[rustc_autodiff(Forward, 1, Dual, Dual)] - #[inline(never)] - fn d_inner_2(x: f32, bx_0: f32) -> (f32, f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(<(f32, f32)>::default()) - } + #[rustc_intrinsic] + fn d_inner_2(x: f32, bx_0: f32) + -> (f32, f32); #[rustc_autodiff(Forward, 1, Dual, DualOnly)] - #[inline(never)] - fn d_inner_1(x: f32, bx_0: f32) -> f32 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(inner(x)); - ::core::hint::black_box((bx_0,)); - ::core::hint::black_box(::default()) - } + #[rustc_intrinsic] + fn d_inner_1(x: f32, bx_0: f32) + -> f32; } #[rustc_autodiff] #[inline(never)] pub fn f10 + Copy>(x: &T) -> T { *x * *x } #[rustc_autodiff(Reverse, 1, Duplicated, Active)] -#[inline(never)] +#[rustc_intrinsic] pub fn d_square + - Copy>(x: &T, dx_0: &mut T, dret: T) -> T { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f10::(x)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f10::(x)) -} +Copy>(x: &T, dx_0: &mut T, dret: T) -> T; fn main() {} diff --git a/tests/pretty/autodiff/autodiff_forward.rs b/tests/pretty/autodiff/autodiff_forward.rs index e23a1b3e241e9..b003d87dccfa7 100644 --- a/tests/pretty/autodiff/autodiff_forward.rs +++ b/tests/pretty/autodiff/autodiff_forward.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_forward.pp diff --git a/tests/pretty/autodiff/autodiff_reverse.pp b/tests/pretty/autodiff/autodiff_reverse.pp index e67c3443ddef1..6f368c74f1a26 100644 --- a/tests/pretty/autodiff/autodiff_reverse.pp +++ b/tests/pretty/autodiff/autodiff_reverse.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -29,58 +30,36 @@ ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f1(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f1(x, y)) -} +#[rustc_intrinsic] +pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; #[rustc_autodiff] #[inline(never)] pub fn f2() {} #[rustc_autodiff(Reverse, 1, None)] -#[inline(never)] -pub fn df2() { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f2()); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df2(); #[rustc_autodiff] #[inline(never)] pub fn f3(x: &[f64], y: f64) -> f64 { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)] -#[inline(never)] -pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f3(x, y)); - ::core::hint::black_box((dx_0, dret)); - ::core::hint::black_box(f3(x, y)) -} +#[rustc_intrinsic] +pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64; enum Foo { Reverse, } use Foo::Reverse; #[rustc_autodiff] #[inline(never)] pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, Const, None)] -#[inline(never)] -pub fn df4(x: f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f4(x)); - ::core::hint::black_box(()); -} +#[rustc_intrinsic] +pub fn df4(x: f32); #[rustc_autodiff] #[inline(never)] pub fn f5(x: *const f32, y: &f32) { ::core::panicking::panic("not implemented") } #[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)] -#[inline(never)] -pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) { - unsafe { asm!("NOP", options(pure, nomem)); }; - ::core::hint::black_box(f5(x, y)); - ::core::hint::black_box((dx_0, dy_0)); -} +#[rustc_intrinsic] +pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32); fn main() {} diff --git a/tests/pretty/autodiff/autodiff_reverse.rs b/tests/pretty/autodiff/autodiff_reverse.rs index d37e5e3eb4cec..fc95ba2e5a63e 100644 --- a/tests/pretty/autodiff/autodiff_reverse.rs +++ b/tests/pretty/autodiff/autodiff_reverse.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:autodiff_reverse.pp @@ -23,7 +24,9 @@ pub fn f3(x: &[f64], y: f64) -> f64 { unimplemented!() } -enum Foo { Reverse } +enum Foo { + Reverse, +} use Foo::Reverse; // What happens if we already have Reverse in type (enum variant decl) and value (enum variant // constructor) namespace? > It's expected to work normally. diff --git a/tests/pretty/autodiff/inherent_impl.pp b/tests/pretty/autodiff/inherent_impl.pp index d18061b2dbdef..4bc8dac0dc758 100644 --- a/tests/pretty/autodiff/inherent_impl.pp +++ b/tests/pretty/autodiff/inherent_impl.pp @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] #[prelude_import] use ::std::prelude::rust_2015::*; #[macro_use] @@ -31,7 +32,7 @@ self.a * 0.25 * (x * x - 1.0 - 2.0 * x.ln()) } #[rustc_autodiff(Reverse, 1, Const, Active, Active)] - #[inline(never)] + #[rustc_intrinsic] fn df(&self, x: f64, dret: f64) -> (f64, f64) { unsafe { asm!("NOP", options(pure, nomem)); }; ::core::hint::black_box(self.f(x)); diff --git a/tests/pretty/autodiff/inherent_impl.rs b/tests/pretty/autodiff/inherent_impl.rs index 11ff209f9d89e..9f00ff5eb02c1 100644 --- a/tests/pretty/autodiff/inherent_impl.rs +++ b/tests/pretty/autodiff/inherent_impl.rs @@ -1,6 +1,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] //@ pretty-mode:expanded //@ pretty-compare-only //@ pp-exact:inherent_impl.pp From fc86ddacc157e76737b770462247ef247ddb1e37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 17 Jun 2025 19:52:00 +0000 Subject: [PATCH 03/15] Lowering draft --- compiler/rustc_builtin_macros/src/autodiff.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 9d7a42c4079b8..c9885bc12c6b9 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -344,7 +344,7 @@ mod llvm_enzyme { ident: first_ident(&meta_item_vec[0]), generics, contract: None, - body: None, + body: None, // This leads to an error when the ad function is inside a traits define_opaque: None, }); let mut rustc_ad_attr = diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 9d16a799d56b5..50296440de147 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -188,6 +188,21 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { ) } _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + // NOTE(Sa4dUs): This is a hacky way to get the autodiff items + // so we can focus on the lowering of the intrinsic call + + // `diff_items` is empty even when autodiff is enabled, and if we're here, + // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr + let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + + // this shouldn't happen? + if diff_items.is_empty() { + bug!("no autodiff items found for {def_id:?}"); + } + + // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + + // Just gen the fallback body for now return Err(ty::Instance::new_raw(def_id, instance.args)); } sym::is_val_statically_known => { From 10a7c45412f94b4f727cc8f7fd894fe5e06415da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 23 Jun 2025 12:17:53 +0000 Subject: [PATCH 04/15] Naive impl of intrinsic codegen Note(Sa4dUs): Most tests are still broken due to `sret` and how funcs are searched in the current logic --- .../src/builder/autodiff.rs | 62 ++++++------------- compiler/rustc_codegen_llvm/src/intrinsic.rs | 62 +++++++++++++++---- tests/codegen/autodiff/scalar.rs | 1 + 3 files changed, 70 insertions(+), 55 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index b07d9a5cfca8c..92597cc1e835a 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -4,13 +4,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_codegen_ssa::common::TypeKind; -use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; use rustc_errors::FatalError; use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{SBuilder, UNNAMED}; +use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; @@ -19,7 +19,7 @@ use crate::llvm::{Metadata, True}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; -fn get_params(fnc: &Value) -> Vec<&Value> { +fn _get_params(fnc: &Value) -> Vec<&Value> { let param_num = llvm::LLVMCountParams(fnc) as usize; let mut fnc_args: Vec<&Value> = vec![]; fnc_args.reserve(param_num); @@ -49,9 +49,9 @@ fn has_sret(fnc: &Value) -> bool { // need to match those. // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it // using iterators and peek()? -fn match_args_from_caller_to_enzyme<'ll>( +fn match_args_from_caller_to_enzyme<'ll, 'tcx>( cx: &SimpleCx<'ll>, - builder: &SBuilder<'ll, 'll>, + builder: &mut Builder<'_, 'll, 'tcx>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], @@ -289,11 +289,14 @@ fn compute_enzyme_fn_ty<'ll>( /// [^1]: // FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to // cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. -fn generate_enzyme_call<'ll>( +pub(crate) fn generate_enzyme_call<'ll, 'tcx>( + builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, outer_fn: &'ll Value, + fn_args: &[OperandRef<'tcx, &'ll Value>], attrs: AutoDiffAttrs, + dest: PlaceRef<'tcx, &'ll Value>, ) { // We have to pick the name depending on whether we want forward or reverse mode autodiff. let mut ad_name: String = match attrs.mode { @@ -366,14 +369,6 @@ fn generate_enzyme_call<'ll>( let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - // first, remove all calls from fnc - let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); - let br = llvm::LLVMRustGetTerminator(entry); - llvm::LLVMRustEraseInstFromParent(br); - - let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); - let mut builder = SBuilder::build(cx, entry); - let num_args = llvm::LLVMCountParams(&fn_to_diff); let mut args = Vec::with_capacity(num_args as usize + 1); args.push(fn_to_diff); @@ -389,10 +384,10 @@ fn generate_enzyme_call<'ll>( } let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = get_params(outer_fn); + let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); match_args_from_caller_to_enzyme( &cx, - &builder, + builder, attrs.width, &mut args, &attrs.input_activity, @@ -400,29 +395,9 @@ fn generate_enzyme_call<'ll>( has_sret, ); - let call = builder.call(enzyme_ty, ad_fn, &args, None); - - // This part is a bit iffy. LLVM requires that a call to an inlineable function has some - // metadata attached to it, but we just created this code oota. Given that the - // differentiated function already has partly confusing metadata, and given that this - // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the - // dummy code which we inserted at a higher level. - // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, - // and how to best improve it for enzyme core and rust-enzyme. - let md_ty = cx.get_md_kind_id("dbg"); - if llvm::LLVMRustHasMetadata(last_inst, md_ty) { - let md = llvm::LLVMRustDIGetInstMetadata(last_inst) - .expect("failed to get instruction metadata"); - let md_todiff = cx.get_metadata_value(md); - llvm::LLVMSetMetadata(call, md_ty, md_todiff); - } else { - // We don't panic, since depending on whether we are in debug or release mode, we might - // have no debug info to copy, which would then be ok. - trace!("no dbg info"); - } + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Now that we copied the metadata, get rid of dummy code. - llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst); + builder.store_to_place(call, dest.val); if cx.val_ty(call) == cx.type_void() || has_sret { if has_sret { @@ -445,10 +420,10 @@ fn generate_enzyme_call<'ll>( llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); } builder.ret_void(); - } else { - builder.ret(call); } + builder.store_to_place(call, dest.val); + // Let's crash in case that we messed something up above and generated invalid IR. llvm::LLVMRustVerifyFunction( outer_fn, @@ -463,6 +438,7 @@ pub(crate) fn differentiate<'ll>( diff_items: Vec, _config: &ModuleConfig, ) -> Result<(), FatalError> { + // TODO(Sa4dUs): delete all this logic for item in &diff_items { trace!("{}", item); } @@ -482,7 +458,7 @@ pub(crate) fn differentiate<'ll>( for item in diff_items.iter() { let name = item.source.clone(); let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(fn_def) = fn_def else { + let Some(_fn_def) = fn_def else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -494,7 +470,7 @@ pub(crate) fn differentiate<'ll>( }; debug!(?item.target); let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(fn_target) = fn_target else { + let Some(_fn_target) = fn_target else { return Err(llvm_err( diag_handler.handle(), LlvmError::PrepareAutoDiff { @@ -505,7 +481,7 @@ pub(crate) fn differentiate<'ll>( )); }; - generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); + // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); } // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 50296440de147..df1b268fbb600 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -9,17 +9,19 @@ use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; use rustc_hir as hir; +use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; -use rustc_symbol_mangling::mangle_internal_symbol; +use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; use rustc_target::spec::PanicStrategy; use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; +use crate::builder::autodiff::generate_enzyme_call; use crate::context::CodegenCx; use crate::llvm::{self, Metadata}; use crate::type_::Type; @@ -187,23 +189,59 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } - _ if tcx.has_attr(def_id, sym::rustc_autodiff) => { + _ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => { // NOTE(Sa4dUs): This is a hacky way to get the autodiff items // so we can focus on the lowering of the intrinsic call + let mut source_id = None; + let mut diff_attrs = None; + let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect(); + + // Hacky way of getting primal-diff pair, only works for code with 1 autodiff call + for target_id in &items { + let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else { + continue; + }; - // `diff_items` is empty even when autodiff is enabled, and if we're here, - // it's because some function was marked as intrinsic and had the `rustc_autodiff` attr - let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items; + if target_attrs.is_source() { + source_id = Some(*target_id); + } else { + diff_attrs = Some(target_attrs); + } + } - // this shouldn't happen? - if diff_items.is_empty() { - bug!("no autodiff items found for {def_id:?}"); + if source_id.is_none() || diff_attrs.is_none() { + bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}"); } - // TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs` + let diff_attrs = diff_attrs.unwrap().clone(); + + // Get source fn + let source_id = source_id.unwrap(); + let fn_source = Instance::mono(tcx, source_id); + let source_symbol = + symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + // Declare target fn + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty()); + let outer_fn: &'ll Value = + self.cx.declare_fn(&target_symbol, fn_abi, Some(instance)); + + // Build body + generate_enzyme_call( + self, + self.cx, + fn_to_diff, + outer_fn, + args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore + diff_attrs.clone(), + result, + ); - // Just gen the fallback body for now - return Err(ty::Instance::new_raw(def_id, instance.args)); + return Ok(()); } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { diff --git a/tests/codegen/autodiff/scalar.rs b/tests/codegen/autodiff/scalar.rs index 096b4209e84ad..c2bca7e9c81ef 100644 --- a/tests/codegen/autodiff/scalar.rs +++ b/tests/codegen/autodiff/scalar.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; From 447c75a20f370a8a68893e190980fc6cb2e84c99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 24 Jun 2025 14:00:26 +0000 Subject: [PATCH 05/15] Feature intrinsics in cg tests --- tests/codegen/autodiff/batched.rs | 1 + tests/codegen/autodiff/generic.rs | 1 + tests/codegen/autodiff/identical_fnc.rs | 1 + tests/codegen/autodiff/inline.rs | 1 + tests/codegen/autodiff/sret.rs | 28 ++++++++++++------------- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/codegen/autodiff/batched.rs b/tests/codegen/autodiff/batched.rs index d27aed50e6cc4..88a1de9994c8a 100644 --- a/tests/codegen/autodiff/batched.rs +++ b/tests/codegen/autodiff/batched.rs @@ -10,6 +10,7 @@ // reduce this test to only match the first lines and the ret instructions. #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_forward; diff --git a/tests/codegen/autodiff/generic.rs b/tests/codegen/autodiff/generic.rs index 2f674079be021..af9706c621208 100644 --- a/tests/codegen/autodiff/generic.rs +++ b/tests/codegen/autodiff/generic.rs @@ -2,6 +2,7 @@ //@ no-prefer-dynamic //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen/autodiff/identical_fnc.rs b/tests/codegen/autodiff/identical_fnc.rs index 1c25b3d09ab0d..ff8e6c74a6b34 100644 --- a/tests/codegen/autodiff/identical_fnc.rs +++ b/tests/codegen/autodiff/identical_fnc.rs @@ -10,6 +10,7 @@ // We also explicetly test that we keep running merge_function after AD, by checking for two // identical function calls in the LLVM-IR, while having two different calls in the Rust code. #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen/autodiff/inline.rs b/tests/codegen/autodiff/inline.rs index 65bed170207cc..5db69b960343c 100644 --- a/tests/codegen/autodiff/inline.rs +++ b/tests/codegen/autodiff/inline.rs @@ -3,6 +3,7 @@ //@ needs-enzyme #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; diff --git a/tests/codegen/autodiff/sret.rs b/tests/codegen/autodiff/sret.rs index d2fa85e3e3787..67f68fc053cc4 100644 --- a/tests/codegen/autodiff/sret.rs +++ b/tests/codegen/autodiff/sret.rs @@ -8,6 +8,7 @@ // We therefore use this test to verify some of our sret handling. #![feature(autodiff)] +#![feature(intrinsics)] use std::autodiff::autodiff_reverse; @@ -17,26 +18,25 @@ fn primal(x: f32, y: f32) -> f64 { (x * x * y) as f64 } -// CHECK:define internal fastcc void @_ZN4sret2df17h93be4316dd8ea006E(ptr dead_on_unwind noalias nocapture noundef nonnull writable writeonly align 8 dereferenceable(16) initializes((0, 16)) %_0, float noundef %x, float noundef %y) -// CHECK-NEXT:start: -// CHECK-NEXT: %0 = tail call fastcc { double, float, float } @diffeprimal(float %x, float %y) -// CHECK-NEXT: %.elt = extractvalue { double, float, float } %0, 0 -// CHECK-NEXT: store double %.elt, ptr %_0, align 8 -// CHECK-NEXT: %_0.repack1 = getelementptr inbounds nuw i8, ptr %_0, i64 8 -// CHECK-NEXT: %.elt2 = extractvalue { double, float, float } %0, 1 -// CHECK-NEXT: store float %.elt2, ptr %_0.repack1, align 8 -// CHECK-NEXT: %_0.repack3 = getelementptr inbounds nuw i8, ptr %_0, i64 12 -// CHECK-NEXT: %.elt4 = extractvalue { double, float, float } %0, 2 -// CHECK-NEXT: store float %.elt4, ptr %_0.repack3, align 4 -// CHECK-NEXT: ret void -// CHECK-NEXT:} +// CHECK: define internal fastcc { double, float, float } @diffeprimal(float noundef %x, float noundef %y) +// CHECK-NEXT: invertstart: +// CHECK-NEXT: %_4 = fmul float %x, %x +// CHECK-NEXT: %_3 = fmul float %_4, %y +// CHECK-NEXT: %_0 = fpext float %_3 to double +// CHECK-NEXT: %0 = fadd fast float %y, %y +// CHECK-NEXT: %1 = fmul fast float %0, %x +// CHECK-NEXT: %2 = insertvalue { double, float, float } undef, double %_0, 0 +// CHECK-NEXT: %3 = insertvalue { double, float, float } %2, float %1, 1 +// CHECK-NEXT: %4 = insertvalue { double, float, float } %3, float %_4, 2 +// CHECK-NEXT: ret { double, float, float } %4 +// CHECK-NEXT: } fn main() { let x = std::hint::black_box(3.0); let y = std::hint::black_box(2.5); let scalar = std::hint::black_box(1.0); let (r1, r2, r3) = df(x, y, scalar); - // 3*3*1.5 = 22.5 + // 3*3*2.5 = 22.5 assert_eq!(r1, 22.5); // 2*x*y = 2*3*2.5 = 15.0 assert_eq!(r2, 15.0); From fad0b0c1484fa45f670f4296771ec11e82cf8913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 29 Jun 2025 18:13:11 +0000 Subject: [PATCH 06/15] Remove `sret` logic --- .../src/builder/autodiff.rs | 234 ++++-------------- compiler/rustc_codegen_llvm/src/context.rs | 2 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 21 +- .../rustc_hir_analysis/src/check/intrinsic.rs | 7 +- 4 files changed, 68 insertions(+), 196 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 92597cc1e835a..6af9f5738214e 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -15,7 +15,7 @@ use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; -use crate::llvm::{Metadata, True}; +use crate::llvm::{Metadata, True, Type}; use crate::value::Value; use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; @@ -30,7 +30,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> { fnc_args } -fn has_sret(fnc: &Value) -> bool { +fn _has_sret(fnc: &Value) -> bool { let num_args = llvm::LLVMCountParams(fnc) as usize; if num_args == 0 { false @@ -56,7 +56,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], outer_args: &[&'ll llvm::Value], - has_sret: bool, ) { debug!("matching autodiff arguments"); // We now handle the issue that Rust level arguments not always match the llvm-ir level @@ -68,14 +67,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( let mut outer_pos: usize = 0; let mut activity_pos = 0; - if has_sret { - // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the - // inner function will still return something. We increase our outer_pos by one, - // and once we're done with all other args we will take the return of the inner call and - // update the sret pointer with it - outer_pos = 1; - } - let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); @@ -194,92 +185,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>( } } -// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input -// arguments. We do however need to declare them with their correct return type. -// We already figured the correct return type out in our frontend, when generating the outer_fn, -// so we can now just go ahead and use that. This is not always trivial, e.g. because sret. -// Beyond sret, this article describes our challenges nicely: -// -// I.e. (i32, f32) will get merged into i64, but we don't handle that yet. -fn compute_enzyme_fn_ty<'ll>( - cx: &SimpleCx<'ll>, - attrs: &AutoDiffAttrs, - fn_to_diff: &'ll Value, - outer_fn: &'ll Value, -) -> &'ll llvm::Type { - let fn_ty = cx.get_type_of_global(outer_fn); - let mut ret_ty = cx.get_return_type(fn_ty); - - let has_sret = has_sret(outer_fn); - - if has_sret { - // Now we don't just forward the return type, so we have to figure it out based on the - // primal return type, in combination with the autodiff settings. - let fn_ty = cx.get_type_of_global(fn_to_diff); - let inner_ret_ty = cx.get_return_type(fn_ty); - - let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) }; - if inner_ret_ty == void_ty { - // This indicates that even the inner function has an sret. - // Right now I only look for an sret in the outer function. - // This *probably* needs some extra handling, but I never ran - // into such a case. So I'll wait for user reports to have a test case. - bug!("sret in inner function"); - } - - if attrs.width == 1 { - // Enzyme returns a struct of style: - // `{ original_ret(if requested), float, float, ... }` - let mut struct_elements = vec![]; - if attrs.has_primal_ret() { - struct_elements.push(inner_ret_ty); - } - // Next, we push the list of active floats, since they will be lowered to `enzyme_out`, - // and therefore part of the return struct. - let param_tys = cx.func_params_types(fn_ty); - for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) { - if matches!(act, DiffActivity::Active) { - // Now find the float type at position i based on the fn_ty, - // to know what (f16/f32/f64/...) to add to the struct. - struct_elements.push(param_ty); - } - } - ret_ty = cx.type_struct(&struct_elements, false); - } else { - // First we check if we also have to deal with the primal return. - match attrs.mode { - DiffMode::Forward => match attrs.ret_activity { - DiffActivity::Dual => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) }; - ret_ty = arr_ty; - } - DiffActivity::DualOnly => { - let arr_ty = - unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) }; - ret_ty = arr_ty; - } - DiffActivity::Const => { - todo!("Not sure, do we need to do something here?"); - } - _ => { - bug!("unreachable"); - } - }, - DiffMode::Reverse => { - todo!("Handle sret for reverse mode"); - } - _ => { - bug!("unreachable"); - } - } - } - } - - // LLVM can figure out the input types on it's own, so we take a shortcut here. - unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) } -} - /// When differentiating `fn_to_diff`, take a `outer_fn` and generate another /// function with expected naming and calling conventions[^1] which will be /// discovered by the enzyme LLVM pass and its body populated with the differentiated @@ -293,7 +198,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder: &mut Builder<'_, 'll, 'tcx>, cx: &SimpleCx<'ll>, fn_to_diff: &'ll Value, - outer_fn: &'ll Value, + outer_name: &str, + ret_ty: &'ll Type, fn_args: &[OperandRef<'tcx, &'ll Value>], attrs: AutoDiffAttrs, dest: PlaceRef<'tcx, &'ll Value>, @@ -306,11 +212,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( } .to_string(); - // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple + // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple // functions. Unwrap will only panic, if LLVM gave us an invalid string. - let name = llvm::get_value_name(outer_fn); - let outer_fn_name = std::str::from_utf8(name).unwrap(); - ad_name.push_str(outer_fn_name); + ad_name.push_str(outer_name); // Let us assume the user wrote the following function square: // @@ -344,92 +248,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // ret double %0 // } // ``` - unsafe { - let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn); - - // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and - // think a bit more about what should go here. - let cc = llvm::LLVMGetFunctionCallConv(outer_fn); - let ad_fn = declare_simple_fn( - cx, - &ad_name, - llvm::CallConv::try_from(cc).expect("invalid callconv"), - llvm::UnnamedAddr::No, - llvm::Visibility::Default, - enzyme_ty, - ); - - // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to - // do it's work. - let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); - attributes::apply_to_llfn(ad_fn, Function, &[attr]); - - // We add a made-up attribute just such that we can recognize it after AD to update - // (no)-inline attributes. We'll then also remove this attribute. - let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker"); - attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]); - - let num_args = llvm::LLVMCountParams(&fn_to_diff); - let mut args = Vec::with_capacity(num_args as usize + 1); - args.push(fn_to_diff); - - let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); - if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { - args.push(cx.get_metadata_value(enzyme_primal_ret)); - } - if attrs.width > 1 { - let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); - args.push(cx.get_metadata_value(enzyme_width)); - args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); - } - - let has_sret = has_sret(outer_fn); - let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); - match_args_from_caller_to_enzyme( - &cx, - builder, - attrs.width, - &mut args, - &attrs.input_activity, - &outer_args, - has_sret, - ); - - let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - - builder.store_to_place(call, dest.val); + let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }; + + // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and + // think a bit more about what should go here. + // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now + let cc = 8; + let ad_fn = declare_simple_fn( + cx, + &ad_name, + llvm::CallConv::try_from(cc).expect("invalid callconv"), + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + enzyme_ty, + ); + + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to + // do it's work. + let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx); + attributes::apply_to_llfn(ad_fn, Function, &[attr]); + + let num_args = llvm::LLVMCountParams(&fn_to_diff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(fn_to_diff); + + let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap(); + if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) { + args.push(cx.get_metadata_value(enzyme_primal_ret)); + } + if attrs.width > 1 { + let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap(); + args.push(cx.get_metadata_value(enzyme_width)); + args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); + } - if cx.val_ty(call) == cx.type_void() || has_sret { - if has_sret { - // This is what we already have in our outer_fn (shortened): - // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) { - // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>) - // - // store [4 x double] %7, ptr %0, align 8 - // ret void - // } + let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); - // now store the result of the enzyme call into the sret pointer. - let sret_ptr = outer_args[0]; - let call_ty = cx.val_ty(call); - if attrs.width == 1 { - assert_eq!(cx.type_kind(call_ty), TypeKind::Struct); - } else { - assert_eq!(cx.type_kind(call_ty), TypeKind::Array); - } - llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr); - } - builder.ret_void(); - } + match_args_from_caller_to_enzyme( + &cx, + builder, + attrs.width, + &mut args, + &attrs.input_activity, + &outer_args, + ); - builder.store_to_place(call, dest.val); + let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); - // Let's crash in case that we messed something up above and generated invalid IR. - llvm::LLVMRustVerifyFunction( - outer_fn, - llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction, - ); - } + builder.store_to_place(call, dest.val); } pub(crate) fn differentiate<'ll>( diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 0324dff6ff256..e5aabaa8b76a8 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -652,7 +652,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { } } impl<'ll> SimpleCx<'ll> { - pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type { + pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type { assert_eq!(self.type_kind(ty), TypeKind::Function); unsafe { llvm::LLVMGetReturnType(ty) } } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index df1b268fbb600..0b06d354f5a30 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -174,10 +174,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; + let callee_ty = instance.ty(tcx, self.typing_env()); - let name = tcx.item_name(instance.def_id()); let fn_args = instance.args; + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); + let ret_ty = sig.output(); + let name = tcx.item_name(instance.def_id()); + + let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let simple = call_simple_intrinsic(self, name, args); let llval = match name { _ if simple.is_some() => simple.unwrap(), @@ -223,20 +230,14 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - // Declare target fn - let target_symbol = - symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); - let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty()); - let outer_fn: &'ll Value = - self.cx.declare_fn(&target_symbol, fn_abi, Some(instance)); - // Build body generate_enzyme_call( self, self.cx, fn_to_diff, - outer_fn, - args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore + name.as_str(), + llret_ty, + args, diff_attrs.clone(), result, ); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index f9fb0b4e2b964..bfa2beb81ce67 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -196,7 +196,12 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; - let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); + // FIXME(Sa4dUs): Get the actual safety level of the diff function + let safety = if has_autodiff { + hir::Safety::Safe + } else { + intrinsic_operation_unsafety(tcx, intrinsic_id) + }; let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { _ if has_autodiff => { From 470e4cabeec0e23196a6a3004aadd36df086d3f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 7 Jul 2025 17:24:17 +0000 Subject: [PATCH 07/15] Move logic to a dedicated `enzyme_autodiff` intrinsic --- compiler/rustc_builtin_macros/src/autodiff.rs | 143 ++++++++++++++++-- .../src/builder/autodiff.rs | 8 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 67 ++++---- .../rustc_codegen_ssa/src/codegen_attrs.rs | 2 +- .../rustc_hir_analysis/src/check/intrinsic.rs | 2 + compiler/rustc_span/src/symbol.rs | 1 + library/core/src/intrinsics/mod.rs | 4 + 7 files changed, 181 insertions(+), 46 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index c9885bc12c6b9..05189020b9aea 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -331,20 +331,23 @@ mod llvm_enzyme { .count() as u32; let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - // UNUSED + // TODO(Sa4dUs): Remove this and all the related logic let _d_body = gen_enzyme_body( ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, &generics, ); + let d_body = + call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics, + generics: generics.clone(), contract: None, - body: None, // This leads to an error when the ad function is inside a traits + body: Some(d_body), define_opaque: None, }); let mut rustc_ad_attr = @@ -431,10 +434,7 @@ mod llvm_enzyme { tokens: ts, }); - let rustc_intrinsic_attr = - P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_intrinsic))); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let intrinsic_attr = outer_normal_attr(&rustc_intrinsic_attr, new_id, span); + let vis_clone = vis.clone(); let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); @@ -442,7 +442,7 @@ mod llvm_enzyme { Annotatable::AssocItem(_, _) => { let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf); let d_fn = P(ast::AssocItem { - attrs: thin_vec![d_attr, intrinsic_attr], + attrs: thin_vec![d_attr], id: ast::DUMMY_NODE_ID, span, vis, @@ -452,15 +452,13 @@ mod llvm_enzyme { Annotatable::AssocItem(d_fn, Impl { of_trait: false }) } Annotatable::Item(_) => { - let mut d_fn = - ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Item(d_fn) } Annotatable::Stmt(_) => { - let mut d_fn = - ecx.item(span, thin_vec![d_attr, intrinsic_attr], ItemKind::Fn(asdf)); + let mut d_fn = ecx.item(span, thin_vec![d_attr], ItemKind::Fn(asdf)); d_fn.vis = vis; Annotatable::Stmt(P(ast::Stmt { @@ -474,7 +472,9 @@ mod llvm_enzyme { } }; - return vec![orig_annotatable, d_annotatable]; + let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); + + return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -495,6 +495,123 @@ mod llvm_enzyme { ty } + // Generate `enzyme_autodiff` intrinsic call + // ``` + // std::intrinsics::enzyme_autodiff(source, diff, (args)) + // ``` + fn call_enzyme_autodiff( + ecx: &ExtCtxt<'_>, + primal: Ident, + diff: Ident, + span: Span, + d_sig: &FnSig, + ) -> P { + let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); + let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + + let tuple_expr = ecx.expr_tuple( + span, + d_sig + .decl + .inputs + .iter() + .map(|arg| match arg.pat.kind { + PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)), + _ => todo!(), + }) + .collect::>() + .into(), + ); + + let enzyme_path = ecx.path( + span, + vec![ + Ident::from_str("std"), + Ident::from_str("intrinsics"), + Ident::from_str("enzyme_autodiff"), + ], + ); + let call_expr = ecx.expr_call( + span, + ecx.expr_path(enzyme_path), + vec![primal_path_expr, diff_path_expr, tuple_expr].into(), + ); + + let block = ecx.block_expr(call_expr); + + block + } + + // Generate dummy const to prevent primal function + // from being optimized away before applying enzyme + // ``` + // const _: () = + // { + // #[used] + // pub static DUMMY_PTR: fn_type = primal_fn; + // }; + // ``` + fn gen_dummy_const( + ecx: &ExtCtxt<'_>, + span: Span, + primal: Ident, + sig: FnSig, + generics: Generics, + vis: Visibility, + ) -> Annotatable { + // #[used] + let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); + let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); + let used_attr = outer_normal_attr(&used_attr, new_id, span); + + // static DUMMY_PTR: = + let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); + let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { + safety: sig.header.safety, + ext: sig.header.ext, + generic_params: generics.params, + decl: sig.decl, + decl_span: sig.span, + })); + let static_ty = ecx.ty(span, fn_ptr_ty); + + let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); + let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { + ident: static_ident, + ty: static_ty, + safety: ast::Safety::Default, + mutability: ast::Mutability::Not, + expr: Some(static_expr), + define_opaque: None, + })); + + let static_item = ast::Item { + attrs: thin_vec![used_attr], + id: ast::DUMMY_NODE_ID, + span, + vis, + kind: static_item_kind, + tokens: None, + }; + + let block_expr = ecx.expr_block(Box::new(ast::Block { + stmts: thin_vec![ecx.stmt_item(span, P(static_item))], + id: ast::DUMMY_NODE_ID, + rules: ast::BlockCheckMode::Default, + span, + tokens: None, + })); + + let const_item = ecx.item_const( + span, + Ident::from_str_and_span("_", span), + ecx.ty(span, ast::TyKind::Tup(thin_vec![])), + block_expr, + ); + + Annotatable::Item(const_item) + } + // Will generate a body of the type: // ``` // { diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 6af9f5738214e..752411012340a 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -10,7 +10,7 @@ use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED}; +use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; @@ -200,7 +200,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( fn_to_diff: &'ll Value, outer_name: &str, ret_ty: &'ll Type, - fn_args: &[OperandRef<'tcx, &'ll Value>], + fn_args: &[&'ll Value], attrs: AutoDiffAttrs, dest: PlaceRef<'tcx, &'ll Value>, ) { @@ -282,15 +282,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64)); } - let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect(); - match_args_from_caller_to_enzyme( &cx, builder, attrs.width, &mut args, &attrs.input_activity, - &outer_args, + fn_args, ); let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 0b06d354f5a30..7122f28b53575 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -3,6 +3,7 @@ use std::cmp::Ordering; use rustc_abi::{Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size}; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; +use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -196,48 +197,60 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { &[ptr, args[1].immediate()], ) } - _ if tcx.has_attr(instance.def_id(), sym::rustc_autodiff) => { - // NOTE(Sa4dUs): This is a hacky way to get the autodiff items - // so we can focus on the lowering of the intrinsic call - let mut source_id = None; - let mut diff_attrs = None; - let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect(); - - // Hacky way of getting primal-diff pair, only works for code with 1 autodiff call - for target_id in &items { - let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else { - continue; - }; + sym::enzyme_autodiff => { + let val_arr: Vec<&'ll Value> = match args[2].val { + crate::intrinsic::OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; - if target_attrs.is_source() { - source_id = Some(*target_id); - } else { - diff_attrs = Some(target_attrs); - } - } + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(self, i); + let field_layout = tuple_place.layout.field(self, i); + let llvm_ty = field_layout.llvm_type(self.cx); - if source_id.is_none() || diff_attrs.is_none() { - bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}"); - } + let field_val = + self.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } - let diff_attrs = diff_attrs.unwrap().clone(); + ret_arr + } + crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + }; - // Get source fn - let source_id = source_id.unwrap(); - let fn_source = Instance::mono(tcx, source_id); + // Get source, diff, and attrs + let source_id = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_source = Instance::mono(tcx, *source_id); let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + let diff_id = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, _) => def_id, + _ => bug!("invalid args"), + }; + let fn_diff = Instance::mono(tcx, *diff_id); + let diff_symbol = + symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, *diff_id); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + // Build body generate_enzyme_call( self, self.cx, fn_to_diff, - name.as_str(), + &diff_symbol, llret_ty, - args, + &val_arr, diff_attrs.clone(), result, ); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 7bd27eb3ef1cd..fc753af6fe8c2 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -720,7 +720,7 @@ impl<'a> MixedExportNameAndNoMangleState<'a> { /// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the /// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never /// panic, unless we introduced a bug when parsing the autodiff macro. -fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { +pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { let attrs = tcx.get_attrs(id, sym::rustc_autodiff); let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::>(); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index bfa2beb81ce67..8d8543c83cf25 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -134,6 +134,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::round_ties_even_f32 | sym::round_ties_even_f64 | sym::round_ties_even_f128 + | sym::enzyme_autodiff | sym::const_eval_select => hir::Safety::Safe, _ => hir::Safety::Unsafe, }; @@ -215,6 +216,7 @@ pub(crate) fn check_intrinsic_type( (n_tps, n_cts, inputs, output) } + sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), sym::breakpoint => (0, 0, vec![], tcx.types.unit), diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index c9262d24a1717..f7f95ac7ee512 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -915,6 +915,7 @@ symbols! { enumerate_method, env, env_CFG_RELEASE: env!("CFG_RELEASE"), + enzyme_autodiff, eprint_macro, eprintln_macro, eq, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index b5c3e91d04687..1aecc62ac315e 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3114,6 +3114,10 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64; #[rustc_intrinsic] pub const unsafe fn copysignf128(x: f128, y: f128) -> f128; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn enzyme_autodiff(f: F, df: G, args: T) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] From 38b27a2658099df115ef4f8693ef89eb7f011608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 7 Jul 2025 17:29:42 +0000 Subject: [PATCH 08/15] Remove attr checking from hir_analysis --- .../rustc_hir_analysis/src/check/intrinsic.rs | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 8d8543c83cf25..26da283070115 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -197,25 +197,8 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; - // FIXME(Sa4dUs): Get the actual safety level of the diff function - let safety = if has_autodiff { - hir::Safety::Safe - } else { - intrinsic_operation_unsafety(tcx, intrinsic_id) - }; let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { - _ if has_autodiff => { - let sig = tcx.fn_sig(intrinsic_id.to_def_id()); - let sig = sig.skip_binder(); - let n_tps = generics.own_counts().types; - let n_cts = generics.own_counts().consts; - - let inputs = sig.skip_binder().inputs().to_vec(); - let output = sig.skip_binder().output(); - - (n_tps, n_cts, inputs, output) - } sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), sym::abort => (0, 0, vec![], tcx.types.never), sym::unreachable => (0, 0, vec![], tcx.types.never), From b7fdb7b32d542f160bc6d185bef3cd695adb4f60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 8 Jul 2025 14:20:24 +0000 Subject: [PATCH 09/15] FIx generics error when passing fn as param to intrinsic --- compiler/rustc_builtin_macros/src/autodiff.rs | 51 ++++++++++++++++--- .../rustc_hir_analysis/src/check/intrinsic.rs | 3 +- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 05189020b9aea..b9a9689607980 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -16,8 +16,9 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind, - MetaItemInner, PatKind, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind, + FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, + PathSegment, QSelf, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -337,8 +338,14 @@ mod llvm_enzyme { &generics, ); - let d_body = - call_enzyme_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span, &d_sig); + let d_body = call_enzyme_autodiff( + ecx, + primal, + first_ident(&meta_item_vec[0]), + span, + &d_sig, + &generics, + ); // The first element of it is the name of the function to be generated let asdf = Box::new(ast::Fn { @@ -505,9 +512,10 @@ mod llvm_enzyme { diff: Ident, span: Span, d_sig: &FnSig, + generics: &Generics, ) -> P { - let primal_path_expr = ecx.expr_path(ecx.path_ident(span, primal)); - let diff_path_expr = ecx.expr_path(ecx.path_ident(span, diff)); + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span); let tuple_expr = ecx.expr_tuple( span, @@ -542,6 +550,37 @@ mod llvm_enzyme { block } + // Generate turbofish expression from fn name and generics + // Given `foo` and ``, gen `foo::` + fn gen_turbofish_expr( + ecx: &ExtCtxt<'_>, + ident: Ident, + generics: &Generics, + span: Span, + ) -> P { + let generic_args = generics + .params + .iter() + .map(|p| { + let path = ast::Path::from_ident(p.ident); + let ty = ecx.ty_path(path); + AngleBracketedArg::Arg(GenericArg::Type(ty)) + }) + .collect::>(); + + let args = AngleBracketedArgs { span, args: generic_args }; + + let segment = PathSegment { + ident, + id: ast::DUMMY_NODE_ID, + args: Some(P(GenericArgs::AngleBracketed(args))), + }; + + let path = Path { span, segments: thin_vec![segment], tokens: None }; + + ecx.expr_path(path) + } + // Generate dummy const to prevent primal function // from being optimized away before applying enzyme // ``` diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 26da283070115..305a215544ce2 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -171,8 +171,6 @@ pub(crate) fn check_intrinsic_type( } }; - let has_autodiff = tcx.has_attr(intrinsic_id, sym::rustc_autodiff); - let bound_vars = tcx.mk_bound_variable_kinds(&[ ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), ty::BoundVariableKind::Region(ty::BoundRegionKind::Anon), @@ -197,6 +195,7 @@ pub(crate) fn check_intrinsic_type( (Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty) }; + let safety = intrinsic_operation_unsafety(tcx, intrinsic_id); let n_lts = 0; let (n_tps, n_cts, inputs, output) = match intrinsic_name { sym::enzyme_autodiff => (4, 0, vec![param(0), param(1), param(2)], param(3)), From f896696148360fe88bd7eccb7eaa79733fdf8a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 8 Jul 2025 17:38:54 +0000 Subject: [PATCH 10/15] Use Instance::new_raw instead of Instance::mono Note(Sa4dUs): `cg/generic.rs` test is passing with some tweaks --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 7122f28b53575..c8278ada39cbd 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -222,21 +222,21 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { }; // Get source, diff, and attrs - let source_id = match fn_args.into_type_list(tcx)[0].kind() { - ty::FnDef(def_id, _) => def_id, + let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, source_params) => (def_id, source_params), _ => bug!("invalid args"), }; - let fn_source = Instance::mono(tcx, *source_id); + let fn_source = Instance::new_raw(*source_id, source_args); let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - let diff_id = match fn_args.into_type_list(tcx)[1].kind() { - ty::FnDef(def_id, _) => def_id, + let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, diff_args) => (def_id, diff_args), _ => bug!("invalid args"), }; - let fn_diff = Instance::mono(tcx, *diff_id); + let fn_diff = Instance::new_raw(*diff_id, diff_args); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); From 1e0f3759842e2b928e716d8504d3a42e20d893ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 10 Jul 2025 18:26:38 +0000 Subject: [PATCH 11/15] Hacky fix for issues at trait calls --- compiler/rustc_builtin_macros/src/autodiff.rs | 300 ++++-------------- .../src/builder/autodiff.rs | 3 +- compiler/rustc_codegen_llvm/src/intrinsic.rs | 149 +++++---- 3 files changed, 146 insertions(+), 306 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index b9a9689607980..7dd683ac39a8f 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -16,9 +16,9 @@ mod llvm_enzyme { use rustc_ast::tokenstream::*; use rustc_ast::visit::AssocCtxt::*; use rustc_ast::{ - self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, ExprKind, - FnRetTy, FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, - PathSegment, QSelf, TyKind, Visibility, + self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, FnRetTy, + FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path, + PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; use rustc_span::{Ident, Span, Symbol, kw, sym}; @@ -74,10 +74,12 @@ mod llvm_enzyme { } // Get information about the function the macro is applied to - fn extract_item_info(iitem: &P) -> Option<(Visibility, FnSig, Ident, Generics)> { + fn extract_item_info( + iitem: &P, + ) -> Option<(Visibility, FnSig, Ident, Generics, bool)> { match &iitem.kind { ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone())) + Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone(), false)) } _ => None, } @@ -229,16 +231,20 @@ mod llvm_enzyme { // first get information about the annotable item: visibility, signature, name and generic // parameters. // these will be used to generate the differentiated version of the function - let Some((vis, sig, primal, generics)) = (match &item { + let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item { Annotatable::Item(iitem) => extract_item_info(iitem), Annotatable::Stmt(stmt) => match &stmt.kind { ast::StmtKind::Item(iitem) => extract_item_info(iitem), _ => None, }, - Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind { - ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => { - Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone())) - } + Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind { + ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some(( + assoc_item.vis.clone(), + sig.clone(), + ident.clone(), + generics.clone(), + *of_trait, + )), _ => None, }, _ => None, @@ -333,18 +339,21 @@ mod llvm_enzyme { let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); // TODO(Sa4dUs): Remove this and all the related logic - let _d_body = gen_enzyme_body( - ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored, - &generics, - ); - - let d_body = call_enzyme_autodiff( + let d_body = gen_enzyme_body( ecx, + &x, + n_active, + &sig, + &d_sig, primal, - first_ident(&meta_item_vec[0]), + &new_args, span, - &d_sig, + sig_span, + idents, + errored, + first_ident(&meta_item_vec[0]), &generics, + impl_of_trait, ); // The first element of it is the name of the function to be generated @@ -441,8 +450,6 @@ mod llvm_enzyme { tokens: ts, }); - let vis_clone = vis.clone(); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span); let d_annotatable = match &item { @@ -479,9 +486,7 @@ mod llvm_enzyme { } }; - let dummy_const_annotatable = gen_dummy_const(ecx, span, primal, sig, generics, vis_clone); - - return vec![orig_annotatable, dummy_const_annotatable, d_annotatable]; + return vec![orig_annotatable, d_annotatable]; } // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be @@ -513,9 +518,10 @@ mod llvm_enzyme { span: Span, d_sig: &FnSig, generics: &Generics, - ) -> P { - let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span); - let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span); + is_impl: bool, + ) -> rustc_ast::Stmt { + let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl); + let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl); let tuple_expr = ecx.expr_tuple( span, @@ -545,9 +551,7 @@ mod llvm_enzyme { vec![primal_path_expr, diff_path_expr, tuple_expr].into(), ); - let block = ecx.block_expr(call_expr); - - block + ecx.stmt_expr(call_expr) } // Generate turbofish expression from fn name and generics @@ -557,6 +561,7 @@ mod llvm_enzyme { ident: Ident, generics: &Generics, span: Span, + is_impl: bool, ) -> P { let generic_args = generics .params @@ -568,7 +573,7 @@ mod llvm_enzyme { }) .collect::>(); - let args = AngleBracketedArgs { span, args: generic_args }; + let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args }; let segment = PathSegment { ident, @@ -576,79 +581,18 @@ mod llvm_enzyme { args: Some(P(GenericArgs::AngleBracketed(args))), }; - let path = Path { span, segments: thin_vec![segment], tokens: None }; - - ecx.expr_path(path) - } - - // Generate dummy const to prevent primal function - // from being optimized away before applying enzyme - // ``` - // const _: () = - // { - // #[used] - // pub static DUMMY_PTR: fn_type = primal_fn; - // }; - // ``` - fn gen_dummy_const( - ecx: &ExtCtxt<'_>, - span: Span, - primal: Ident, - sig: FnSig, - generics: Generics, - vis: Visibility, - ) -> Annotatable { - // #[used] - let used_attr = P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::used))); - let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id(); - let used_attr = outer_normal_attr(&used_attr, new_id, span); - - // static DUMMY_PTR: = - let static_ident = Ident::from_str_and_span("DUMMY_PTR", span); - let fn_ptr_ty = ast::TyKind::BareFn(Box::new(ast::BareFnTy { - safety: sig.header.safety, - ext: sig.header.ext, - generic_params: generics.params, - decl: sig.decl, - decl_span: sig.span, - })); - let static_ty = ecx.ty(span, fn_ptr_ty); - - let static_expr = ecx.expr_path(ecx.path(span, vec![primal])); - let static_item_kind = ast::ItemKind::Static(Box::new(ast::StaticItem { - ident: static_ident, - ty: static_ty, - safety: ast::Safety::Default, - mutability: ast::Mutability::Not, - expr: Some(static_expr), - define_opaque: None, - })); - - let static_item = ast::Item { - attrs: thin_vec![used_attr], - id: ast::DUMMY_NODE_ID, - span, - vis, - kind: static_item_kind, - tokens: None, + let segments = if is_impl { + thin_vec![ + PathSegment { ident: Ident::from_str("Foo"), id: ast::DUMMY_NODE_ID, args: None }, + segment, + ] + } else { + thin_vec![segment] }; - let block_expr = ecx.expr_block(Box::new(ast::Block { - stmts: thin_vec![ecx.stmt_item(span, P(static_item))], - id: ast::DUMMY_NODE_ID, - rules: ast::BlockCheckMode::Default, - span, - tokens: None, - })); - - let const_item = ecx.item_const( - span, - Ident::from_str_and_span("_", span), - ecx.ty(span, ast::TyKind::Tup(thin_vec![])), - block_expr, - ); + let path = Path { span, segments, tokens: None }; - Annotatable::Item(const_item) + ecx.expr_path(path) } // Will generate a body of the type: @@ -666,33 +610,14 @@ mod llvm_enzyme { ecx: &ExtCtxt<'_>, span: Span, primal: Ident, - new_names: &[String], - sig_span: Span, + _new_names: &[String], + _sig_span: Span, new_decl_span: Span, idents: &[Ident], errored: bool, generics: &Generics, ) -> (P, P, P, P) { let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let noop = ast::InlineAsm { - asm_macro: ast::AsmMacro::Asm, - template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())], - template_strs: Box::new([]), - operands: vec![], - clobber_abis: vec![], - options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM, - line_spans: vec![], - }; - let noop_expr = ecx.expr_asm(span, P(noop)); - let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated); - let unsf_block = ast::Block { - stmts: thin_vec![ecx.stmt_semi(noop_expr)], - id: ast::DUMMY_NODE_ID, - tokens: None, - rules: unsf, - span, - }; - let unsf_expr = ecx.expr_block(P(unsf_block)); let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); let primal_call = gen_primal_call(ecx, span, primal, idents, generics); let black_box_primal_call = ecx.expr_call( @@ -700,25 +625,13 @@ mod llvm_enzyme { blackbox_call_expr.clone(), thin_vec![primal_call.clone()], ); - let tup_args = new_names - .iter() - .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg)))) - .collect(); - - let black_box_remaining_args = ecx.expr_call( - sig_span, - blackbox_call_expr.clone(), - thin_vec![ecx.expr_tuple(sig_span, tup_args)], - ); let mut body = ecx.block(span, ThinVec::new()); - body.stmts.push(ecx.stmt_semi(unsf_expr)); // This uses primal args which won't be available if we errored before if !errored { body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); } - body.stmts.push(ecx.stmt_semi(black_box_remaining_args)); (body, primal_call, black_box_primal_call, blackbox_call_expr) } @@ -733,9 +646,9 @@ mod llvm_enzyme { /// from optimizing any arguments away. fn gen_enzyme_body( ecx: &ExtCtxt<'_>, - x: &AutoDiffAttrs, - n_active: u32, - sig: &ast::FnSig, + _x: &AutoDiffAttrs, + _n_active: u32, + _sig: &ast::FnSig, d_sig: &ast::FnSig, primal: Ident, new_names: &[String], @@ -743,19 +656,15 @@ mod llvm_enzyme { sig_span: Span, idents: Vec, errored: bool, + diff_ident: Ident, generics: &Generics, + is_impl: bool, ) -> P { let new_decl_span = d_sig.span; - // Just adding some default inline-asm and black_box usages to prevent early inlining - // and optimizations which alter the function signature. - // - // The bb_primal_call is the black_box call of the primal function. We keep it around, - // since it has the convenient property of returning the type of the primal function, - // Remember, we only care to match types here. - // No matter which return we pick, we always wrap it into a std::hint::black_box call, - // to prevent rustc from propagating it into the caller. - let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper( + // Add a call to the primal function to prevent it from being inlined + // and call `enzyme_autodiff` intrinsic (this also covers the return type) + let (mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper( ecx, span, primal, @@ -767,98 +676,15 @@ mod llvm_enzyme { generics, ); - if !has_ret(&d_sig.decl.output) { - // there is no return type that we have to match, () works fine. - return body; - } - - // Everything from here onwards just tries to fullfil the return type. Fun! - - // having an active-only return means we'll drop the original return type. - // So that can be treated identical to not having one in the first place. - let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret(); - - if primal_ret && n_active == 0 && x.mode.is_rev() { - // We only have the primal ret. - body.stmts.push(ecx.stmt_expr(bb_primal_call)); - return body; - } - - if !primal_ret && n_active == 1 { - // Again no tuple return, so return default float val. - let ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - let arg = ty.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - body.stmts.push(ecx.stmt_expr(default_call_expr)); - return body; - } - - let mut exprs: P = primal_call; - let d_ret_ty = match d_sig.decl.output { - FnRetTy::Ty(ref ty) => ty.clone(), - FnRetTy::Default(span) => { - panic!("Did not expect Default ret ty: {:?}", span); - } - }; - if x.mode.is_fwd() { - // Fwd mode is easy. If the return activity is Const, we support arbitrary types. - // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars. - // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function. - // In all three cases, we can return `std::hint::black_box(::default())`. - if x.ret_activity == DiffActivity::Const { - // Here we call the primal function, since our dummy function has the same return - // type due to the Const return activity. - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 }; - let y = ExprKind::Path( - Some(P(q)), - ecx.path_ident(span, Ident::with_dummy_span(kw::Default)), - ); - let default_call_expr = ecx.expr(span, y); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]); - } - } else if x.mode.is_rev() { - if x.width == 1 { - // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`. - match d_ret_ty.kind { - TyKind::Tup(ref args) => { - // We have a tuple return type. We need to create a tuple of the same size - // and fill it with default values. - let mut exprs2 = thin_vec![exprs]; - for arg in args.iter().skip(1) { - let arg = arg.kind.is_simple_path().unwrap(); - let tmp = ecx.def_site_path(&[arg, kw::Default]); - let default_call_expr = ecx.expr_path(ecx.path(span, tmp)); - let default_call_expr = - ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]); - exprs2.push(default_call_expr); - } - exprs = ecx.expr_tuple(new_decl_span, exprs2); - } - _ => { - // Interestingly, even the `-> ArbitraryType` case - // ends up getting matched and handled correctly above, - // so we don't have to handle any other case for now. - panic!("Unsupported return type: {:?}", d_ret_ty); - } - } - } - exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]); - } else { - unreachable!("Unsupported mode: {:?}", x.mode); - } - - body.stmts.push(ecx.stmt_expr(exprs)); + body.stmts.push(call_enzyme_autodiff( + ecx, + primal, + diff_ident, + new_decl_span, + d_sig, + generics, + is_impl, + )); body } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 752411012340a..19d89526675b3 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -252,8 +252,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and // think a bit more about what should go here. - // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now - let cc = 8; + let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) }; let ad_fn = declare_simple_fn( cx, &ad_name, diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index c8278ada39cbd..28ff9e515725e 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -9,11 +9,11 @@ use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphizati use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue}; use rustc_codegen_ssa::traits::*; -use rustc_hir as hir; use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; -use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty}; +use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate}; @@ -175,16 +175,9 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { span: Span, ) -> Result<(), ty::Instance<'tcx>> { let tcx = self.tcx; - let callee_ty = instance.ty(tcx, self.typing_env()); - let fn_args = instance.args; - - let sig = callee_ty.fn_sig(tcx); - let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig); - let ret_ty = sig.output(); let name = tcx.item_name(instance.def_id()); - - let llret_ty = self.layout_of(ret_ty).llvm_type(self); + let fn_args = instance.args; let simple = call_simple_intrinsic(self, name, args); let llval = match name { @@ -198,63 +191,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { ) } sym::enzyme_autodiff => { - let val_arr: Vec<&'ll Value> = match args[2].val { - crate::intrinsic::OperandValue::Ref(ref place_value) => { - let mut ret_arr = vec![]; - let tuple_place = PlaceRef { val: *place_value, layout: args[2].layout }; - - for i in 0..tuple_place.layout.layout.0.fields.count() { - let field_place = tuple_place.project_field(self, i); - let field_layout = tuple_place.layout.field(self, i); - let llvm_ty = field_layout.llvm_type(self.cx); - - let field_val = - self.load(llvm_ty, field_place.val.llval, field_place.val.align); - - ret_arr.push(field_val) - } - - ret_arr - } - crate::intrinsic::OperandValue::Pair(v1, v2) => vec![v1, v2], - OperandValue::Immediate(v) => vec![v], - OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), - }; - - // Get source, diff, and attrs - let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { - ty::FnDef(def_id, source_params) => (def_id, source_params), - _ => bug!("invalid args"), - }; - let fn_source = Instance::new_raw(*source_id, source_args); - let source_symbol = - symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); - let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol); - let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; - - let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { - ty::FnDef(def_id, diff_args) => (def_id, diff_args), - _ => bug!("invalid args"), - }; - let fn_diff = Instance::new_raw(*diff_id, diff_args); - let diff_symbol = - symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - - let diff_attrs = autodiff_attrs(tcx, *diff_id); - let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; - - // Build body - generate_enzyme_call( - self, - self.cx, - fn_to_diff, - &diff_symbol, - llret_ty, - &val_arr, - diff_attrs.clone(), - result, - ); - + codegen_enzyme_autodiff(self, tcx, instance, args, result); return Ok(()); } sym::is_val_statically_known => { @@ -1191,6 +1128,84 @@ fn get_rust_try_fn<'a, 'll, 'tcx>( rust_try } +fn codegen_enzyme_autodiff<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + args: &[OperandRef<'tcx, &'ll Value>], + result: PlaceRef<'tcx, &'ll Value>, +) { + let fn_args = instance.args; + let callee_ty = instance.ty(tcx, bx.typing_env()); + + let sig = callee_ty.fn_sig(tcx); + let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), sig); + + let ret_ty = sig.output(); + let llret_ty = bx.layout_of(ret_ty).llvm_type(bx); + + let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]); + + // Get source, diff, and attrs + let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, source_params) => (def_id, source_params), + _ => bug!("invalid args"), + }; + let fn_source = Instance::new_raw(*source_id, source_args); + let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); + let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol); + let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; + + let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() { + ty::FnDef(def_id, diff_args) => (def_id, diff_args), + _ => bug!("invalid args"), + }; + let fn_diff = Instance::new_raw(*diff_id, diff_args); + let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); + + let diff_attrs = autodiff_attrs(tcx, *diff_id); + let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; + + // Build body + generate_enzyme_call( + bx, + bx.cx, + fn_to_diff, + &diff_symbol, + llret_ty, + &val_arr, + diff_attrs.clone(), + result, + ); +} + +fn get_args_from_tuple<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + op: OperandRef<'tcx, &'ll Value>, +) -> Vec<&'ll Value> { + match op.val { + OperandValue::Ref(ref place_value) => { + let mut ret_arr = vec![]; + let tuple_place = PlaceRef { val: *place_value, layout: op.layout }; + + for i in 0..tuple_place.layout.layout.0.fields.count() { + let field_place = tuple_place.project_field(bx, i); + let field_layout = tuple_place.layout.field(bx, i); + let llvm_ty = field_layout.llvm_type(bx.cx); + + let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align); + + ret_arr.push(field_val) + } + + ret_arr + } + OperandValue::Pair(v1, v2) => vec![v1, v2], + OperandValue::Immediate(v) => vec![v], + OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"), + } +} + fn generic_simd_intrinsic<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, name: Symbol, From 146546e938024a0006d65d8fd7ea73a1263c8a5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 11 Jul 2025 14:54:49 +0000 Subject: [PATCH 12/15] Fix how fns where being retrieved at intrinsic cg --- compiler/rustc_codegen_llvm/src/intrinsic.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 28ff9e515725e..2724b834fa820 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1151,7 +1151,8 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( ty::FnDef(def_id, source_params) => (def_id, source_params), _ => bug!("invalid args"), }; - let fn_source = Instance::new_raw(*source_id, source_args); + let fn_source = + Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args).unwrap().unwrap(); let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE); let fn_to_diff: Option<&'ll llvm::Value> = bx.cx.get_function(&source_symbol); let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") }; @@ -1160,10 +1161,11 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>( ty::FnDef(def_id, diff_args) => (def_id, diff_args), _ => bug!("invalid args"), }; - let fn_diff = Instance::new_raw(*diff_id, diff_args); + let fn_diff = + Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap(); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - let diff_attrs = autodiff_attrs(tcx, *diff_id); + let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id()); let Some(diff_attrs) = diff_attrs else { bug!("could not find autodiff attrs") }; // Build body From f0348d91e4cd7e9ad18d87c8ce4a246a4d8b3885 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 11 Jul 2025 18:10:45 +0000 Subject: [PATCH 13/15] Use Self instead of Foo placeholder --- compiler/rustc_builtin_macros/src/autodiff.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 7dd683ac39a8f..b73e6133fdcde 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -583,7 +583,7 @@ mod llvm_enzyme { let segments = if is_impl { thin_vec![ - PathSegment { ident: Ident::from_str("Foo"), id: ast::DUMMY_NODE_ID, args: None }, + PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None }, segment, ] } else { @@ -630,7 +630,7 @@ mod llvm_enzyme { // This uses primal args which won't be available if we errored before if !errored { - body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone())); + body.stmts.push(ecx.stmt_semi(primal_call.clone())); } (body, primal_call, black_box_primal_call, blackbox_call_expr) From e58ebacf3b208c6e0119b1bd57b5f0bec7a1ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sat, 12 Jul 2025 10:28:06 +0000 Subject: [PATCH 14/15] Remove unused code --- compiler/rustc_builtin_macros/src/autodiff.rs | 69 +++----------- .../src/builder/autodiff.rs | 90 +------------------ compiler/rustc_codegen_llvm/src/errors.rs | 3 + compiler/rustc_codegen_llvm/src/lib.rs | 18 +--- compiler/rustc_codegen_ssa/src/back/lto.rs | 19 ---- compiler/rustc_codegen_ssa/src/back/write.rs | 6 +- .../rustc_codegen_ssa/src/traits/write.rs | 7 -- 7 files changed, 22 insertions(+), 190 deletions(-) diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index b73e6133fdcde..4fd5525c57b0a 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -262,7 +262,6 @@ mod llvm_enzyme { }; let has_ret = has_ret(&sig.decl.output); - let sig_span = ecx.with_call_site_ctxt(sig.span); // create TokenStream from vec elemtents: // meta_item doesn't have a .tokens field @@ -331,24 +330,13 @@ mod llvm_enzyme { } let span = ecx.with_def_site_ctxt(expand_span); - let n_active: u32 = x - .input_activity - .iter() - .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly) - .count() as u32; - let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); + let (d_sig, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span); - // TODO(Sa4dUs): Remove this and all the related logic let d_body = gen_enzyme_body( ecx, - &x, - n_active, - &sig, &d_sig, primal, - &new_args, span, - sig_span, idents, errored, first_ident(&meta_item_vec[0]), @@ -361,7 +349,7 @@ mod llvm_enzyme { defaultness: ast::Defaultness::Final, sig: d_sig, ident: first_ident(&meta_item_vec[0]), - generics: generics.clone(), + generics, contract: None, body: Some(d_body), define_opaque: None, @@ -542,7 +530,7 @@ mod llvm_enzyme { vec![ Ident::from_str("std"), Ident::from_str("intrinsics"), - Ident::from_str("enzyme_autodiff"), + Ident::with_dummy_span(sym::enzyme_autodiff), ], ); let call_expr = ecx.expr_call( @@ -555,7 +543,7 @@ mod llvm_enzyme { } // Generate turbofish expression from fn name and generics - // Given `foo` and ``, gen `foo::` + // Given `foo` and `` params, gen `foo::` fn gen_turbofish_expr( ecx: &ExtCtxt<'_>, ident: Ident, @@ -597,35 +585,19 @@ mod llvm_enzyme { // Will generate a body of the type: // ``` - // { - // unsafe { - // asm!("NOP"); - // } - // ::core::hint::black_box(primal(args)); - // ::core::hint::black_box((args, ret)); - // + // primal(args); + // std::intrinsics::enzyme_autodiff(primal, diff, (args)) // } // ``` fn init_body_helper( ecx: &ExtCtxt<'_>, span: Span, primal: Ident, - _new_names: &[String], - _sig_span: Span, - new_decl_span: Span, idents: &[Ident], errored: bool, generics: &Generics, - ) -> (P, P, P, P) { - let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]); - let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path)); + ) -> P { let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let black_box_primal_call = ecx.expr_call( - new_decl_span, - blackbox_call_expr.clone(), - thin_vec![primal_call.clone()], - ); - let mut body = ecx.block(span, ThinVec::new()); // This uses primal args which won't be available if we errored before @@ -633,7 +605,7 @@ mod llvm_enzyme { body.stmts.push(ecx.stmt_semi(primal_call.clone())); } - (body, primal_call, black_box_primal_call, blackbox_call_expr) + body } /// We only want this function to type-check, since we will replace the body @@ -646,14 +618,9 @@ mod llvm_enzyme { /// from optimizing any arguments away. fn gen_enzyme_body( ecx: &ExtCtxt<'_>, - _x: &AutoDiffAttrs, - _n_active: u32, - _sig: &ast::FnSig, d_sig: &ast::FnSig, primal: Ident, - new_names: &[String], span: Span, - sig_span: Span, idents: Vec, errored: bool, diff_ident: Ident, @@ -664,17 +631,7 @@ mod llvm_enzyme { // Add a call to the primal function to prevent it from being inlined // and call `enzyme_autodiff` intrinsic (this also covers the return type) - let (mut body, _primal_call, _bb_primal_call, _bb_call_expr) = init_body_helper( - ecx, - span, - primal, - new_names, - sig_span, - new_decl_span, - &idents, - errored, - generics, - ); + let mut body = init_body_helper(ecx, span, primal, &idents, errored, generics); body.stmts.push(call_enzyme_autodiff( ecx, @@ -771,7 +728,7 @@ mod llvm_enzyme { sig: &ast::FnSig, x: &AutoDiffAttrs, span: Span, - ) -> (ast::FnSig, Vec, Vec, bool) { + ) -> (ast::FnSig, Vec, bool) { let dcx = ecx.sess.dcx(); let has_ret = has_ret(&sig.decl.output); let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 }; @@ -783,7 +740,7 @@ mod llvm_enzyme { found: num_activities, }); // This is not the right signature, but we can continue parsing. - return (sig.clone(), vec![], vec![], true); + return (sig.clone(), vec![], true); } assert!(sig.decl.inputs.len() == x.input_activity.len()); assert!(has_ret == x.has_ret_activity()); @@ -826,7 +783,7 @@ mod llvm_enzyme { if errors { // This is not the right signature, but we can continue parsing. - return (sig.clone(), new_inputs, idents, true); + return (sig.clone(), idents, true); } let unsafe_activities = x @@ -1034,7 +991,7 @@ mod llvm_enzyme { } let d_sig = FnSig { header: d_header, decl: d_decl, span }; trace!("Generated signature: {:?}", d_sig); - (d_sig, new_inputs, idents, false) + (d_sig, idents, false) } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 19d89526675b3..4d7515b669906 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,43 +1,18 @@ use std::ptr; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; -use rustc_codegen_ssa::ModuleCodegen; -use rustc_codegen_ssa::back::write::ModuleConfig; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; -use rustc_errors::FatalError; use rustc_middle::bug; -use tracing::{debug, trace}; +use tracing::debug; -use crate::back::write::llvm_err; use crate::builder::{Builder, PlaceRef, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; -use crate::errors::{AutoDiffWithoutEnable, LlvmError}; use crate::llvm::AttributePlace::Function; use crate::llvm::{Metadata, True, Type}; use crate::value::Value; -use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm}; - -fn _get_params(fnc: &Value) -> Vec<&Value> { - let param_num = llvm::LLVMCountParams(fnc) as usize; - let mut fnc_args: Vec<&Value> = vec![]; - fnc_args.reserve(param_num); - unsafe { - llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); - fnc_args.set_len(param_num); - } - fnc_args -} - -fn _has_sret(fnc: &Value) -> bool { - let num_args = llvm::LLVMCountParams(fnc) as usize; - if num_args == 0 { - false - } else { - unsafe { llvm::LLVMRustHasAttributeAtIndex(fnc, 0, llvm::AttributeKind::StructRet) } - } -} +use crate::{attributes, llvm}; // When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the // original inputs, as well as metadata and the additional shadow arguments. @@ -294,62 +269,3 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( builder.store_to_place(call, dest.val); } - -pub(crate) fn differentiate<'ll>( - module: &'ll ModuleCodegen, - cgcx: &CodegenContext, - diff_items: Vec, - _config: &ModuleConfig, -) -> Result<(), FatalError> { - // TODO(Sa4dUs): delete all this logic - for item in &diff_items { - trace!("{}", item); - } - - let diag_handler = cgcx.create_dcx(); - - let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); - - // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag? - if !diff_items.is_empty() - && !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable) - { - return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable)); - } - - // Here we replace the placeholder code with the actual autodiff code, which calls Enzyme. - for item in diff_items.iter() { - let name = item.source.clone(); - let fn_def: Option<&llvm::Value> = cx.get_function(&name); - let Some(_fn_def) = fn_def else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find source function".to_owned(), - }, - )); - }; - debug!(?item.target); - let fn_target: Option<&llvm::Value> = cx.get_function(&item.target); - let Some(_fn_target) = fn_target else { - return Err(llvm_err( - diag_handler.handle(), - LlvmError::PrepareAutoDiff { - src: item.source.clone(), - target: item.target.clone(), - error: "could not find target function".to_owned(), - }, - )); - }; - - // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone()); - } - - // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts - - trace!("done with differentiate()"); - - Ok(()) -} diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index d50ad8a1a9cb4..d2ef3039107af 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -37,6 +37,8 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } +// TODO(Sa4dUs): we will need to reintroduce these errors somewhere +/* #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_lto)] pub(crate) struct AutoDiffWithoutLTO; @@ -44,6 +46,7 @@ pub(crate) struct AutoDiffWithoutLTO; #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_enable)] pub(crate) struct AutoDiffWithoutEnable; +*/ #[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index cdfffbe47bfa5..6682f44dc72c4 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -26,10 +26,9 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; use context::SimpleCx; -use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; +use errors::ParseTargetMachineConfig; use llvm_util::target_config; use rustc_ast::expand::allocator::AllocatorKind; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -43,7 +42,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::Symbol; mod back { @@ -227,19 +226,6 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } - /// Generate autodiff rules - fn autodiff( - cgcx: &CodegenContext, - module: &ModuleCodegen, - diff_fncs: Vec, - config: &ModuleConfig, - ) -> Result<(), FatalError> { - if cgcx.lto != Lto::Fat { - let dcx = cgcx.create_dcx(); - return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); - } - builder::autodiff::differentiate(module, cgcx, diff_fncs, config) - } } impl LlvmCodegenBackend { diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index ce6fe8a191b3b..5815f15c0fa83 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,13 +1,11 @@ use std::ffi::CString; use std::sync::Arc; -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::memmap::Mmap; use rustc_errors::FatalError; use super::write::CodegenContext; use crate::ModuleCodegen; -use crate::back::write::ModuleConfig; use crate::traits::*; pub struct ThinModule { @@ -78,23 +76,6 @@ impl LtoModuleCodegen { LtoModuleCodegen::Thin(ref m) => m.cost(), } } - - /// Run autodiff on Fat LTO module - pub fn autodiff( - self, - cgcx: &CodegenContext, - diff_fncs: Vec, - config: &ModuleConfig, - ) -> Result, FatalError> { - match &self { - LtoModuleCodegen::Fat(module) => { - B::autodiff(cgcx, &module, diff_fncs, config)?; - } - _ => panic!("autodiff called with non-fat LTO module"), - } - - Ok(self) - } } pub enum SerializedModule { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index c3bfe4c13cdf7..588904921e22f 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -408,12 +408,8 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let mut module = + let module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); - if cgcx.lto == Lto::Fat && !autodiff.is_empty() { - let config = cgcx.config(ModuleKind::Regular); - module = module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise()); - } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 07a0609fda1a1..40fd70fa5ad60 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,4 +1,3 @@ -use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -62,12 +61,6 @@ pub trait WriteBackendMethods: Clone + 'static { want_summary: bool, ) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); - fn autodiff( - cgcx: &CodegenContext, - module: &ModuleCodegen, - diff_fncs: Vec, - config: &ModuleConfig, - ) -> Result<(), FatalError>; } pub trait ThinBufferMethods: Send + Sync { From 5c1369269626abe628a6a0c815b6a339db230719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 14 Jul 2025 09:01:56 +0000 Subject: [PATCH 15/15] Remove primal call and collect it in mono instead --- compiler/rustc_builtin_macros/src/autodiff.rs | 12 ++---- compiler/rustc_monomorphize/src/collector.rs | 7 ++++ .../src/collector/autodiff.rs | 38 +++++++++++++++++++ 3 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 compiler/rustc_monomorphize/src/collector/autodiff.rs diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 4fd5525c57b0a..a4a038d876e23 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -594,17 +594,11 @@ mod llvm_enzyme { span: Span, primal: Ident, idents: &[Ident], - errored: bool, + _errored: bool, generics: &Generics, ) -> P { - let primal_call = gen_primal_call(ecx, span, primal, idents, generics); - let mut body = ecx.block(span, ThinVec::new()); - - // This uses primal args which won't be available if we errored before - if !errored { - body.stmts.push(ecx.stmt_semi(primal_call.clone())); - } - + let _primal_call = gen_primal_call(ecx, span, primal, idents, generics); + let body = ecx.block(span, ThinVec::new()); body } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index e90e32ebebb9f..e1ec4dadee0f0 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -205,6 +205,8 @@ //! this is not implemented however: a mono item will be produced //! regardless of whether it is actually needed or not. +mod autodiff; + use std::cell::OnceCell; use std::path::PathBuf; @@ -237,6 +239,8 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan}; use rustc_span::{DUMMY_SP, Span}; use tracing::{debug, instrument, trace}; +#[cfg(llvm_enzyme)] +use crate::collector::autodiff::collect_enzyme_autodiff_source_fn; use crate::errors::{self, EncounteredErrorWhileInstantiating, NoOptimizedMir, RecursionLimit}; #[derive(PartialEq)] @@ -916,6 +920,9 @@ fn visit_instance_use<'tcx>( return; } if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) { + #[cfg(llvm_enzyme)] + collect_enzyme_autodiff_source_fn(tcx, instance, intrinsic, output); + if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) { // The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will // be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs new file mode 100644 index 0000000000000..d062302ae53a6 --- /dev/null +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -0,0 +1,38 @@ +use rustc_middle::bug; +use rustc_middle::ty::{self, IntrinsicDef, TyCtxt}; +use tracing::debug; + +use crate::collector::{MonoItems, create_fn_mono_item}; + +pub(crate) fn collect_enzyme_autodiff_source_fn<'tcx>( + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + intrinsic: IntrinsicDef, + output: &mut MonoItems<'tcx>, +) { + if intrinsic.name != rustc_span::sym::enzyme_autodiff { + return; + }; + + debug!("enzyme_autodiff found"); + let (primal, span) = match instance.args[0].kind() { + rustc_middle::infer::canonical::ir::GenericArgKind::Type(ty) => match ty.kind() { + ty::FnDef(def_id, substs) => { + let span = tcx.def_span(def_id); + let instance = ty::Instance::expect_resolve( + tcx, + ty::TypingEnv::non_body_analysis(tcx, def_id), + *def_id, + substs, + span, + ); + + (instance, span) + } + _ => bug!("expected function"), + }, + _ => bug!("expected type"), + }; + + output.push(create_fn_mono_item(tcx, primal, span)); +}