@@ -15,7 +15,7 @@ use crate::context::SimpleCx;
15
15
use crate :: declare:: declare_simple_fn;
16
16
use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
17
17
use crate :: llvm:: AttributePlace :: Function ;
18
- use crate :: llvm:: { Metadata , True } ;
18
+ use crate :: llvm:: { Metadata , True , Type } ;
19
19
use crate :: value:: Value ;
20
20
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
21
21
@@ -30,7 +30,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> {
30
30
fnc_args
31
31
}
32
32
33
- fn has_sret ( fnc : & Value ) -> bool {
33
+ fn _has_sret ( fnc : & Value ) -> bool {
34
34
let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
35
35
if num_args == 0 {
36
36
false
@@ -56,7 +56,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
56
56
args : & mut Vec < & ' ll llvm:: Value > ,
57
57
inputs : & [ DiffActivity ] ,
58
58
outer_args : & [ & ' ll llvm:: Value ] ,
59
- has_sret : bool ,
60
59
) {
61
60
debug ! ( "matching autodiff arguments" ) ;
62
61
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -68,14 +67,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
68
67
let mut outer_pos: usize = 0 ;
69
68
let mut activity_pos = 0 ;
70
69
71
- if has_sret {
72
- // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
73
- // inner function will still return something. We increase our outer_pos by one,
74
- // and once we're done with all other args we will take the return of the inner call and
75
- // update the sret pointer with it
76
- outer_pos = 1 ;
77
- }
78
-
79
70
let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
80
71
let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
81
72
let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
@@ -194,92 +185,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
194
185
}
195
186
}
196
187
197
- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
198
- // arguments. We do however need to declare them with their correct return type.
199
- // We already figured the correct return type out in our frontend, when generating the outer_fn,
200
- // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
201
- // Beyond sret, this article describes our challenges nicely:
202
- // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
203
- // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
204
- fn compute_enzyme_fn_ty < ' ll > (
205
- cx : & SimpleCx < ' ll > ,
206
- attrs : & AutoDiffAttrs ,
207
- fn_to_diff : & ' ll Value ,
208
- outer_fn : & ' ll Value ,
209
- ) -> & ' ll llvm:: Type {
210
- let fn_ty = cx. get_type_of_global ( outer_fn) ;
211
- let mut ret_ty = cx. get_return_type ( fn_ty) ;
212
-
213
- let has_sret = has_sret ( outer_fn) ;
214
-
215
- if has_sret {
216
- // Now we don't just forward the return type, so we have to figure it out based on the
217
- // primal return type, in combination with the autodiff settings.
218
- let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
219
- let inner_ret_ty = cx. get_return_type ( fn_ty) ;
220
-
221
- let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
222
- if inner_ret_ty == void_ty {
223
- // This indicates that even the inner function has an sret.
224
- // Right now I only look for an sret in the outer function.
225
- // This *probably* needs some extra handling, but I never ran
226
- // into such a case. So I'll wait for user reports to have a test case.
227
- bug ! ( "sret in inner function" ) ;
228
- }
229
-
230
- if attrs. width == 1 {
231
- // Enzyme returns a struct of style:
232
- // `{ original_ret(if requested), float, float, ... }`
233
- let mut struct_elements = vec ! [ ] ;
234
- if attrs. has_primal_ret ( ) {
235
- struct_elements. push ( inner_ret_ty) ;
236
- }
237
- // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
238
- // and therefore part of the return struct.
239
- let param_tys = cx. func_params_types ( fn_ty) ;
240
- for ( act, param_ty) in attrs. input_activity . iter ( ) . zip ( param_tys) {
241
- if matches ! ( act, DiffActivity :: Active ) {
242
- // Now find the float type at position i based on the fn_ty,
243
- // to know what (f16/f32/f64/...) to add to the struct.
244
- struct_elements. push ( param_ty) ;
245
- }
246
- }
247
- ret_ty = cx. type_struct ( & struct_elements, false ) ;
248
- } else {
249
- // First we check if we also have to deal with the primal return.
250
- match attrs. mode {
251
- DiffMode :: Forward => match attrs. ret_activity {
252
- DiffActivity :: Dual => {
253
- let arr_ty =
254
- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
255
- ret_ty = arr_ty;
256
- }
257
- DiffActivity :: DualOnly => {
258
- let arr_ty =
259
- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
260
- ret_ty = arr_ty;
261
- }
262
- DiffActivity :: Const => {
263
- todo ! ( "Not sure, do we need to do something here?" ) ;
264
- }
265
- _ => {
266
- bug ! ( "unreachable" ) ;
267
- }
268
- } ,
269
- DiffMode :: Reverse => {
270
- todo ! ( "Handle sret for reverse mode" ) ;
271
- }
272
- _ => {
273
- bug ! ( "unreachable" ) ;
274
- }
275
- }
276
- }
277
- }
278
-
279
- // LLVM can figure out the input types on it's own, so we take a shortcut here.
280
- unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
281
- }
282
-
283
188
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
284
189
/// function with expected naming and calling conventions[^1] which will be
285
190
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -293,7 +198,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
293
198
builder : & mut Builder < ' _ , ' ll , ' tcx > ,
294
199
cx : & SimpleCx < ' ll > ,
295
200
fn_to_diff : & ' ll Value ,
296
- outer_fn : & ' ll Value ,
201
+ outer_name : & str ,
202
+ ret_ty : & ' ll Type ,
297
203
fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
298
204
attrs : AutoDiffAttrs ,
299
205
dest : PlaceRef < ' tcx , & ' ll Value > ,
@@ -306,11 +212,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
306
212
}
307
213
. to_string ( ) ;
308
214
309
- // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
215
+ // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
310
216
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
311
- let name = llvm:: get_value_name ( outer_fn) ;
312
- let outer_fn_name = std:: str:: from_utf8 ( name) . unwrap ( ) ;
313
- ad_name. push_str ( outer_fn_name) ;
217
+ ad_name. push_str ( outer_name) ;
314
218
315
219
// Let us assume the user wrote the following function square:
316
220
//
@@ -344,92 +248,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
344
248
// ret double %0
345
249
// }
346
250
// ```
347
- unsafe {
348
- let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
349
-
350
- // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
351
- // think a bit more about what should go here.
352
- let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
353
- let ad_fn = declare_simple_fn (
354
- cx,
355
- & ad_name,
356
- llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
357
- llvm:: UnnamedAddr :: No ,
358
- llvm:: Visibility :: Default ,
359
- enzyme_ty,
360
- ) ;
361
-
362
- // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
363
- // do it's work.
364
- let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
365
- attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
366
-
367
- // We add a made-up attribute just such that we can recognize it after AD to update
368
- // (no)-inline attributes. We'll then also remove this attribute.
369
- let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
370
- attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
371
-
372
- let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
373
- let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
374
- args. push ( fn_to_diff) ;
375
-
376
- let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
377
- if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
378
- args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
379
- }
380
- if attrs. width > 1 {
381
- let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
382
- args. push ( cx. get_metadata_value ( enzyme_width) ) ;
383
- args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
384
- }
385
-
386
- let has_sret = has_sret ( outer_fn) ;
387
- let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
388
- match_args_from_caller_to_enzyme (
389
- & cx,
390
- builder,
391
- attrs. width ,
392
- & mut args,
393
- & attrs. input_activity ,
394
- & outer_args,
395
- has_sret,
396
- ) ;
397
-
398
- let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
399
-
400
- builder. store_to_place ( call, dest. val ) ;
251
+ let enzyme_ty = unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) } ;
252
+
253
+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
254
+ // think a bit more about what should go here.
255
+ // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
256
+ let cc = 8 ;
257
+ let ad_fn = declare_simple_fn (
258
+ cx,
259
+ & ad_name,
260
+ llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
261
+ llvm:: UnnamedAddr :: No ,
262
+ llvm:: Visibility :: Default ,
263
+ enzyme_ty,
264
+ ) ;
265
+
266
+ // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
267
+ // do it's work.
268
+ let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
269
+ attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
270
+
271
+ let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
272
+ let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
273
+ args. push ( fn_to_diff) ;
274
+
275
+ let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
276
+ if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
277
+ args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
278
+ }
279
+ if attrs. width > 1 {
280
+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
281
+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
282
+ args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
283
+ }
401
284
402
- if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
403
- if has_sret {
404
- // This is what we already have in our outer_fn (shortened):
405
- // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
406
- // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
407
- // <Here we are, we want to add the following two lines>
408
- // store [4 x double] %7, ptr %0, align 8
409
- // ret void
410
- // }
285
+ let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
411
286
412
- // now store the result of the enzyme call into the sret pointer.
413
- let sret_ptr = outer_args[ 0 ] ;
414
- let call_ty = cx. val_ty ( call) ;
415
- if attrs. width == 1 {
416
- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Struct ) ;
417
- } else {
418
- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
419
- }
420
- llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
421
- }
422
- builder. ret_void ( ) ;
423
- }
287
+ match_args_from_caller_to_enzyme (
288
+ & cx,
289
+ builder,
290
+ attrs. width ,
291
+ & mut args,
292
+ & attrs. input_activity ,
293
+ & outer_args,
294
+ ) ;
424
295
425
- builder. store_to_place ( call, dest . val ) ;
296
+ let call = builder. call ( enzyme_ty , None , None , ad_fn , & args , None , None ) ;
426
297
427
- // Let's crash in case that we messed something up above and generated invalid IR.
428
- llvm:: LLVMRustVerifyFunction (
429
- outer_fn,
430
- llvm:: LLVMRustVerifierFailureAction :: LLVMAbortProcessAction ,
431
- ) ;
432
- }
298
+ builder. store_to_place ( call, dest. val ) ;
433
299
}
434
300
435
301
pub ( crate ) fn differentiate < ' ll > (
0 commit comments