@@ -2,6 +2,7 @@ use std::ffi::CString;
22
33use llvm:: Linkage :: * ;
44use rustc_abi:: Align ;
5+ use rustc_codegen_ssa:: common:: TypeKind ;
56use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
67use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
78use rustc_middle:: bug;
@@ -361,7 +362,6 @@ pub(crate) fn add_global<'ll>(
361362pub ( crate ) fn gen_define_handling < ' ll > (
362363 cx : & CodegenCx < ' ll , ' _ > ,
363364 metadata : & [ OffloadMetadata ] ,
364- types : & [ & ' ll Type ] ,
365365 symbol : String ,
366366 offload_globals : & OffloadGlobals < ' ll > ,
367367) -> OffloadKernelGlobals < ' ll > {
@@ -371,25 +371,18 @@ pub(crate) fn gen_define_handling<'ll>(
371371
372372 let offload_entry_ty = offload_globals. offload_entry_ty ;
373373
374- // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
375- // reference) types.
376- let ptr_meta = types. iter ( ) . zip ( metadata) . filter_map ( |( & x, meta) | match cx. type_kind ( x) {
377- rustc_codegen_ssa:: common:: TypeKind :: Pointer => Some ( meta) ,
378- _ => None ,
379- } ) ;
380-
381374 // FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
382- let ( ptr_sizes , ptr_transfer ) : ( Vec < _ > , Vec < _ > ) =
383- ptr_meta . map ( |m| ( m. payload_size , m. mode . bits ( ) | 0x20 ) ) . unzip ( ) ;
375+ let ( sizes , transfer ) : ( Vec < _ > , Vec < _ > ) =
376+ metadata . iter ( ) . map ( |m| ( m. payload_size , m. mode . bits ( ) | 0x20 ) ) . unzip ( ) ;
384377
385- let offload_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol}" ) , & ptr_sizes ) ;
378+ let offload_sizes = add_priv_unnamed_arr ( & cx, & format ! ( ".offload_sizes.{symbol}" ) , & sizes ) ;
386379 // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
387380 // or both to and from the gpu (=3). Other values shouldn't affect us for now.
388381 // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
389382 // will be 2. For now, everything is 3, until we have our frontend set up.
390383 // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
391384 let memtransfer_types =
392- add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}" ) , & ptr_transfer ) ;
385+ add_priv_unnamed_arr ( & cx, & format ! ( ".offload_maptypes.{symbol}" ) , & transfer ) ;
393386
394387 // Next: For each function, generate these three entries. A weak constant,
395388 // the llvm.rodata entry name, and the llvm_offload_entries value
@@ -445,13 +438,25 @@ fn declare_offload_fn<'ll>(
445438 )
446439}
447440
441+ pub ( crate ) fn scalar_width < ' ll > ( cx : & ' ll SimpleCx < ' _ > , ty : & ' ll Type ) -> u64 {
442+ match cx. type_kind ( ty) {
443+ TypeKind :: Half
444+ | TypeKind :: Float
445+ | TypeKind :: Double
446+ | TypeKind :: X86_FP80
447+ | TypeKind :: FP128
448+ | TypeKind :: PPC_FP128 => cx. float_width ( ty) as u64 ,
449+ TypeKind :: Integer => cx. int_width ( ty) ,
450+ other => bug ! ( "scalar_width was called on a non scalar type {other:?}" ) ,
451+ }
452+ }
453+
448454// For each kernel *call*, we now use some of our previous declared globals to move data to and from
449455// the gpu. For now, we only handle the data transfer part of it.
450456// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
451457// Since in our frontend users (by default) don't have to specify data transfer, this is something
452- // we should optimize in the future! We also assume that everything should be copied back and forth,
453- // but sometimes we can directly zero-allocate on the device and only move back, or if something is
454- // immutable, we might only copy it to the device, but not back.
458+ // we should optimize in the future! In some cases we can directly zero-allocate on the device and
459+ // only move data back, or if something is immutable, we might only copy it to the device.
455460//
456461// Current steps:
457462// 0. Alloca some variables for the following steps
@@ -538,8 +543,34 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>(
538543 let mut geps = vec ! [ ] ;
539544 let i32_0 = cx. get_const_i32 ( 0 ) ;
540545 for & v in args {
541- let gep = builder. inbounds_gep ( cx. type_f32 ( ) , v, & [ i32_0] ) ;
542- vals. push ( v) ;
546+ let ty = cx. val_ty ( v) ;
547+ let ty_kind = cx. type_kind ( ty) ;
548+ let ( base_val, gep_base) = match ty_kind {
549+ TypeKind :: Pointer => ( v, v) ,
550+ TypeKind :: Half | TypeKind :: Float | TypeKind :: Double | TypeKind :: Integer => {
551+ // FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it
552+ let num_bits = scalar_width ( cx, ty) ;
553+
554+ let bb = builder. llbb ( ) ;
555+ unsafe {
556+ llvm:: LLVMRustPositionBuilderPastAllocas ( builder. llbuilder , builder. llfn ( ) ) ;
557+ }
558+ let addr = builder. direct_alloca ( cx. type_i64 ( ) , Align :: EIGHT , "addr" ) ;
559+ unsafe {
560+ llvm:: LLVMPositionBuilderAtEnd ( builder. llbuilder , bb) ;
561+ }
562+
563+ let cast = builder. bitcast ( v, cx. type_ix ( num_bits) ) ;
564+ let value = builder. zext ( cast, cx. type_i64 ( ) ) ;
565+ builder. store ( value, addr, Align :: EIGHT ) ;
566+ ( value, addr)
567+ }
568+ other => bug ! ( "offload does not support {other:?}" ) ,
569+ } ;
570+
571+ let gep = builder. inbounds_gep ( cx. type_f32 ( ) , gep_base, & [ i32_0] ) ;
572+
573+ vals. push ( base_val) ;
543574 geps. push ( gep) ;
544575 }
545576
0 commit comments