@@ -73,10 +73,10 @@ mod llvm_enzyme {
73
73
}
74
74
75
75
// Get information about the function the macro is applied to
76
- fn extract_item_info ( iitem : & P < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident ) > {
76
+ fn extract_item_info ( iitem : & P < ast:: Item > ) -> Option < ( Visibility , FnSig , Ident , Generics ) > {
77
77
match & iitem. kind {
78
- ItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
79
- Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
78
+ ItemKind :: Fn ( box ast:: Fn { sig, ident, generics , .. } ) => {
79
+ Some ( ( iitem. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics . clone ( ) ) )
80
80
}
81
81
_ => None ,
82
82
}
@@ -210,16 +210,18 @@ mod llvm_enzyme {
210
210
}
211
211
let dcx = ecx. sess . dcx ( ) ;
212
212
213
- // first get information about the annotable item:
214
- let Some ( ( vis, sig, primal) ) = ( match & item {
213
+ // first get information about the annotable item: visibility, signature, name and generic
214
+ // parameters.
215
+ // these will be used to generate the differentiated version of the function
216
+ let Some ( ( vis, sig, primal, generics) ) = ( match & item {
215
217
Annotatable :: Item ( iitem) => extract_item_info ( iitem) ,
216
218
Annotatable :: Stmt ( stmt) => match & stmt. kind {
217
219
ast:: StmtKind :: Item ( iitem) => extract_item_info ( iitem) ,
218
220
_ => None ,
219
221
} ,
220
222
Annotatable :: AssocItem ( assoc_item, Impl { .. } ) => match & assoc_item. kind {
221
- ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, .. } ) => {
222
- Some ( ( assoc_item. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) ) )
223
+ ast:: AssocItemKind :: Fn ( box ast:: Fn { sig, ident, generics , .. } ) => {
224
+ Some ( ( assoc_item. vis . clone ( ) , sig. clone ( ) , ident. clone ( ) , generics . clone ( ) ) )
223
225
}
224
226
_ => None ,
225
227
} ,
@@ -303,14 +305,15 @@ mod llvm_enzyme {
303
305
let ( d_sig, new_args, idents, errored) = gen_enzyme_decl ( ecx, & sig, & x, span) ;
304
306
let d_body = gen_enzyme_body (
305
307
ecx, & x, n_active, & sig, & d_sig, primal, & new_args, span, sig_span, idents, errored,
308
+ & generics,
306
309
) ;
307
310
308
311
// The first element of it is the name of the function to be generated
309
312
let asdf = Box :: new ( ast:: Fn {
310
313
defaultness : ast:: Defaultness :: Final ,
311
314
sig : d_sig,
312
315
ident : first_ident ( & meta_item_vec[ 0 ] ) ,
313
- generics : Generics :: default ( ) ,
316
+ generics,
314
317
contract : None ,
315
318
body : Some ( d_body) ,
316
319
define_opaque : None ,
@@ -475,6 +478,7 @@ mod llvm_enzyme {
475
478
new_decl_span : Span ,
476
479
idents : & [ Ident ] ,
477
480
errored : bool ,
481
+ generics : & Generics ,
478
482
) -> ( P < ast:: Block > , P < ast:: Expr > , P < ast:: Expr > , P < ast:: Expr > ) {
479
483
let blackbox_path = ecx. std_path ( & [ sym:: hint, sym:: black_box] ) ;
480
484
let noop = ast:: InlineAsm {
@@ -497,7 +501,7 @@ mod llvm_enzyme {
497
501
} ;
498
502
let unsf_expr = ecx. expr_block ( P ( unsf_block) ) ;
499
503
let blackbox_call_expr = ecx. expr_path ( ecx. path ( span, blackbox_path) ) ;
500
- let primal_call = gen_primal_call ( ecx, span, primal, idents) ;
504
+ let primal_call = gen_primal_call ( ecx, span, primal, idents, generics ) ;
501
505
let black_box_primal_call = ecx. expr_call (
502
506
new_decl_span,
503
507
blackbox_call_expr. clone ( ) ,
@@ -546,6 +550,7 @@ mod llvm_enzyme {
546
550
sig_span : Span ,
547
551
idents : Vec < Ident > ,
548
552
errored : bool ,
553
+ generics : & Generics ,
549
554
) -> P < ast:: Block > {
550
555
let new_decl_span = d_sig. span ;
551
556
@@ -566,6 +571,7 @@ mod llvm_enzyme {
566
571
new_decl_span,
567
572
& idents,
568
573
errored,
574
+ generics,
569
575
) ;
570
576
571
577
if !has_ret ( & d_sig. decl . output ) {
@@ -608,7 +614,6 @@ mod llvm_enzyme {
608
614
panic ! ( "Did not expect Default ret ty: {:?}" , span) ;
609
615
}
610
616
} ;
611
-
612
617
if x. mode . is_fwd ( ) {
613
618
// Fwd mode is easy. If the return activity is Const, we support arbitrary types.
614
619
// Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
@@ -668,8 +673,10 @@ mod llvm_enzyme {
668
673
span : Span ,
669
674
primal : Ident ,
670
675
idents : & [ Ident ] ,
676
+ generics : & Generics ,
671
677
) -> P < ast:: Expr > {
672
678
let has_self = idents. len ( ) > 0 && idents[ 0 ] . name == kw:: SelfLower ;
679
+
673
680
if has_self {
674
681
let args: ThinVec < _ > =
675
682
idents[ 1 ..] . iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
@@ -678,7 +685,51 @@ mod llvm_enzyme {
678
685
} else {
679
686
let args: ThinVec < _ > =
680
687
idents. iter ( ) . map ( |arg| ecx. expr_path ( ecx. path_ident ( span, * arg) ) ) . collect ( ) ;
681
- let primal_call_expr = ecx. expr_path ( ecx. path_ident ( span, primal) ) ;
688
+ let mut primal_path = ecx. path_ident ( span, primal) ;
689
+
690
+ let is_generic = !generics. params . is_empty ( ) ;
691
+
692
+ match ( is_generic, primal_path. segments . last_mut ( ) ) {
693
+ ( true , Some ( function_path) ) => {
694
+ let primal_generic_types = generics
695
+ . params
696
+ . iter ( )
697
+ . filter ( |param| matches ! ( param. kind, ast:: GenericParamKind :: Type { .. } ) ) ;
698
+
699
+ let generated_generic_types = primal_generic_types
700
+ . map ( |type_param| {
701
+ let generic_param = TyKind :: Path (
702
+ None ,
703
+ ast:: Path {
704
+ span,
705
+ segments : thin_vec ! [ ast:: PathSegment {
706
+ ident: type_param. ident,
707
+ args: None ,
708
+ id: ast:: DUMMY_NODE_ID ,
709
+ } ] ,
710
+ tokens : None ,
711
+ } ,
712
+ ) ;
713
+
714
+ ast:: AngleBracketedArg :: Arg ( ast:: GenericArg :: Type ( P ( ast:: Ty {
715
+ id : type_param. id ,
716
+ span,
717
+ kind : generic_param,
718
+ tokens : None ,
719
+ } ) ) )
720
+ } )
721
+ . collect ( ) ;
722
+
723
+ function_path. args =
724
+ Some ( P ( ast:: GenericArgs :: AngleBracketed ( ast:: AngleBracketedArgs {
725
+ span,
726
+ args : generated_generic_types,
727
+ } ) ) ) ;
728
+ }
729
+ _ => { }
730
+ }
731
+
732
+ let primal_call_expr = ecx. expr_path ( primal_path) ;
682
733
ecx. expr_call ( span, primal_call_expr, args)
683
734
}
684
735
}
0 commit comments