@@ -18,18 +18,36 @@ pub(crate) fn handle_gpu_code<'ll>(
18
18
// The offload memory transfer type for each kernel
19
19
let mut o_types = vec ! [ ] ;
20
20
let mut kernels = vec ! [ ] ;
21
+ let mut region_ids = vec ! [ ] ;
21
22
let offload_entry_ty = add_tgt_offload_entry ( & cx) ;
22
23
for num in 0 ..9 {
23
24
let kernel = cx. get_function ( & format ! ( "kernel_{num}" ) ) ;
24
25
if let Some ( kernel) = kernel {
25
- o_types. push ( gen_define_handling ( & cx, kernel, offload_entry_ty, num) ) ;
26
+ let ( o, k) = gen_define_handling ( & cx, kernel, offload_entry_ty, num) ;
27
+ o_types. push ( o) ;
28
+ region_ids. push ( k) ;
26
29
kernels. push ( kernel) ;
27
30
}
28
31
}
29
- gen_call_handling ( & cx, & kernels, & o_types) ;
32
+ gen_call_handling ( & cx, & kernels, & o_types, & region_ids ) ;
30
33
crate :: builder:: gpu_wrapper:: gen_image_wrapper_module ( & cgcx) ;
31
34
}
32
35
36
+ // ; Function Attrs: nounwind
37
+ // declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
38
+ fn generate_launcher < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> ( & ' ll llvm:: Value , & ' ll llvm:: Type ) {
39
+ let tptr = cx. type_ptr ( ) ;
40
+ let ti64 = cx. type_i64 ( ) ;
41
+ let ti32 = cx. type_i32 ( ) ;
42
+ let args = vec ! [ tptr, ti64, ti32, ti32, tptr, tptr] ;
43
+ let tgt_fn_ty = cx. type_func ( & args, ti32) ;
44
+ let name = "__tgt_target_kernel" ;
45
+ let tgt_decl = declare_offload_fn ( & cx, name, tgt_fn_ty) ;
46
+ let nounwind = llvm:: AttributeKind :: NoUnwind . create_attr ( cx. llcx ) ;
47
+ attributes:: apply_to_llfn ( tgt_decl, Function , & [ nounwind] ) ;
48
+ ( tgt_decl, tgt_fn_ty)
49
+ }
50
+
33
51
// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
34
52
// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
35
53
// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
@@ -83,7 +101,7 @@ pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Ty
83
101
offload_entry_ty
84
102
}
85
103
86
- fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) {
104
+ fn gen_tgt_kernel_global < ' ll > ( cx : & ' ll SimpleCx < ' _ > ) -> ( & ' ll llvm :: Type , Vec < & ' ll llvm :: Type > ) {
87
105
let kernel_arguments_ty = cx. type_named_struct ( "struct.__tgt_kernel_arguments" ) ;
88
106
let tptr = cx. type_ptr ( ) ;
89
107
let ti64 = cx. type_i64 ( ) ;
@@ -118,9 +136,10 @@ fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
118
136
vec ! [ ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32] ;
119
137
120
138
cx. set_struct_body ( kernel_arguments_ty, & kernel_elements, false ) ;
139
+ ( kernel_arguments_ty, kernel_elements)
121
140
// For now we don't handle kernels, so for now we just add a global dummy
122
141
// to make sure that the __tgt_offload_entry is defined and handled correctly.
123
- cx. declare_global ( "my_struct_global2" , kernel_arguments_ty) ;
142
+ // cx.declare_global("my_struct_global2", kernel_arguments_ty);
124
143
}
125
144
126
145
fn gen_tgt_data_mappers < ' ll > (
@@ -187,7 +206,7 @@ fn gen_define_handling<'ll>(
187
206
kernel : & ' ll llvm:: Value ,
188
207
offload_entry_ty : & ' ll llvm:: Type ,
189
208
num : i64 ,
190
- ) -> & ' ll llvm:: Value {
209
+ ) -> ( & ' ll llvm:: Value , & ' ll llvm :: Value ) {
191
210
let types = cx. func_params_types ( cx. get_type_of_global ( kernel) ) ;
192
211
// It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
193
212
// reference) types.
@@ -205,10 +224,11 @@ fn gen_define_handling<'ll>(
205
224
// or both to and from the gpu (=3). Other values shouldn't affect us for now.
206
225
// A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
207
226
// will be 2. For now, everything is 3, until we have our frontend set up.
227
+ // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add extra input ptr once, idk, figure out later)
208
228
let o_types =
209
- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 3 ; num_ptr_types] ) ;
229
+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{num}" ) , & vec ! [ 1 + 2 + 32 ; num_ptr_types] ) ;
210
230
// Next: For each function, generate these three entries. A weak constant,
211
- // the llvm.rodata entry name, and the omp_offloading_entries value
231
+ // the llvm.rodata entry name, and the llvm_offload_entries value
212
232
213
233
let name = format ! ( ".kernel_{num}.region_id" ) ;
214
234
let initializer = cx. get_const_i8 ( 0 ) ;
@@ -243,12 +263,12 @@ fn gen_define_handling<'ll>(
243
263
llvm:: set_linkage ( llglobal, WeakAnyLinkage ) ;
244
264
llvm:: set_initializer ( llglobal, initializer) ;
245
265
llvm:: set_alignment ( llglobal, Align :: ONE ) ;
246
- let c_section_name = CString :: new ( ".omp_offloading_entries " ) . unwrap ( ) ;
266
+ let c_section_name = CString :: new ( "llvm_offload_entries " ) . unwrap ( ) ;
247
267
llvm:: set_section ( llglobal, & c_section_name) ;
248
- o_types
268
+ ( o_types, region_id )
249
269
}
250
270
251
- fn declare_offload_fn < ' ll > (
271
+ pub ( crate ) fn declare_offload_fn < ' ll > (
252
272
cx : & ' ll SimpleCx < ' _ > ,
253
273
name : & str ,
254
274
ty : & ' ll llvm:: Type ,
@@ -287,15 +307,17 @@ fn gen_call_handling<'ll>(
287
307
cx : & ' ll SimpleCx < ' _ > ,
288
308
_kernels : & [ & ' ll llvm:: Value ] ,
289
309
o_types : & [ & ' ll llvm:: Value ] ,
310
+ region_ids : & [ & ' ll llvm:: Value ] ,
290
311
) {
312
+ let ( tgt_decl, tgt_target_kernel_ty) = generate_launcher ( & cx) ;
291
313
// %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
292
314
let tptr = cx. type_ptr ( ) ;
293
315
let ti32 = cx. type_i32 ( ) ;
294
316
let tgt_bin_desc_ty = vec ! [ ti32, tptr, tptr, tptr] ;
295
317
let tgt_bin_desc = cx. type_named_struct ( "struct.__tgt_bin_desc" ) ;
296
318
cx. set_struct_body ( tgt_bin_desc, & tgt_bin_desc_ty, false ) ;
297
319
298
- gen_tgt_kernel_global ( & cx) ;
320
+ let ( tgt_kernel_decl , tgt_kernel_types ) = gen_tgt_kernel_global ( & cx) ;
299
321
let ( begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers ( & cx) ;
300
322
301
323
let main_fn = cx. get_function ( "main" ) ;
@@ -329,29 +351,33 @@ fn gen_call_handling<'ll>(
329
351
// These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
330
352
let ty2 = cx. type_array ( cx. type_i64 ( ) , num_args) ;
331
353
let a4 = builder. direct_alloca ( ty2, Align :: EIGHT , ".offload_sizes" ) ;
354
+
355
+ //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
356
+ let a5 = builder. direct_alloca ( tgt_kernel_decl, Align :: EIGHT , "kernel_args" ) ;
357
+
358
+ // Step 1)
359
+ unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
360
+ builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
361
+
332
362
// Now we allocate once per function param, a copy to be passed to one of our maps.
333
363
let mut vals = vec ! [ ] ;
334
364
let mut geps = vec ! [ ] ;
335
365
let i32_0 = cx. get_const_i32 ( 0 ) ;
336
366
for ( index, in_ty) in types. iter ( ) . enumerate ( ) {
337
367
// get function arg, store it into the alloca, and read it.
338
- let p = llvm:: get_param ( called, index as u32 ) ;
339
- let name = llvm:: get_value_name ( p) ;
340
- let name = str:: from_utf8 ( & name) . unwrap ( ) ;
341
- let arg_name = format ! ( "{name}.addr" ) ;
342
- let alloca = builder. direct_alloca ( in_ty, Align :: EIGHT , & arg_name) ;
343
-
344
- builder . store ( p , alloca , Align :: EIGHT ) ;
345
- let val = builder. load ( in_ty , alloca , Align :: EIGHT ) ;
346
- let gep = builder . inbounds_gep ( cx . type_f32 ( ) , val , & [ i32_0 ] ) ;
347
- vals. push ( val) ;
368
+ // let p = llvm::get_param(called, index as u32);
369
+ // let name = llvm::get_value_name(p);
370
+ // let name = str::from_utf8(&name).unwrap();
371
+ // let arg_name = format!("{name}.addr");
372
+ // let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);
373
+
374
+ let v = unsafe { llvm :: LLVMGetOperand ( kernel_call , index as u32 ) . unwrap ( ) } ;
375
+ let gep = builder. inbounds_gep ( cx . type_f32 ( ) , v , & [ i32_0 ] ) ;
376
+ vals . push ( v ) ;
377
+ // vals.push(val);
348
378
geps. push ( gep) ;
349
379
}
350
380
351
- // Step 1)
352
- unsafe { llvm:: LLVMRustPositionBefore ( builder. llbuilder , kernel_call) } ;
353
- builder. memset ( tgt_bin_desc_alloca, cx. get_const_i8 ( 0 ) , cx. get_const_i64 ( 32 ) , Align :: EIGHT ) ;
354
-
355
381
let mapper_fn_ty = cx. type_func ( & [ cx. type_ptr ( ) ] , cx. type_void ( ) ) ;
356
382
let register_lib_decl = declare_offload_fn ( & cx, "__tgt_register_lib" , mapper_fn_ty) ;
357
383
let unregister_lib_decl = declare_offload_fn ( & cx, "__tgt_unregister_lib" , mapper_fn_ty) ;
@@ -421,16 +447,87 @@ fn gen_call_handling<'ll>(
421
447
422
448
// Step 3)
423
449
// Here we will add code for the actual kernel launches in a follow-up PR.
450
+ //%28 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 0
451
+ //store i32 3, ptr %28, align 4
452
+ //%29 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 1
453
+ //store i32 3, ptr %29, align 4
454
+ //%30 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 2
455
+ //store ptr %26, ptr %30, align 8
456
+ //%31 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 3
457
+ //store ptr %27, ptr %31, align 8
458
+ //%32 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 4
459
+ //store ptr @.offload_sizes, ptr %32, align 8
460
+ //%33 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 5
461
+ //store ptr @.offload_maptypes, ptr %33, align 8
462
+ //%34 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 6
463
+ //store ptr null, ptr %34, align 8
464
+ //%35 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 7
465
+ //store ptr null, ptr %35, align 8
466
+ //%36 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 8
467
+ //store i64 0, ptr %36, align 8
468
+ //%37 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 9
469
+ //store i64 0, ptr %37, align 8
470
+ //%38 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 10
471
+ //store [3 x i32] [i32 2097152, i32 0, i32 0], ptr %38, align 4
472
+ //%39 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 11
473
+ //store [3 x i32] [i32 256, i32 0, i32 0], ptr %39, align 4
474
+ //%40 = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %kernel_args, i32 0, i32 12
475
+ //store i32 0, ptr %40, align 4
424
476
// FIXME(offload): launch kernels
477
+ let mut values = vec ! [ ] ;
478
+ values. push ( ( 4 , cx. get_const_i32 ( 3 ) ) ) ;
479
+ values. push ( ( 4 , cx. get_const_i32 ( num_args) ) ) ;
480
+ values. push ( ( 8 , geps. 0 ) ) ;
481
+ values. push ( ( 8 , geps. 1 ) ) ;
482
+ values. push ( ( 8 , geps. 2 ) ) ;
483
+ values. push ( ( 8 , o_types[ 0 ] ) ) ;
484
+ values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
485
+ values. push ( ( 8 , cx. const_null ( cx. type_ptr ( ) ) ) ) ;
486
+ values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
487
+ values. push ( ( 8 , cx. get_const_i64 ( 0 ) ) ) ;
488
+ let ti32 = cx. type_i32 ( ) ;
489
+ let ci32_0 = cx. get_const_i32 ( 0 ) ;
490
+ values. push ( ( 8 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 2097152 ) , ci32_0, ci32_0] ) ) ) ;
491
+ values. push ( ( 8 , cx. const_array ( ti32, & vec ! [ cx. get_const_i32( 256 ) , ci32_0, ci32_0] ) ) ) ;
492
+ values. push ( ( 4 , cx. get_const_i32 ( 0 ) ) ) ;
493
+
494
+ for ( i, value) in values. iter ( ) . enumerate ( ) {
495
+ let ptr = builder. inbounds_gep ( tgt_kernel_decl, a5, & [ i32_0, cx. get_const_i32 ( i as u64 ) ] ) ;
496
+ builder. store ( value. 1 , ptr, Align :: from_bytes ( value. 0 ) . unwrap ( ) ) ;
497
+ }
498
+
499
+ let args = vec ! [
500
+ s_ident_t,
501
+ // MAX == -1
502
+ cx. get_const_i64( u64 :: MAX ) ,
503
+ cx. get_const_i32( 2097152 ) ,
504
+ cx. get_const_i32( 256 ) ,
505
+ region_ids[ 0 ] ,
506
+ a5,
507
+ ] ;
508
+ let offload_success = builder. call ( tgt_target_kernel_ty, tgt_decl, & args, None ) ;
509
+ // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
510
+ unsafe {
511
+ let next = llvm:: LLVMGetNextInstruction ( offload_success) . unwrap ( ) ;
512
+ dbg ! ( & next) ;
513
+ llvm:: LLVMRustPositionAfter ( builder. llbuilder , next) ;
514
+ let called_kernel = llvm:: LLVMGetCalledValue ( next) . unwrap ( ) ;
515
+ llvm:: LLVMInstructionEraseFromParent ( next) ;
516
+ dbg ! ( & called_kernel) ;
517
+ }
425
518
426
519
// Step 4)
427
- unsafe { llvm:: LLVMRustPositionAfter ( builder. llbuilder , kernel_call) } ;
520
+ // unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
428
521
429
522
let geps = get_geps ( & mut builder, & cx, ty, ty2, a1, a2, a4) ;
430
523
generate_mapper_call ( & mut builder, & cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t) ;
431
524
432
525
builder. call ( mapper_fn_ty, unregister_lib_decl, & [ tgt_bin_desc_alloca] , None ) ;
433
526
527
+ drop ( builder) ;
528
+ unsafe { llvm:: LLVMDeleteFunction ( called) } ;
529
+ dbg ! ( "survived" ) ;
530
+
434
531
// With this we generated the following begin and end mappers. We could easily generate the
435
532
// update mapper in an update.
436
533
// call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
0 commit comments