Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,59 @@ pub(crate) unsafe fn llvm_optimize(

let llvm_plugins = config.llvm_plugins.join(",");

fn adjust_offload_kernel_abis(m: &llvm::Module, llcx: &llvm::Context) {
unsafe {
// We just add a `ptr %dyn_ptr, ` as the first arg to every kernel_{i} function.
// for function in function
for num in 0..9 {
let name = format!("kernel_{num}");
let c_name = CString::new(name).unwrap();
let kernel = llvm::LLVMGetNamedFunction(m, c_name.as_ptr());
if let Some(old_fn) = kernel {
dbg!("found kernel");
//let old_fn_ty = llvm::LLVMTypeOf(old_fn);
//let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
//
//let n = unsafe { llvm::LLVMCountParamTypes(old_fn_ty) } as usize;
//
//let mut old_param_tys = Vec::with_capacity(n);
//unsafe { llvm::LLVMGetParamTypes(old_fn_ty, old_param_tys.as_mut_ptr()) };
//let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };// new param list = [ptr] + old params
//let mut new_params = Vec::with_capacity(n + 1);
//new_params.push(ptr_ty);
//for elem in &old_param_tys {
// new_params.push(elem);
//}
//let new_fn_ty = unsafe {
// llvm::LLVMFunctionType(ret_ty, new_params.as_mut_ptr(), new_params.len() as u32, llvm::False)
//};
//let new_fn = unsafe { llvm::LLVMAddFunction(c_name.as_ptr(), new_fn_ty) };
//let a0 = unsafe { llvm::LLVMGetParam(new_fn, 0) };
//unsafe { llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr(), "dyn_ptr".len()) };// Move basic blocks
//let mut bb = unsafe { llvm::LLVMGetFirstBasicBlock(old_fn) };
//while !bb.is_null() {
// let next = unsafe { llvm::LLVMGetNextBasicBlock(bb) };
// unsafe { llvm::LLVMAppendExistingBasicBlock(new_fn, bb) };
// bb = next;
//}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
//let old_n = unsafe { llvm::LLVMCountParams(old_fn) };
//for i in 0..old_n {
// let old_arg = unsafe { llvm::LLVMGetParam(old_fn, i) };
// let new_arg = unsafe { llvm::LLVMGetParam(new_fn, i + 1) };
// unsafe { llvm::LLVMReplaceAllUsesWith(old_arg, new_arg) };
//}
//unsafe { llvm::LLVMReplaceAllUsesWith(old_fn, new_fn) };
}
}
}

}
if cgcx.target_arch == "amdgpu" {
adjust_offload_kernel_abis(module.module_llvm.llmod(), &*module.module_llvm.llc);
} else {
dbg!(&cgcx.target_arch);
}

let result = unsafe {
llvm::LLVMRustOptimize(
module.module_llvm.llmod(),
Expand Down
Loading
Loading