@@ -15,7 +15,7 @@ use crate::context::SimpleCx;
1515use crate :: declare:: declare_simple_fn;
1616use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
1717use crate :: llvm:: AttributePlace :: Function ;
18- use crate :: llvm:: { Metadata , True } ;
18+ use crate :: llvm:: { Metadata , True , Type } ;
1919use crate :: value:: Value ;
2020use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
2121
@@ -30,7 +30,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> {
3030 fnc_args
3131}
3232
33- fn has_sret ( fnc : & Value ) -> bool {
33+ fn _has_sret ( fnc : & Value ) -> bool {
3434 let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
3535 if num_args == 0 {
3636 false
@@ -56,7 +56,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
5656 args : & mut Vec < & ' ll llvm:: Value > ,
5757 inputs : & [ DiffActivity ] ,
5858 outer_args : & [ & ' ll llvm:: Value ] ,
59- has_sret : bool ,
6059) {
6160 debug ! ( "matching autodiff arguments" ) ;
6261 // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -68,14 +67,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
6867 let mut outer_pos: usize = 0 ;
6968 let mut activity_pos = 0 ;
7069
71- if has_sret {
72- // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
73- // inner function will still return something. We increase our outer_pos by one,
74- // and once we're done with all other args we will take the return of the inner call and
75- // update the sret pointer with it
76- outer_pos = 1 ;
77- }
78-
7970 let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
8071 let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
8172 let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -194,92 +185,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
194185 }
195186}
196187
197- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
198- // arguments. We do however need to declare them with their correct return type.
199- // We already figured the correct return type out in our frontend, when generating the outer_fn,
200- // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
201- // Beyond sret, this article describes our challenges nicely:
202- // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
203- // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
204- fn compute_enzyme_fn_ty < ' ll > (
205- cx : & SimpleCx < ' ll > ,
206- attrs : & AutoDiffAttrs ,
207- fn_to_diff : & ' ll Value ,
208- outer_fn : & ' ll Value ,
209- ) -> & ' ll llvm:: Type {
210- let fn_ty = cx. get_type_of_global ( outer_fn) ;
211- let mut ret_ty = cx. get_return_type ( fn_ty) ;
212-
213- let has_sret = has_sret ( outer_fn) ;
214-
215- if has_sret {
216- // Now we don't just forward the return type, so we have to figure it out based on the
217- // primal return type, in combination with the autodiff settings.
218- let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
219- let inner_ret_ty = cx. get_return_type ( fn_ty) ;
220-
221- let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
222- if inner_ret_ty == void_ty {
223- // This indicates that even the inner function has an sret.
224- // Right now I only look for an sret in the outer function.
225- // This *probably* needs some extra handling, but I never ran
226- // into such a case. So I'll wait for user reports to have a test case.
227- bug ! ( "sret in inner function" ) ;
228- }
229-
230- if attrs. width == 1 {
231- // Enzyme returns a struct of style:
232- // `{ original_ret(if requested), float, float, ... }`
233- let mut struct_elements = vec ! [ ] ;
234- if attrs. has_primal_ret ( ) {
235- struct_elements. push ( inner_ret_ty) ;
236- }
237- // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
238- // and therefore part of the return struct.
239- let param_tys = cx. func_params_types ( fn_ty) ;
240- for ( act, param_ty) in attrs. input_activity . iter ( ) . zip ( param_tys) {
241- if matches ! ( act, DiffActivity :: Active ) {
242- // Now find the float type at position i based on the fn_ty,
243- // to know what (f16/f32/f64/...) to add to the struct.
244- struct_elements. push ( param_ty) ;
245- }
246- }
247- ret_ty = cx. type_struct ( & struct_elements, false ) ;
248- } else {
249- // First we check if we also have to deal with the primal return.
250- match attrs. mode {
251- DiffMode :: Forward => match attrs. ret_activity {
252- DiffActivity :: Dual => {
253- let arr_ty =
254- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
255- ret_ty = arr_ty;
256- }
257- DiffActivity :: DualOnly => {
258- let arr_ty =
259- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
260- ret_ty = arr_ty;
261- }
262- DiffActivity :: Const => {
263- todo ! ( "Not sure, do we need to do something here?" ) ;
264- }
265- _ => {
266- bug ! ( "unreachable" ) ;
267- }
268- } ,
269- DiffMode :: Reverse => {
270- todo ! ( "Handle sret for reverse mode" ) ;
271- }
272- _ => {
273- bug ! ( "unreachable" ) ;
274- }
275- }
276- }
277- }
278-
279- // LLVM can figure out the input types on it's own, so we take a shortcut here.
280- unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
281- }
282-
283188/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
284189/// function with expected naming and calling conventions[^1] which will be
285190/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -293,7 +198,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
293198 builder : & mut Builder < ' _ , ' ll , ' tcx > ,
294199 cx : & SimpleCx < ' ll > ,
295200 fn_to_diff : & ' ll Value ,
296- outer_fn : & ' ll Value ,
201+ outer_name : & str ,
202+ ret_ty : & ' ll Type ,
297203 fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
298204 attrs : AutoDiffAttrs ,
299205 dest : PlaceRef < ' tcx , & ' ll Value > ,
@@ -306,11 +212,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
306212 }
307213 . to_string ( ) ;
308214
309- // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
215+ // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
310216 // functions. Unwrap will only panic, if LLVM gave us an invalid string.
311- let name = llvm:: get_value_name ( outer_fn) ;
312- let outer_fn_name = std:: str:: from_utf8 ( name) . unwrap ( ) ;
313- ad_name. push_str ( outer_fn_name) ;
217+ ad_name. push_str ( outer_name) ;
314218
315219 // Let us assume the user wrote the following function square:
316220 //
@@ -344,92 +248,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
344248 // ret double %0
345249 // }
346250 // ```
347- unsafe {
348- let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
349-
350- // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
351- // think a bit more about what should go here.
352- let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
353- let ad_fn = declare_simple_fn (
354- cx,
355- & ad_name,
356- llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
357- llvm:: UnnamedAddr :: No ,
358- llvm:: Visibility :: Default ,
359- enzyme_ty,
360- ) ;
361-
362- // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
363- // do it's work.
364- let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
365- attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
366-
367- // We add a made-up attribute just such that we can recognize it after AD to update
368- // (no)-inline attributes. We'll then also remove this attribute.
369- let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
370- attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
371-
372- let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
373- let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
374- args. push ( fn_to_diff) ;
375-
376- let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
377- if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
378- args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
379- }
380- if attrs. width > 1 {
381- let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
382- args. push ( cx. get_metadata_value ( enzyme_width) ) ;
383- args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
384- }
385-
386- let has_sret = has_sret ( outer_fn) ;
387- let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
388- match_args_from_caller_to_enzyme (
389- & cx,
390- builder,
391- attrs. width ,
392- & mut args,
393- & attrs. input_activity ,
394- & outer_args,
395- has_sret,
396- ) ;
397-
398- let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
399-
400- builder. store_to_place ( call, dest. val ) ;
251+ let enzyme_ty = unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) } ;
252+
253+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
254+ // think a bit more about what should go here.
255+ // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
256+ let cc = 8 ;
257+ let ad_fn = declare_simple_fn (
258+ cx,
259+ & ad_name,
260+ llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
261+ llvm:: UnnamedAddr :: No ,
262+ llvm:: Visibility :: Default ,
263+ enzyme_ty,
264+ ) ;
265+
266+ // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
267+ // do it's work.
268+ let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
269+ attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
270+
271+ let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
272+ let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
273+ args. push ( fn_to_diff) ;
274+
275+ let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
276+ if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
277+ args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
278+ }
279+ if attrs. width > 1 {
280+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
281+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
282+ args. push ( cx. get_const_i64 ( attrs. width as u64 ) ) ;
283+ }
401284
402- if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
403- if has_sret {
404- // This is what we already have in our outer_fn (shortened):
405- // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
406- // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
407- // <Here we are, we want to add the following two lines>
408- // store [4 x double] %7, ptr %0, align 8
409- // ret void
410- // }
285+ let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
411286
412- // now store the result of the enzyme call into the sret pointer.
413- let sret_ptr = outer_args[ 0 ] ;
414- let call_ty = cx. val_ty ( call) ;
415- if attrs. width == 1 {
416- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Struct ) ;
417- } else {
418- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
419- }
420- llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
421- }
422- builder. ret_void ( ) ;
423- }
287+ match_args_from_caller_to_enzyme (
288+ & cx,
289+ builder,
290+ attrs. width ,
291+ & mut args,
292+ & attrs. input_activity ,
293+ & outer_args,
294+ ) ;
424295
425- builder. store_to_place ( call, dest . val ) ;
296+ let call = builder. call ( enzyme_ty , None , None , ad_fn , & args , None , None ) ;
426297
427- // Let's crash in case that we messed something up above and generated invalid IR.
428- llvm:: LLVMRustVerifyFunction (
429- outer_fn,
430- llvm:: LLVMRustVerifierFailureAction :: LLVMAbortProcessAction ,
431- ) ;
432- }
298+ builder. store_to_place ( call, dest. val ) ;
433299}
434300
435301pub ( crate ) fn differentiate < ' ll > (
0 commit comments