1- // Reference: PTX Writer's Guide to Interoperability
2- // https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability
3-
4- use crate :: abi:: call:: { ArgAbi , FnAbi } ;
1+ use crate :: abi:: call:: { ArgAbi , FnAbi , PassMode , Reg , Size , Uniform } ;
2+ use crate :: abi:: { HasDataLayout , TyAbiInterface } ;
53
64fn classify_ret < Ty > ( ret : & mut ArgAbi < ' _ , Ty > ) {
75 if ret. layout . is_aggregate ( ) && ret. layout . size . bits ( ) > 64 {
86 ret. make_indirect ( ) ;
9- } else {
10- ret. extend_integer_width_to ( 64 ) ;
117 }
128}
139
1410fn classify_arg < Ty > ( arg : & mut ArgAbi < ' _ , Ty > ) {
1511 if arg. layout . is_aggregate ( ) && arg. layout . size . bits ( ) > 64 {
1612 arg. make_indirect ( ) ;
17- } else {
18- arg. extend_integer_width_to ( 64 ) ;
13+ }
14+ }
15+
16+ fn classify_arg_kernel < ' a , Ty , C > ( _cx : & C , arg : & mut ArgAbi < ' a , Ty > )
17+ where
18+ Ty : TyAbiInterface < ' a , C > + Copy ,
19+ C : HasDataLayout ,
20+ {
21+ if matches ! ( arg. mode, PassMode :: Pair ( ..) ) && ( arg. layout . is_adt ( ) || arg. layout . is_tuple ( ) ) {
22+ let align_bytes = arg. layout . align . abi . bytes ( ) ;
23+
24+ let unit = match align_bytes {
25+ 1 => Reg :: i8 ( ) ,
26+ 2 => Reg :: i16 ( ) ,
27+ 4 => Reg :: i32 ( ) ,
28+ 8 => Reg :: i64 ( ) ,
29+ 16 => Reg :: i128 ( ) ,
30+ _ => unreachable ! ( "Align is given as power of 2 no larger than 16 bytes" ) ,
31+ } ;
32+ arg. cast_to ( Uniform { unit, total : Size :: from_bytes ( 2 * align_bytes) } ) ;
1933 }
2034}
2135
@@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
3145 classify_arg ( arg) ;
3246 }
3347}
48+
49+ pub fn compute_ptx_kernel_abi_info < ' a , Ty , C > ( cx : & C , fn_abi : & mut FnAbi < ' a , Ty > )
50+ where
51+ Ty : TyAbiInterface < ' a , C > + Copy ,
52+ C : HasDataLayout ,
53+ {
54+ if !fn_abi. ret . layout . is_unit ( ) && !fn_abi. ret . layout . is_never ( ) {
55+ panic ! ( "Kernels should not return anything other than () or !" ) ;
56+ }
57+
58+ for arg in & mut fn_abi. args {
59+ if arg. is_ignore ( ) {
60+ continue ;
61+ }
62+ classify_arg_kernel ( cx, arg) ;
63+ }
64+ }
0 commit comments