Skip to content

Commit fad0b0c

Browse files
committed
Remove sret logic
1 parent 447c75a commit fad0b0c

File tree

4 files changed

+68
-196
lines changed

4 files changed

+68
-196
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 50 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use crate::context::SimpleCx;
1515
use crate::declare::declare_simple_fn;
1616
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1717
use crate::llvm::AttributePlace::Function;
18-
use crate::llvm::{Metadata, True};
18+
use crate::llvm::{Metadata, True, Type};
1919
use crate::value::Value;
2020
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2121

@@ -30,7 +30,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> {
3030
fnc_args
3131
}
3232

33-
fn has_sret(fnc: &Value) -> bool {
33+
fn _has_sret(fnc: &Value) -> bool {
3434
let num_args = llvm::LLVMCountParams(fnc) as usize;
3535
if num_args == 0 {
3636
false
@@ -56,7 +56,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
5656
args: &mut Vec<&'ll llvm::Value>,
5757
inputs: &[DiffActivity],
5858
outer_args: &[&'ll llvm::Value],
59-
has_sret: bool,
6059
) {
6160
debug!("matching autodiff arguments");
6261
// We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -68,14 +67,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
6867
let mut outer_pos: usize = 0;
6968
let mut activity_pos = 0;
7069

71-
if has_sret {
72-
// Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
73-
// inner function will still return something. We increase our outer_pos by one,
74-
// and once we're done with all other args we will take the return of the inner call and
75-
// update the sret pointer with it
76-
outer_pos = 1;
77-
}
78-
7970
let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap();
8071
let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap();
8172
let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap();
@@ -194,92 +185,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
194185
}
195186
}
196187

197-
// On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
198-
// arguments. We do however need to declare them with their correct return type.
199-
// We already figured the correct return type out in our frontend, when generating the outer_fn,
200-
// so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
201-
// Beyond sret, this article describes our challenges nicely:
202-
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
203-
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
204-
fn compute_enzyme_fn_ty<'ll>(
205-
cx: &SimpleCx<'ll>,
206-
attrs: &AutoDiffAttrs,
207-
fn_to_diff: &'ll Value,
208-
outer_fn: &'ll Value,
209-
) -> &'ll llvm::Type {
210-
let fn_ty = cx.get_type_of_global(outer_fn);
211-
let mut ret_ty = cx.get_return_type(fn_ty);
212-
213-
let has_sret = has_sret(outer_fn);
214-
215-
if has_sret {
216-
// Now we don't just forward the return type, so we have to figure it out based on the
217-
// primal return type, in combination with the autodiff settings.
218-
let fn_ty = cx.get_type_of_global(fn_to_diff);
219-
let inner_ret_ty = cx.get_return_type(fn_ty);
220-
221-
let void_ty = unsafe { llvm::LLVMVoidTypeInContext(cx.llcx) };
222-
if inner_ret_ty == void_ty {
223-
// This indicates that even the inner function has an sret.
224-
// Right now I only look for an sret in the outer function.
225-
// This *probably* needs some extra handling, but I never ran
226-
// into such a case. So I'll wait for user reports to have a test case.
227-
bug!("sret in inner function");
228-
}
229-
230-
if attrs.width == 1 {
231-
// Enzyme returns a struct of style:
232-
// `{ original_ret(if requested), float, float, ... }`
233-
let mut struct_elements = vec![];
234-
if attrs.has_primal_ret() {
235-
struct_elements.push(inner_ret_ty);
236-
}
237-
// Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
238-
// and therefore part of the return struct.
239-
let param_tys = cx.func_params_types(fn_ty);
240-
for (act, param_ty) in attrs.input_activity.iter().zip(param_tys) {
241-
if matches!(act, DiffActivity::Active) {
242-
// Now find the float type at position i based on the fn_ty,
243-
// to know what (f16/f32/f64/...) to add to the struct.
244-
struct_elements.push(param_ty);
245-
}
246-
}
247-
ret_ty = cx.type_struct(&struct_elements, false);
248-
} else {
249-
// First we check if we also have to deal with the primal return.
250-
match attrs.mode {
251-
DiffMode::Forward => match attrs.ret_activity {
252-
DiffActivity::Dual => {
253-
let arr_ty =
254-
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
255-
ret_ty = arr_ty;
256-
}
257-
DiffActivity::DualOnly => {
258-
let arr_ty =
259-
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64) };
260-
ret_ty = arr_ty;
261-
}
262-
DiffActivity::Const => {
263-
todo!("Not sure, do we need to do something here?");
264-
}
265-
_ => {
266-
bug!("unreachable");
267-
}
268-
},
269-
DiffMode::Reverse => {
270-
todo!("Handle sret for reverse mode");
271-
}
272-
_ => {
273-
bug!("unreachable");
274-
}
275-
}
276-
}
277-
}
278-
279-
// LLVM can figure out the input types on it's own, so we take a shortcut here.
280-
unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) }
281-
}
282-
283188
/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
284189
/// function with expected naming and calling conventions[^1] which will be
285190
/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -293,7 +198,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
293198
builder: &mut Builder<'_, 'll, 'tcx>,
294199
cx: &SimpleCx<'ll>,
295200
fn_to_diff: &'ll Value,
296-
outer_fn: &'ll Value,
201+
outer_name: &str,
202+
ret_ty: &'ll Type,
297203
fn_args: &[OperandRef<'tcx, &'ll Value>],
298204
attrs: AutoDiffAttrs,
299205
dest: PlaceRef<'tcx, &'ll Value>,
@@ -306,11 +212,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
306212
}
307213
.to_string();
308214

309-
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
215+
// add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
310216
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
311-
let name = llvm::get_value_name(outer_fn);
312-
let outer_fn_name = std::str::from_utf8(name).unwrap();
313-
ad_name.push_str(outer_fn_name);
217+
ad_name.push_str(outer_name);
314218

315219
// Let us assume the user wrote the following function square:
316220
//
@@ -344,92 +248,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
344248
// ret double %0
345249
// }
346250
// ```
347-
unsafe {
348-
let enzyme_ty = compute_enzyme_fn_ty(cx, &attrs, fn_to_diff, outer_fn);
349-
350-
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
351-
// think a bit more about what should go here.
352-
let cc = llvm::LLVMGetFunctionCallConv(outer_fn);
353-
let ad_fn = declare_simple_fn(
354-
cx,
355-
&ad_name,
356-
llvm::CallConv::try_from(cc).expect("invalid callconv"),
357-
llvm::UnnamedAddr::No,
358-
llvm::Visibility::Default,
359-
enzyme_ty,
360-
);
361-
362-
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
363-
// do it's work.
364-
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
365-
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
366-
367-
// We add a made-up attribute just such that we can recognize it after AD to update
368-
// (no)-inline attributes. We'll then also remove this attribute.
369-
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
370-
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
371-
372-
let num_args = llvm::LLVMCountParams(&fn_to_diff);
373-
let mut args = Vec::with_capacity(num_args as usize + 1);
374-
args.push(fn_to_diff);
375-
376-
let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap();
377-
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
378-
args.push(cx.get_metadata_value(enzyme_primal_ret));
379-
}
380-
if attrs.width > 1 {
381-
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
382-
args.push(cx.get_metadata_value(enzyme_width));
383-
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
384-
}
385-
386-
let has_sret = has_sret(outer_fn);
387-
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
388-
match_args_from_caller_to_enzyme(
389-
&cx,
390-
builder,
391-
attrs.width,
392-
&mut args,
393-
&attrs.input_activity,
394-
&outer_args,
395-
has_sret,
396-
);
397-
398-
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
399-
400-
builder.store_to_place(call, dest.val);
251+
let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True) };
252+
253+
// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
254+
// think a bit more about what should go here.
255+
// FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
256+
let cc = 8;
257+
let ad_fn = declare_simple_fn(
258+
cx,
259+
&ad_name,
260+
llvm::CallConv::try_from(cc).expect("invalid callconv"),
261+
llvm::UnnamedAddr::No,
262+
llvm::Visibility::Default,
263+
enzyme_ty,
264+
);
265+
266+
// Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
267+
// do it's work.
268+
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
269+
attributes::apply_to_llfn(ad_fn, Function, &[attr]);
270+
271+
let num_args = llvm::LLVMCountParams(&fn_to_diff);
272+
let mut args = Vec::with_capacity(num_args as usize + 1);
273+
args.push(fn_to_diff);
274+
275+
let enzyme_primal_ret = cx.create_metadata("enzyme_primal_return".to_string()).unwrap();
276+
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
277+
args.push(cx.get_metadata_value(enzyme_primal_ret));
278+
}
279+
if attrs.width > 1 {
280+
let enzyme_width = cx.create_metadata("enzyme_width".to_string()).unwrap();
281+
args.push(cx.get_metadata_value(enzyme_width));
282+
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
283+
}
401284

402-
if cx.val_ty(call) == cx.type_void() || has_sret {
403-
if has_sret {
404-
// This is what we already have in our outer_fn (shortened):
405-
// define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
406-
// %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
407-
// <Here we are, we want to add the following two lines>
408-
// store [4 x double] %7, ptr %0, align 8
409-
// ret void
410-
// }
285+
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
411286

412-
// now store the result of the enzyme call into the sret pointer.
413-
let sret_ptr = outer_args[0];
414-
let call_ty = cx.val_ty(call);
415-
if attrs.width == 1 {
416-
assert_eq!(cx.type_kind(call_ty), TypeKind::Struct);
417-
} else {
418-
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
419-
}
420-
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
421-
}
422-
builder.ret_void();
423-
}
287+
match_args_from_caller_to_enzyme(
288+
&cx,
289+
builder,
290+
attrs.width,
291+
&mut args,
292+
&attrs.input_activity,
293+
&outer_args,
294+
);
424295

425-
builder.store_to_place(call, dest.val);
296+
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
426297

427-
// Let's crash in case that we messed something up above and generated invalid IR.
428-
llvm::LLVMRustVerifyFunction(
429-
outer_fn,
430-
llvm::LLVMRustVerifierFailureAction::LLVMAbortProcessAction,
431-
);
432-
}
298+
builder.store_to_place(call, dest.val);
433299
}
434300

435301
pub(crate) fn differentiate<'ll>(

compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
652652
}
653653
}
654654
impl<'ll> SimpleCx<'ll> {
655-
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
655+
pub(crate) fn _get_return_type(&self, ty: &'ll Type) -> &'ll Type {
656656
assert_eq!(self.type_kind(ty), TypeKind::Function);
657657
unsafe { llvm::LLVMGetReturnType(ty) }
658658
}

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,17 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
174174
span: Span,
175175
) -> Result<(), ty::Instance<'tcx>> {
176176
let tcx = self.tcx;
177+
let callee_ty = instance.ty(tcx, self.typing_env());
177178

178-
let name = tcx.item_name(instance.def_id());
179179
let fn_args = instance.args;
180180

181+
let sig = callee_ty.fn_sig(tcx);
182+
let sig = tcx.normalize_erasing_late_bound_regions(self.typing_env(), sig);
183+
let ret_ty = sig.output();
184+
let name = tcx.item_name(instance.def_id());
185+
186+
let llret_ty = self.layout_of(ret_ty).llvm_type(self);
187+
181188
let simple = call_simple_intrinsic(self, name, args);
182189
let llval = match name {
183190
_ if simple.is_some() => simple.unwrap(),
@@ -223,20 +230,14 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
223230
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
224231
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
225232

226-
// Declare target fn
227-
let target_symbol =
228-
symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
229-
let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty());
230-
let outer_fn: &'ll Value =
231-
self.cx.declare_fn(&target_symbol, fn_abi, Some(instance));
232-
233233
// Build body
234234
generate_enzyme_call(
235235
self,
236236
self.cx,
237237
fn_to_diff,
238-
outer_fn,
239-
args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore
238+
name.as_str(),
239+
llret_ty,
240+
args,
240241
diff_attrs.clone(),
241242
result,
242243
);

compiler/rustc_hir_analysis/src/check/intrinsic.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,12 @@ pub(crate) fn check_intrinsic_type(
196196
(Ty::new_ref(tcx, env_region, va_list_ty, mutbl), va_list_ty)
197197
};
198198

199-
let safety = intrinsic_operation_unsafety(tcx, intrinsic_id);
199+
// FIXME(Sa4dUs): Get the actual safety level of the diff function
200+
let safety = if has_autodiff {
201+
hir::Safety::Safe
202+
} else {
203+
intrinsic_operation_unsafety(tcx, intrinsic_id)
204+
};
200205
let n_lts = 0;
201206
let (n_tps, n_cts, inputs, output) = match intrinsic_name {
202207
_ if has_autodiff => {

0 commit comments

Comments
 (0)