diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 90f15753e99c9..a3c913436ac76 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -1,12 +1,13 @@ //! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute, -//! we create an [`AutoDiffItem`] which contains the source and target function names. The source +//! we create an `RustcAutodiff` which contains the source and target function names. The source //! is the function to which the autodiff attribute is applied, and the target is the function //! getting generated by us (with a name given by the user as the first autodiff arg). use std::fmt::{self, Display, Formatter}; use std::str::FromStr; -use crate::expand::typetree::TypeTree; +use rustc_span::{Symbol, sym}; + use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::{Ty, TyKind}; @@ -31,6 +32,12 @@ pub enum DiffMode { Reverse, } +impl DiffMode { + pub fn all_modes() -> &'static [Symbol] { + &[sym::Source, sym::Forward, sym::Reverse] + } +} + /// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity. /// However, under forward mode we overwrite the previous shadow value, while for reverse mode /// we add to the previous shadow value. To not surprise users, we picked different names. @@ -76,43 +83,20 @@ impl DiffActivity { use DiffActivity::*; matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const) } -} -/// We generate one of these structs for each `#[autodiff(...)]` attribute. -#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] -pub struct AutoDiffItem { - /// The name of the function getting differentiated - pub source: String, - /// The name of the function being generated - pub target: String, - pub attrs: AutoDiffAttrs, - pub inputs: Vec, - pub output: TypeTree, -} -#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] -pub struct AutoDiffAttrs { - /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and - /// e.g. in the [JAX - /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). - pub mode: DiffMode, - /// A user-provided, batching width. If not given, we will default to 1 (no batching). - /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to: - /// - Calling the function 50 times with a batch size of 2 - /// - Calling the function 25 times with a batch size of 4, - /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from - /// cache locality, better re-usal of primal values, and other optimizations. - /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width` - /// times, so this massively increases code size. As such, values like 1024 are unlikely to - /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for - /// experiments for now and focus on documenting the implications of a large width. - pub width: u32, - pub ret_activity: DiffActivity, - pub input_activity: Vec, -} - -impl AutoDiffAttrs { - pub fn has_primal_ret(&self) -> bool { - matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual) + pub fn all_activities() -> &'static [Symbol] { + &[ + sym::None, + sym::Active, + sym::ActiveOnly, + sym::Const, + sym::Dual, + sym::Dualv, + sym::DualOnly, + sym::DualvOnly, + sym::Duplicated, + sym::DuplicatedOnly, + ] } } @@ -241,59 +225,3 @@ impl FromStr for DiffActivity { } } } - -impl AutoDiffAttrs { - pub fn has_ret_activity(&self) -> bool { - self.ret_activity != DiffActivity::None - } - pub fn has_active_only_ret(&self) -> bool { - self.ret_activity == DiffActivity::ActiveOnly - } - - pub const fn error() -> Self { - AutoDiffAttrs { - mode: DiffMode::Error, - width: 0, - ret_activity: DiffActivity::None, - input_activity: Vec::new(), - } - } - pub fn source() -> Self { - AutoDiffAttrs { - mode: DiffMode::Source, - width: 0, - ret_activity: DiffActivity::None, - input_activity: Vec::new(), - } - } - - pub fn is_active(&self) -> bool { - self.mode != DiffMode::Error - } - - pub fn is_source(&self) -> bool { - self.mode == DiffMode::Source - } - pub fn apply_autodiff(&self) -> bool { - !matches!(self.mode, DiffMode::Error | DiffMode::Source) - } - - pub fn into_item( - self, - source: String, - target: String, - inputs: Vec, - output: TypeTree, - ) -> AutoDiffItem { - AutoDiffItem { source, target, inputs, output, attrs: self } - } -} - -impl fmt::Display for AutoDiffItem { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs)?; - write!(f, " with inputs: {:?}", self.inputs)?; - write!(f, " with output: {:?}", self.output) - } -} diff --git a/compiler/rustc_attr_parsing/src/attributes/autodiff.rs b/compiler/rustc_attr_parsing/src/attributes/autodiff.rs new file mode 100644 index 0000000000000..118a4103b1a96 --- /dev/null +++ b/compiler/rustc_attr_parsing/src/attributes/autodiff.rs @@ -0,0 +1,117 @@ +use std::str::FromStr; + +use rustc_ast::LitKind; +use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode}; +use rustc_feature::{AttributeTemplate, template}; +use rustc_hir::attrs::{AttributeKind, RustcAutodiff}; +use rustc_hir::{MethodKind, Target}; +use rustc_span::{Symbol, sym}; +use thin_vec::ThinVec; + +use crate::attributes::prelude::Allow; +use crate::attributes::{AttributeOrder, OnDuplicate, SingleAttributeParser}; +use crate::context::{AcceptContext, Stage}; +use crate::parser::{ArgParser, MetaItemOrLitParser}; +use crate::target_checking::AllowedTargets; + +pub(crate) struct RustcAutodiffParser; + +impl SingleAttributeParser for RustcAutodiffParser { + const PATH: &[Symbol] = &[sym::rustc_autodiff]; + const ATTRIBUTE_ORDER: AttributeOrder = AttributeOrder::KeepInnermost; + const ON_DUPLICATE: OnDuplicate = OnDuplicate::Error; + const ALLOWED_TARGETS: AllowedTargets = AllowedTargets::AllowList(&[ + Allow(Target::Fn), + Allow(Target::Method(MethodKind::Inherent)), + Allow(Target::Method(MethodKind::Trait { body: true })), + Allow(Target::Method(MethodKind::TraitImpl)), + ]); + const TEMPLATE: AttributeTemplate = template!( + List: &["MODE", "WIDTH", "INPUT_ACTIVITIES", "OUTPUT_ACTIVITY"], + "https://doc.rust-lang.org/std/autodiff/index.html" + ); + + fn convert(cx: &mut AcceptContext<'_, '_, S>, args: &ArgParser) -> Option { + let list = match args { + ArgParser::NoArgs => return Some(AttributeKind::RustcAutodiff(None)), + ArgParser::List(list) => list, + ArgParser::NameValue(_) => { + cx.expected_list_or_no_args(cx.attr_span); + return None; + } + }; + + let mut items = list.mixed().peekable(); + + // Parse name + let Some(mode) = items.next() else { + cx.expected_at_least_one_argument(list.span); + return None; + }; + let Some(mode) = mode.meta_item() else { + cx.expected_identifier(mode.span()); + return None; + }; + let Ok(()) = mode.args().no_args() else { + cx.expected_identifier(mode.span()); + return None; + }; + let Some(mode) = mode.path().word() else { + cx.expected_identifier(mode.span()); + return None; + }; + let Ok(mode) = DiffMode::from_str(mode.as_str()) else { + cx.expected_specific_argument(mode.span, DiffMode::all_modes()); + return None; + }; + + // Parse width + let width = if let Some(width) = items.peek() + && let MetaItemOrLitParser::Lit(width) = width + && let LitKind::Int(width, _) = width.kind + && let Ok(width) = width.0.try_into() + { + _ = items.next(); + width + } else { + 1 + }; + + // Parse activities + let mut activities = ThinVec::new(); + for activity in items { + let MetaItemOrLitParser::MetaItemParser(activity) = activity else { + cx.expected_specific_argument(activity.span(), DiffActivity::all_activities()); + return None; + }; + let Ok(()) = activity.args().no_args() else { + cx.expected_specific_argument(activity.span(), DiffActivity::all_activities()); + return None; + }; + let Some(activity) = activity.path().word() else { + cx.expected_specific_argument(activity.span(), DiffActivity::all_activities()); + return None; + }; + let Ok(activity) = DiffActivity::from_str(activity.as_str()) else { + cx.expected_specific_argument(activity.span, DiffActivity::all_activities()); + return None; + }; + + activities.push(activity); + } + let Some(ret_activity) = activities.pop() else { + cx.expected_specific_argument( + list.span.with_lo(list.span.hi()), + DiffActivity::all_activities(), + ); + return None; + }; + + Some(AttributeKind::RustcAutodiff(Some(Box::new(RustcAutodiff { + mode, + width, + input_activity: activities, + ret_activity, + })))) + } +} diff --git a/compiler/rustc_attr_parsing/src/attributes/mod.rs b/compiler/rustc_attr_parsing/src/attributes/mod.rs index 8ee453d7f4649..223c88972d75e 100644 --- a/compiler/rustc_attr_parsing/src/attributes/mod.rs +++ b/compiler/rustc_attr_parsing/src/attributes/mod.rs @@ -30,6 +30,7 @@ use crate::target_checking::AllowedTargets; mod prelude; pub(crate) mod allow_unstable; +pub(crate) mod autodiff; pub(crate) mod body; pub(crate) mod cfg; pub(crate) mod cfg_select; diff --git a/compiler/rustc_attr_parsing/src/context.rs b/compiler/rustc_attr_parsing/src/context.rs index b82607e7c450d..802ee56f504b0 100644 --- a/compiler/rustc_attr_parsing/src/context.rs +++ b/compiler/rustc_attr_parsing/src/context.rs @@ -19,6 +19,7 @@ use rustc_span::{ErrorGuaranteed, Span, Symbol}; use crate::AttributeParser; // Glob imports to avoid big, bitrotty import lists use crate::attributes::allow_unstable::*; +use crate::attributes::autodiff::*; use crate::attributes::body::*; use crate::attributes::cfi_encoding::*; use crate::attributes::codegen_attrs::*; @@ -204,6 +205,7 @@ attribute_parsers!( Single, Single, Single, + Single, Single, Single, Single, diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 264f797a78250..2a0175043b6f3 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -8,8 +8,7 @@ mod llvm_enzyme { use std::string::String; use rustc_ast::expand::autodiff_attrs::{ - AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, - valid_ty_for_activity, + DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity, }; use rustc_ast::token::{Lit, LitKind, Token, TokenKind}; use rustc_ast::tokenstream::*; @@ -20,6 +19,7 @@ mod llvm_enzyme { MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility, }; use rustc_expand::base::{Annotatable, ExtCtxt}; + use rustc_hir::attrs::RustcAutodiff; use rustc_span::{Ident, Span, Symbol, sym}; use thin_vec::{ThinVec, thin_vec}; use tracing::{debug, trace}; @@ -87,7 +87,7 @@ mod llvm_enzyme { meta_item: &ThinVec, has_ret: bool, mode: DiffMode, - ) -> AutoDiffAttrs { + ) -> RustcAutodiff { let dcx = ecx.sess.dcx(); // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode. @@ -105,7 +105,7 @@ mod llvm_enzyme { span: meta_item[1].span(), width: x, }); - return AutoDiffAttrs::error(); + return RustcAutodiff::error(); } } } else { @@ -129,7 +129,7 @@ mod llvm_enzyme { }; } if errors { - return AutoDiffAttrs::error(); + return RustcAutodiff::error(); } // If a return type exist, we need to split the last activity, @@ -145,11 +145,11 @@ mod llvm_enzyme { (&DiffActivity::None, activities.as_slice()) }; - AutoDiffAttrs { + RustcAutodiff { mode, width, ret_activity: *ret_activity, - input_activity: input_activity.to_vec(), + input_activity: input_activity.iter().cloned().collect(), } } @@ -309,7 +309,7 @@ mod llvm_enzyme { ts.pop(); let ts: TokenStream = TokenStream::from_iter(ts); - let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode); + let x: RustcAutodiff = from_ast(ecx, &meta_item_vec, has_ret, mode); if !x.is_active() { // We encountered an error, so we return the original item. // This allows us to potentially parse other attributes. @@ -603,7 +603,7 @@ mod llvm_enzyme { fn gen_enzyme_decl( ecx: &ExtCtxt<'_>, sig: &ast::FnSig, - x: &AutoDiffAttrs, + x: &RustcAutodiff, span: Span, ) -> ast::FnSig { let dcx = ecx.sess.dcx(); diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 4b433e2b63616..f8f6439a7b0ea 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,9 +1,11 @@ use std::ptr; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; +use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode}; use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods}; +use rustc_data_structures::thin_vec::ThinVec; +use rustc_hir::attrs::RustcAutodiff; use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv}; use rustc_middle::{bug, ty}; use rustc_target::callconv::PassMode; @@ -18,7 +20,7 @@ pub(crate) fn adjust_activity_to_abi<'tcx>( tcx: TyCtxt<'tcx>, instance: Instance<'tcx>, typing_env: TypingEnv<'tcx>, - da: &mut Vec, + da: &mut ThinVec, ) { let fn_ty = instance.ty(tcx, typing_env); @@ -295,7 +297,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( outer_name: &str, ret_ty: &'ll Type, fn_args: &[&'ll Value], - attrs: AutoDiffAttrs, + attrs: &RustcAutodiff, dest: PlaceRef<'tcx, &'ll Value>, fnc_tree: FncTree, ) { diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 854fbbcea3ee1..6a7ee711ff8a8 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -6,7 +6,6 @@ use rustc_abi::{ Align, BackendRepr, ExternAbi, Float, HasDataLayout, Primitive, Size, WrappingRange, }; use rustc_codegen_ssa::base::{compare_simd_types, wants_msvc_seh, wants_wasm_eh}; -use rustc_codegen_ssa::codegen_attrs::autodiff_attrs; use rustc_codegen_ssa::common::{IntPredicate, TypeKind}; use rustc_codegen_ssa::errors::{ExpectedPointerMutability, InvalidMonomorphization}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -15,6 +14,7 @@ use rustc_codegen_ssa::traits::*; use rustc_data_structures::assert_matches; use rustc_hir as hir; use rustc_hir::def_id::LOCAL_CRATE; +use rustc_hir::find_attr; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; use rustc_middle::ty::offload_meta::OffloadMetadata; @@ -1367,7 +1367,9 @@ fn codegen_autodiff<'ll, 'tcx>( let val_arr = get_args_from_tuple(bx, args[2], fn_diff); let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE); - let Some(mut diff_attrs) = autodiff_attrs(tcx, fn_diff.def_id()) else { + let Some(Some(mut diff_attrs)) = + find_attr!(tcx, fn_diff.def_id(), RustcAutodiff(attr) => attr.clone()) + else { bug!("could not find autodiff attrs") }; @@ -1389,7 +1391,7 @@ fn codegen_autodiff<'ll, 'tcx>( &diff_symbol, llret_ty, &val_arr, - diff_attrs.clone(), + &diff_attrs, result, fnc_tree, ); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 43f039cc5ebfd..1ceb01337b118 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,8 +1,4 @@ -use std::str::FromStr; - use rustc_abi::{Align, ExternAbi}; -use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; -use rustc_ast::{LitKind, MetaItem, MetaItemInner}; use rustc_hir::attrs::{ AttributeKind, EiiImplResolution, InlineAttr, Linkage, RtsanSetting, UsedBy, }; @@ -14,7 +10,6 @@ use rustc_middle::middle::codegen_fn_attrs::{ }; use rustc_middle::mir::mono::Visibility; use rustc_middle::query::Providers; -use rustc_middle::span_bug; use rustc_middle::ty::{self as ty, TyCtxt}; use rustc_session::lint; use rustc_session::parse::feature_err; @@ -614,116 +609,6 @@ fn inherited_align<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option { tcx.codegen_fn_attrs(tcx.trait_item_of(def_id)?).alignment } -/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)] -/// macros. There are two forms. The pure one without args to mark primal functions (the functions -/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the -/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never -/// panic, unless we introduced a bug when parsing the autodiff macro. -//FIXME(jdonszelmann): put in the main loop. No need to have two..... :/ Let's do that when we make autodiff parsed. -pub fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option { - #[allow(deprecated)] - let attrs = tcx.get_attrs(id, sym::rustc_autodiff); - - let attrs = attrs.filter(|attr| attr.has_name(sym::rustc_autodiff)).collect::>(); - - // check for exactly one autodiff attribute on placeholder functions. - // There should only be one, since we generate a new placeholder per ad macro. - let attr = match &attrs[..] { - [] => return None, - [attr] => attr, - _ => { - span_bug!(attrs[1].span(), "cg_ssa: rustc_autodiff should only exist once per source"); - } - }; - - let list = attr.meta_item_list().unwrap_or_default(); - - // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions - if list.is_empty() { - return Some(AutoDiffAttrs::source()); - } - - let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities"); - }; - let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode { - p1.segments.first().unwrap().ident - } else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain mode"); - }; - - // parse mode - let mode = match mode.as_str() { - "Forward" => DiffMode::Forward, - "Reverse" => DiffMode::Reverse, - _ => { - span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode"); - } - }; - - let width: u32 = match width_meta { - MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => { - let w = p1.segments.first().unwrap().ident; - match w.as_str().parse() { - Ok(val) => val, - Err(_) => { - span_bug!(w.span, "rustc_autodiff width should fit u32"); - } - } - } - MetaItemInner::Lit(lit) => { - if let LitKind::Int(val, _) = lit.kind { - match val.get().try_into() { - Ok(val) => val, - Err(_) => { - span_bug!(lit.span, "rustc_autodiff width should fit u32"); - } - } - } else { - span_bug!(lit.span, "rustc_autodiff width should be an integer"); - } - } - }; - - // First read the ret symbol from the attribute - let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity else { - span_bug!(attr.span(), "rustc_autodiff attribute must contain the return activity"); - }; - let ret_symbol = p1.segments.first().unwrap().ident; - - // Then parse it into an actual DiffActivity - let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else { - span_bug!(ret_symbol.span, "invalid return activity"); - }; - - // Now parse all the intermediate (input) activities - let mut arg_activities: Vec = vec![]; - for arg in input_activities { - let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p2, .. }) = arg { - match p2.segments.first() { - Some(x) => x.ident, - None => { - span_bug!( - arg.span(), - "rustc_autodiff attribute must contain the input activity" - ); - } - } - } else { - span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity"); - }; - - match DiffActivity::from_str(arg_symbol.as_str()) { - Ok(arg_activity) => arg_activities.push(arg_activity), - Err(_) => { - span_bug!(arg_symbol.span, "invalid input activity"); - } - } - } - - Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities }) -} - pub(crate) fn provide(providers: &mut Providers) { *providers = Providers { codegen_fn_attrs, diff --git a/compiler/rustc_hir/src/attrs/data_structures.rs b/compiler/rustc_hir/src/attrs/data_structures.rs index 68f5bb94c3fe8..91409108a7533 100644 --- a/compiler/rustc_hir/src/attrs/data_structures.rs +++ b/compiler/rustc_hir/src/attrs/data_structures.rs @@ -1,9 +1,12 @@ use std::borrow::Cow; +use std::fmt; use std::path::PathBuf; pub use ReprAttr::*; use rustc_abi::Align; pub use rustc_ast::attr::data_structures::*; +use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::TypeTree; use rustc_ast::token::DocFragmentKind; use rustc_ast::{AttrStyle, Path, ast}; use rustc_data_structures::fx::FxIndexMap; @@ -794,6 +797,103 @@ pub struct RustcCleanQueries { pub span: Span, } +#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(HashStable_Generic, Encodable, Decodable, PrintAttribute)] +pub struct RustcAutodiff { + /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and + /// e.g. in the [JAX + /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions). + pub mode: DiffMode, + /// A user-provided, batching width. If not given, we will default to 1 (no batching). + /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to: + /// - Calling the function 50 times with a batch size of 2 + /// - Calling the function 25 times with a batch size of 4, + /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from + /// cache locality, better re-usal of primal values, and other optimizations. + /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width` + /// times, so this massively increases code size. As such, values like 1024 are unlikely to + /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for + /// experiments for now and focus on documenting the implications of a large width. + pub width: u32, + pub input_activity: ThinVec, + pub ret_activity: DiffActivity, +} + +impl RustcAutodiff { + pub fn has_primal_ret(&self) -> bool { + matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual) + } +} + +impl RustcAutodiff { + pub fn has_ret_activity(&self) -> bool { + self.ret_activity != DiffActivity::None + } + pub fn has_active_only_ret(&self) -> bool { + self.ret_activity == DiffActivity::ActiveOnly + } + + pub fn error() -> Self { + RustcAutodiff { + mode: DiffMode::Error, + width: 0, + ret_activity: DiffActivity::None, + input_activity: ThinVec::new(), + } + } + + pub fn source() -> Self { + RustcAutodiff { + mode: DiffMode::Source, + width: 0, + ret_activity: DiffActivity::None, + input_activity: ThinVec::new(), + } + } + + pub fn is_active(&self) -> bool { + self.mode != DiffMode::Error + } + + pub fn is_source(&self) -> bool { + self.mode == DiffMode::Source + } + pub fn apply_autodiff(&self) -> bool { + !matches!(self.mode, DiffMode::Error | DiffMode::Source) + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +/// We generate one of these structs for each `#[autodiff(...)]` attribute. +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] +pub struct AutoDiffItem { + /// The name of the function getting differentiated + pub source: String, + /// The name of the function being generated + pub target: String, + pub attrs: RustcAutodiff, + pub inputs: Vec, + pub output: TypeTree, +} + +impl fmt::Display for AutoDiffItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Differentiating {} -> {}", self.source, self.target)?; + write!(f, " with attributes: {:?}", self.attrs)?; + write!(f, " with inputs: {:?}", self.inputs)?; + write!(f, " with output: {:?}", self.output) + } +} + /// Represents parsed *built-in* inert attributes. /// /// ## Overview @@ -1186,6 +1286,9 @@ pub enum AttributeKind { /// Represents `#[rustc_as_ptr]` (used by the `dangling_pointers_from_temporaries` lint). RustcAsPtr(Span), + /// Represents `#[rustc_autodiff]`. + RustcAutodiff(Option>), + /// Represents `#[rustc_default_body_unstable]`. RustcBodyStability { stability: DefaultBodyStability, diff --git a/compiler/rustc_hir/src/attrs/encode_cross_crate.rs b/compiler/rustc_hir/src/attrs/encode_cross_crate.rs index c50d38b6d673a..cd41a2b9b28c7 100644 --- a/compiler/rustc_hir/src/attrs/encode_cross_crate.rs +++ b/compiler/rustc_hir/src/attrs/encode_cross_crate.rs @@ -102,6 +102,7 @@ impl AttributeKind { RustcAllowConstFnUnstable(..) => No, RustcAllowIncoherentImpl(..) => No, RustcAsPtr(..) => Yes, + RustcAutodiff(..) => Yes, RustcBodyStability { .. } => No, RustcBuiltinMacro { .. } => Yes, RustcCaptureAnalysis => No, diff --git a/compiler/rustc_hir/src/attrs/pretty_printing.rs b/compiler/rustc_hir/src/attrs/pretty_printing.rs index 8fce529010150..9d14f9de3078d 100644 --- a/compiler/rustc_hir/src/attrs/pretty_printing.rs +++ b/compiler/rustc_hir/src/attrs/pretty_printing.rs @@ -6,6 +6,7 @@ use rustc_abi::Align; use rustc_ast::ast::{Path, join_path_idents}; use rustc_ast::attr::data_structures::CfgEntry; use rustc_ast::attr::version::RustcVersion; +use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode}; use rustc_ast::token::{CommentKind, DocFragmentKind}; use rustc_ast::{AttrId, AttrStyle, IntTy, UintTy}; use rustc_ast_pretty::pp::Printer; @@ -191,7 +192,7 @@ macro_rules! print_tup { print_tup!(A B C D E F G H); print_skip!(Span, (), ErrorGuaranteed, AttrId); -print_disp!(u8, u16, u128, usize, bool, NonZero, Limit); +print_disp!(u8, u16, u32, u128, usize, bool, NonZero, Limit); print_debug!( Symbol, Ident, @@ -206,4 +207,6 @@ print_debug!( DefId, RustcVersion, CfgEntry, + DiffActivity, + DiffMode, ); diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 0f254aaa9fa0a..d00ef84725071 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -90,7 +90,7 @@ macro_rules! arena_types { [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, [decode] attribute: rustc_hir::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, - [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, + [] autodiff_item: rustc_hir::attrs::AutoDiffItem, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, [] valtree: rustc_middle::ty::ValTreeKind>, [] stable_order_of_exportable_impls: diff --git a/compiler/rustc_mir_transform/src/cross_crate_inline.rs b/compiler/rustc_mir_transform/src/cross_crate_inline.rs index 7435fbe8d38a5..19ffffdd1eca7 100644 --- a/compiler/rustc_mir_transform/src/cross_crate_inline.rs +++ b/compiler/rustc_mir_transform/src/cross_crate_inline.rs @@ -8,7 +8,6 @@ use rustc_middle::mir::*; use rustc_middle::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_session::config::{InliningThreshold, OptLevel}; -use rustc_span::sym; use crate::{inline, pass_manager as pm}; @@ -37,11 +36,7 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool { } // FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880 - #[allow(deprecated)] - if tcx.has_attr(def_id, sym::autodiff_forward) - || tcx.has_attr(def_id, sym::autodiff_reverse) - || tcx.has_attr(def_id, sym::rustc_autodiff) - { + if find_attr!(tcx, def_id, RustcAutodiff(..)) { return true; } diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index 89bd39b77e64b..dacb02afe1612 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -53,7 +53,6 @@ use rustc_span::{BytePos, DUMMY_SP, Ident, Span, Symbol, sym}; use rustc_trait_selection::error_reporting::InferCtxtErrorExt; use rustc_trait_selection::infer::{TyCtxtInferExt, ValuePairs}; use rustc_trait_selection::traits::ObligationCtxt; -use tracing::debug; use crate::errors; @@ -299,6 +298,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> { | AttributeKind::RustcAllocatorZeroed | AttributeKind::RustcAllocatorZeroedVariant { .. } | AttributeKind::RustcAsPtr(..) + | AttributeKind::RustcAutodiff(..) | AttributeKind::RustcBodyStability { .. } | AttributeKind::RustcBuiltinMacro { .. } | AttributeKind::RustcCaptureAnalysis @@ -390,22 +390,13 @@ impl<'tcx> CheckAttrVisitor<'tcx> { Attribute::Unparsed(attr_item) => { style = Some(attr_item.style); match attr.path().as_slice() { - [sym::autodiff_forward, ..] | [sym::autodiff_reverse, ..] => { - self.check_autodiff(hir_id, attr, span, target) - } [ // ok sym::allow | sym::expect | sym::warn | sym::deny - | sym::forbid - // internal - | sym::rustc_on_unimplemented - | sym::rustc_layout - | sym::rustc_autodiff - // crate-level attrs, are checked below - | sym::feature, + | sym::forbid, .. ] => {} [name, rest@..] => { @@ -1863,18 +1854,6 @@ impl<'tcx> CheckAttrVisitor<'tcx> { } } - /// Checks if `#[autodiff]` is applied to an item other than a function item. - fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) { - debug!("check_autodiff"); - match target { - Target::Fn => {} - _ => { - self.dcx().emit_err(errors::AutoDiffAttr { attr_span: span }); - self.abort.set(true); - } - } - } - fn check_loop_match(&self, hir_id: HirId, attr_span: Span, target: Target) { let node_span = self.tcx.hir_span(hir_id); diff --git a/compiler/rustc_passes/src/errors.rs b/compiler/rustc_passes/src/errors.rs index b9ada150d0301..0cf0d1a5c80ff 100644 --- a/compiler/rustc_passes/src/errors.rs +++ b/compiler/rustc_passes/src/errors.rs @@ -19,14 +19,6 @@ use crate::lang_items::Duplicate; #[diag("`#[diagnostic::do_not_recommend]` can only be placed on trait implementations")] pub(crate) struct IncorrectDoNotRecommendLocation; -#[derive(Diagnostic)] -#[diag("`#[autodiff]` should be applied to a function")] -pub(crate) struct AutoDiffAttr { - #[primary_span] - #[label("not a function")] - pub attr_span: Span, -} - #[derive(Diagnostic)] #[diag("`#[loop_match]` should be applied to a loop")] pub(crate) struct LoopMatchAttr { diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 422a15b060cc0..731a838530729 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -157,6 +157,8 @@ symbols! { Abi, AcqRel, Acquire, + Active, + ActiveOnly, Alignment, Arc, ArcWeak, @@ -213,6 +215,12 @@ symbols! { Deref, DispatchFromDyn, Display, + Dual, + DualOnly, + Dualv, + DualvOnly, + Duplicated, + DuplicatedOnly, DynTrait, Enum, Eq, @@ -310,6 +318,7 @@ symbols! { Slice, SliceIndex, Some, + Source, SpanCtxt, Str, String,