Skip to content

Commit 5a3207d

Browse files
committed
new code
1 parent f211827 commit 5a3207d

File tree

8 files changed

+198
-78
lines changed

8 files changed

+198
-78
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ pub(crate) fn codegen(
926926
// binaries. So we must clone the module to produce the asm output
927927
// if we are also producing object code.
928928
let llmod = if let EmitObj::ObjectCode(_) = config.emit_obj {
929-
llvm::LLVMCloneModule(llmod)
929+
unsafe { llvm::LLVMCloneModule(llmod) }
930930
} else {
931931
llmod
932932
};

compiler/rustc_codegen_llvm/src/builder/gpu_device.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fn add_unnamed_global_in_addrspace<'ll>(
1414
addrspace: u32,
1515
) -> &'ll llvm::Value {
1616
let llglobal = add_global_in_addrspace(cx, name, initializer, l, addrspace);
17-
unsafe { llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global) };
17+
llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
1818
llglobal
1919
}
2020

compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs

Lines changed: 123 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,36 @@ pub(crate) fn handle_gpu_code<'ll>(
1818
// The offload memory transfer type for each kernel
1919
let mut o_types = vec![];
2020
let mut kernels = vec![];
21+
let mut region_ids = vec![];
2122
let offload_entry_ty = add_tgt_offload_entry(&cx);
2223
for num in 0..9 {
2324
let kernel = cx.get_function(&format!("kernel_{num}"));
2425
if let Some(kernel) = kernel {
25-
o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
26+
let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num);
27+
o_types.push(o);
28+
region_ids.push(k);
2629
kernels.push(kernel);
2730
}
2831
}
29-
gen_call_handling(&cx, &kernels, &o_types);
32+
gen_call_handling(&cx, &kernels, &o_types, &region_ids);
3033
crate::builder::gpu_wrapper::gen_image_wrapper_module(&cgcx);
3134
}
3235

36+
// ; Function Attrs: nounwind
37+
// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
38+
fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
39+
let tptr = cx.type_ptr();
40+
let ti64 = cx.type_i64();
41+
let ti32 = cx.type_i32();
42+
let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
43+
let tgt_fn_ty = cx.type_func(&args, ti32);
44+
let name = "__tgt_target_kernel";
45+
let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
46+
let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
47+
attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
48+
(tgt_decl, tgt_fn_ty)
49+
}
50+
3351
// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
3452
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
3553
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
@@ -83,7 +101,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
83101
offload_entry_ty
84102
}
85103

86-
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
104+
fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Type, Vec<&'ll llvm::Type>) {
87105
let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
88106
let tptr = cx.type_ptr();
89107
let ti64 = cx.type_i64();
@@ -118,9 +136,10 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
118136
vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
119137

120138
cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
139+
(kernel_arguments_ty, kernel_elements)
121140
// For now we don't handle kernels, so for now we just add a global dummy
122141
// to make sure that the __tgt_offload_entry is defined and handled correctly.
123-
cx.declare_global("my_struct_global2", kernel_arguments_ty);
142+
//cx.declare_global("my_struct_global2", kernel_arguments_ty);
124143
}
125144

126145
fn gen_tgt_data_mappers<'ll>(
@@ -187,7 +206,7 @@ fn gen_define_handling<'ll>(
187206
kernel: &'ll llvm::Value,
188207
offload_entry_ty: &'ll llvm::Type,
189208
num: i64,
190-
) -> &'ll llvm::Value {
209+
) -> (&'ll llvm::Value, &'ll llvm::Value) {
191210
let types = cx.func_params_types(cx.get_type_of_global(kernel));
192211
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
193212
// reference) types.
@@ -205,10 +224,11 @@ fn gen_define_handling<'ll>(
205224
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
206225
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
207226
// will be 2. For now, everything is 3, until we have our frontend set up.
227+
// 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add extra input ptr once, idk, figure out later)
208228
let o_types =
209-
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]);
229+
add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![1+2+32; num_ptr_types]);
210230
// Next: For each function, generate these three entries. A weak constant,
211-
// the llvm.rodata entry name, and the omp_offloading_entries value
231+
// the llvm.rodata entry name, and the llvm_offload_entries value
212232

213233
let name = format!(".kernel_{num}.region_id");
214234
let initializer = cx.get_const_i8(0);
@@ -243,12 +263,12 @@ fn gen_define_handling<'ll>(
243263
llvm::set_linkage(llglobal, WeakAnyLinkage);
244264
llvm::set_initializer(llglobal, initializer);
245265
llvm::set_alignment(llglobal, Align::ONE);
246-
let c_section_name = CString::new(".omp_offloading_entries").unwrap();
266+
let c_section_name = CString::new("llvm_offload_entries").unwrap();
247267
llvm::set_section(llglobal, &c_section_name);
248-
o_types
268+
(o_types, region_id)
249269
}
250270

251-
fn declare_offload_fn<'ll>(
271+
pub(crate) fn declare_offload_fn<'ll>(
252272
cx: &'ll SimpleCx<'_>,
253273
name: &str,
254274
ty: &'ll llvm::Type,
@@ -287,15 +307,17 @@ fn gen_call_handling<'ll>(
287307
cx: &'ll SimpleCx<'_>,
288308
_kernels: &[&'ll llvm::Value],
289309
o_types: &[&'ll llvm::Value],
310+
region_ids: &[&'ll llvm::Value],
290311
) {
312+
let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
291313
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
292314
let tptr = cx.type_ptr();
293315
let ti32 = cx.type_i32();
294316
let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
295317
let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
296318
cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
297319

298-
gen_tgt_kernel_global(&cx);
320+
let (tgt_kernel_decl, tgt_kernel_types) = gen_tgt_kernel_global(&cx);
299321
let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
300322

301323
let main_fn = cx.get_function("main");
@@ -329,29 +351,33 @@ fn gen_call_handling<'ll>(
329351
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
330352
let ty2 = cx.type_array(cx.type_i64(), num_args);
331353
let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
354+
355+
//%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
356+
let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
357+
358+
// Step 1)
359+
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
360+
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
361+
332362
// Now we allocate once per function param, a copy to be passed to one of our maps.
333363
let mut vals = vec![];
334364
let mut geps = vec![];
335365
let i32_0 = cx.get_const_i32(0);
336366
for (index, in_ty) in types.iter().enumerate() {
337367
// get function arg, store it into the alloca, and read it.
338-
let p = llvm::get_param(called, index as u32);
339-
let name = llvm::get_value_name(p);
340-
let name = str::from_utf8(&name).unwrap();
341-
let arg_name = format!("{name}.addr");
342-
let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);
343-
344-
builder.store(p, alloca, Align::EIGHT);
345-
let val = builder.load(in_ty, alloca, Align::EIGHT);
346-
let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
347-
vals.push(val);
368+
//let p = llvm::get_param(called, index as u32);
369+
//let name = llvm::get_value_name(p);
370+
//let name = str::from_utf8(&name).unwrap();
371+
//let arg_name = format!("{name}.addr");
372+
//let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);
373+
374+
let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
375+
let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
376+
vals.push(v);
377+
//vals.push(val);
348378
geps.push(gep);
349379
}
350380

351-
// Step 1)
352-
unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
353-
builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
354-
355381
let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
356382
let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
357383
let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
@@ -421,16 +447,87 @@ fn gen_call_handling<'ll>(
421447

422448
// Step 3)
423449
// Here we will add code for the actual kernel launches in a follow-up PR.
450+
//%28 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0
451+
//store i32 3, ptr %28, align 4
452+
//%29 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1
453+
//store i32 3, ptr %29, align 4
454+
//%30 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2
455+
//store ptr %26, ptr %30, align 8
456+
//%31 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3
457+
//store ptr %27, ptr %31, align 8
458+
//%32 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4
459+
//store ptr @.offload_sizes, ptr %32, align 8
460+
//%33 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5
461+
//store ptr @.offload_maptypes, ptr %33, align 8
462+
//%34 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6
463+
//store ptr null, ptr %34, align 8
464+
//%35 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7
465+
//store ptr null, ptr %35, align 8
466+
//%36 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8
467+
//store i64 0, ptr %36, align 8
468+
//%37 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9
469+
//store i64 0, ptr %37, align 8
470+
//%38 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10
471+
//store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %38, align 4
472+
//%39 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11
473+
//store [3 x i32] [i32 256, i32 0, i32 0], ptr %39, align 4
474+
//%40 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12
475+
//store i32 0, ptr %40, align 4
424476
// FIXME(offload): launch kernels
477+
let mut values = vec![];
478+
values.push((4, cx.get_const_i32(3)));
479+
values.push((4, cx.get_const_i32(num_args)));
480+
values.push((8, geps.0));
481+
values.push((8, geps.1));
482+
values.push((8, geps.2));
483+
values.push((8, o_types[0]));
484+
values.push((8, cx.const_null(cx.type_ptr())));
485+
values.push((8, cx.const_null(cx.type_ptr())));
486+
values.push((8, cx.get_const_i64(0)));
487+
values.push((8, cx.get_const_i64(0)));
488+
let ti32 = cx.type_i32();
489+
let ci32_0 = cx.get_const_i32(0);
490+
values.push((8, cx.const_array(ti32, &vec![cx.get_const_i32(2097152), ci32_0, ci32_0])));
491+
values.push((8, cx.const_array(ti32, &vec![cx.get_const_i32(256), ci32_0, ci32_0])));
492+
values.push((4, cx.get_const_i32(0)));
493+
494+
for (i, value) in values.iter().enumerate() {
495+
let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
496+
builder.store(value.1, ptr, Align::from_bytes(value.0).unwrap());
497+
}
498+
499+
let args = vec![
500+
s_ident_t,
501+
// MAX == -1
502+
cx.get_const_i64(u64::MAX),
503+
cx.get_const_i32(2097152),
504+
cx.get_const_i32(256),
505+
region_ids[0],
506+
a5,
507+
];
508+
let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
509+
// %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
510+
unsafe {
511+
let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
512+
dbg!(&next);
513+
llvm::LLVMRustPositionAfter(builder.llbuilder, next);
514+
let called_kernel = llvm::LLVMGetCalledValue(next).unwrap();
515+
llvm::LLVMInstructionEraseFromParent(next);
516+
dbg!(&called_kernel);
517+
}
425518

426519
// Step 4)
427-
unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
520+
//unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
428521

429522
let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
430523
generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
431524

432525
builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
433526

527+
drop(builder);
528+
unsafe { llvm::LLVMDeleteFunction(called) };
529+
dbg!("survived");
530+
434531
// With this we generated the following begin and end mappers. We could easily generate the
435532
// update mapper in an update.
436533
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)

compiler/rustc_codegen_llvm/src/builder/gpu_wrapper.rs

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use std::ffi::CString;
22

33
use llvm::Linkage::*;
4-
use rustc_abi::Align;
4+
use rustc_abi::{AddressSpace, Align};
55
use rustc_codegen_ssa::back::write::CodegenContext;
66
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
77

88
use crate::builder::gpu_offload::*;
9-
use crate::llvm::{self, Visibility};
9+
use crate::llvm::{self, Linkage, Type, Visibility};
1010
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx};
1111

1212
pub(crate) fn create_struct_ty<'ll>(
@@ -22,6 +22,23 @@ pub(crate) fn create_struct_ty<'ll>(
2222
}
2323
}
2424

25+
pub(crate) fn add_global_decl<'ll>(
26+
cx: &SimpleCx<'ll>,
27+
ty: &'ll Type,
28+
name: &str,
29+
l: Linkage,
30+
hidden: bool,
31+
) -> &'ll llvm::Value {
32+
let c_name = CString::new(name).unwrap();
33+
let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, ty, &c_name);
34+
llvm::set_global_constant(llglobal, true);
35+
llvm::set_linkage(llglobal, l);
36+
if hidden {
37+
llvm::set_visibility(llglobal, Visibility::Hidden);
38+
}
39+
llglobal
40+
}
41+
2542
// We don't copy types from other functions because we generate a new module and context.
2643
// Bringing in types from other contexts would likely cause issues.
2744
pub(crate) fn gen_image_wrapper_module(cgcx: &CodegenContext<LlvmCodegenBackend>) {
@@ -32,6 +49,7 @@ pub(crate) fn gen_image_wrapper_module(cgcx: &CodegenContext<LlvmCodegenBackend>
3249
ModuleLlvm::new_simple(name, dl_cstr.into_raw(), target_cstr.into_raw(), &cgcx).unwrap();
3350
let cx = SimpleCx::new(m.llmod(), m.llcx, cgcx.pointer_size);
3451
let tptr = cx.type_ptr();
52+
let tptr1 = cx.type_ptr_ext(AddressSpace(1));
3553
let ti64 = cx.type_i64();
3654
let ti32 = cx.type_i32();
3755
let ti16 = cx.type_i16();
@@ -44,28 +62,22 @@ pub(crate) fn gen_image_wrapper_module(cgcx: &CodegenContext<LlvmCodegenBackend>
4462
let offload_entry_ty = add_tgt_offload_entry(&cx);
4563
let offload_entry_arr = cx.type_array(offload_entry_ty, 0);
4664

47-
let c_name = CString::new("__start_omp_offloading_entries").unwrap();
48-
let llglobal = llvm::add_global(cx.llmod, offload_entry_arr, &c_name);
49-
llvm::set_global_constant(llglobal, true);
50-
llvm::set_linkage(llglobal, ExternalLinkage);
51-
llvm::set_visibility(llglobal, Visibility::Hidden);
52-
let c_name = CString::new("__stop_omp_offloading_entries").unwrap();
53-
let llglobal = llvm::add_global(cx.llmod, offload_entry_arr, &c_name);
54-
llvm::set_global_constant(llglobal, true);
55-
llvm::set_linkage(llglobal, ExternalLinkage);
56-
llvm::set_visibility(llglobal, Visibility::Hidden);
65+
let name = "__start_omp_offloading_entries";
66+
add_global_decl(&cx, offload_entry_arr, name, ExternalLinkage, true);
67+
68+
let name = "__stop_omp_offloading_entries";
69+
add_global_decl(&cx, offload_entry_arr, name, ExternalLinkage, true);
70+
71+
let name = "__dummy.omp_offloading_entries";
72+
let llglobal = add_global_decl(&cx, offload_entry_arr, name, InternalLinkage, false);
5773

58-
let c_name = CString::new("__dummy.omp_offloading_entries").unwrap();
59-
let llglobal = llvm::add_global(cx.llmod, offload_entry_arr, &c_name);
60-
llvm::set_global_constant(llglobal, true);
61-
llvm::set_linkage(llglobal, InternalLinkage);
6274
let c_section_name = CString::new("omp_offloading_entries").unwrap();
6375
llvm::set_section(llglobal, &c_section_name);
6476
let zeroinit = cx.const_null(offload_entry_arr);
6577
llvm::set_initializer(llglobal, zeroinit);
6678

6779
CString::new("llvm.compiler.used").unwrap();
68-
let arr_val = cx.const_array(tptr, &[llglobal]);
80+
let arr_val = cx.const_array(tptr1, &[llglobal]);
6981
let c_section_name = CString::new("llvm.metadata").unwrap();
7082
let llglobal = add_global(&cx, "llvm.compiler.used", arr_val, AppendingLinkage);
7183
llvm::set_section(llglobal, &c_section_name);
@@ -74,30 +86,9 @@ pub(crate) fn gen_image_wrapper_module(cgcx: &CodegenContext<LlvmCodegenBackend>
7486
//@llvm.compiler.used = appending global [1 x ptr] [ptr @__dummy.omp_offloading_entries], section "llvm.metadata"
7587

7688
let mapper_fn_ty = cx.type_func(&[tptr], cx.type_void());
77-
crate::declare::declare_simple_fn(
78-
&cx,
79-
&"__tgt_unregister_lib",
80-
llvm::CallConv::CCallConv,
81-
llvm::UnnamedAddr::No,
82-
llvm::Visibility::Default,
83-
mapper_fn_ty,
84-
);
85-
crate::declare::declare_simple_fn(
86-
&cx,
87-
&"__tgt_register_lib",
88-
llvm::CallConv::CCallConv,
89-
llvm::UnnamedAddr::No,
90-
llvm::Visibility::Default,
91-
mapper_fn_ty,
92-
);
93-
crate::declare::declare_simple_fn(
94-
&cx,
95-
&"atexit",
96-
llvm::CallConv::CCallConv,
97-
llvm::UnnamedAddr::No,
98-
llvm::Visibility::Default,
99-
cx.type_func(&[tptr], ti32),
100-
);
89+
declare_offload_fn(&cx, &"__tgt_register_lib", mapper_fn_ty);
90+
declare_offload_fn(&cx, &"__tgt_unregister_lib", mapper_fn_ty);
91+
declare_offload_fn(&cx, &"atexit", cx.type_func(&[tptr], ti32));
10192

10293
let unknown_txt = "11111111111111";
10394
let c_entry_name = CString::new(unknown_txt).unwrap();

0 commit comments

Comments
 (0)