Skip to content

Commit be22b03

Browse files
committed
Support scalar pair ABI
(T, U) pairs in entrypoints and regular functions are now supported.
1 parent 7358fae commit be22b03

File tree

23 files changed

+492
-46
lines changed

23 files changed

+492
-46
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ pub fn scalar_pair_element_backend_type<'tcx>(
655655
ty: TyAndLayout<'tcx>,
656656
index: usize,
657657
) -> Word {
658-
let [a, b] = match ty.layout.backend_repr() {
658+
let [a, b] = match ty.backend_repr {
659659
BackendRepr::ScalarPair(a, b) => [a, b],
660660
other => span_bug!(
661661
span,

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 80 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use rspirv::dr::Operand;
1111
use rspirv::spirv::{
1212
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
1313
};
14-
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _};
14+
use rustc_codegen_ssa::traits::{
15+
BaseTypeCodegenMethods, BuilderMethods, ConstCodegenMethods, LayoutTypeCodegenMethods,
16+
MiscCodegenMethods as _,
17+
};
1518
use rustc_data_structures::fx::FxHashMap;
1619
use rustc_errors::MultiSpan;
1720
use rustc_hir as hir;
@@ -86,22 +89,7 @@ impl<'tcx> CodegenCx<'tcx> {
8689
};
8790
for (arg_abi, hir_param) in fn_abi.args.iter().zip(hir_params) {
8891
match arg_abi.mode {
89-
PassMode::Direct(_) => {}
90-
PassMode::Pair(..) => {
91-
// FIXME(eddyb) implement `ScalarPair` `Input`s, or change
92-
// the `FnAbi` readjustment to only use `PassMode::Pair` for
93-
// pointers to `!Sized` types, but not other `ScalarPair`s.
94-
if !matches!(arg_abi.layout.ty.kind(), ty::Ref(..)) {
95-
self.tcx.dcx().span_err(
96-
hir_param.ty_span,
97-
format!(
98-
"entry point parameter type not yet supported \
99-
(`{}` has `ScalarPair` ABI but is not a `&T`)",
100-
arg_abi.layout.ty
101-
),
102-
);
103-
}
104-
}
92+
PassMode::Direct(_) | PassMode::Pair(..) => {}
10593
// FIXME(eddyb) support these (by just ignoring them) - if there
10694
// is any validation concern, it should be done on the types.
10795
PassMode::Ignore => self.tcx.dcx().span_fatal(
@@ -442,6 +430,33 @@ impl<'tcx> CodegenCx<'tcx> {
442430
} = self.entry_param_deduce_from_rust_ref_or_value(entry_arg_abi.layout, hir_param, &attrs);
443431
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);
444432

433+
// In compute shaders, user-provided data must come from buffers or push
434+
// constants, i.e. by-reference parameters.
435+
if execution_model == ExecutionModel::GLCompute
436+
&& matches!(entry_arg_abi.mode, PassMode::Direct(_) | PassMode::Pair(..))
437+
&& !matches!(entry_arg_abi.layout.ty.kind(), ty::Ref(..))
438+
&& attrs.builtin.is_none()
439+
{
440+
let param_name = if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind {
441+
ident.name.to_string()
442+
} else {
443+
"parameter".to_string()
444+
};
445+
self.tcx
446+
.dcx()
447+
.struct_span_err(
448+
hir_param.ty_span,
449+
format!("compute entry parameter `{param_name}` must be by-reference",),
450+
)
451+
.with_help(format!(
452+
"consider changing the type to `&{}`",
453+
entry_arg_abi.layout.ty
454+
))
455+
.emit();
456+
// Keep this a hard error to stop compilation after emitting help.
457+
self.tcx.dcx().abort_if_errors();
458+
}
459+
445460
let (var_id, spec_const_id) = match storage_class {
446461
// Pre-allocate the module-scoped `OpVariable` *Result* ID.
447462
Ok(_) => (
@@ -491,14 +506,6 @@ impl<'tcx> CodegenCx<'tcx> {
491506
vs layout:\n{value_layout:#?}",
492507
entry_arg_abi.layout.ty
493508
);
494-
if is_pair && !is_unsized {
495-
// If PassMode is Pair, then we need to fill in the second part of the pair with a
496-
// value. We currently only do that with unsized types, so if a type is a pair for some
497-
// other reason (e.g. a tuple), we bail.
498-
self.tcx
499-
.dcx()
500-
.span_fatal(hir_param.ty_span, "pair type not supported yet")
501-
}
502509
// FIXME(eddyb) should this talk about "typed buffers" instead of "interface blocks"?
503510
// FIXME(eddyb) should we talk about "descriptor indexing" or
504511
// actually use more reasonable terms like "resource arrays"?
@@ -621,8 +628,8 @@ impl<'tcx> CodegenCx<'tcx> {
621628
}
622629
}
623630

624-
let value_len = if is_pair {
625-
// We've already emitted an error, fill in a placeholder value
631+
let value_len = if is_pair && is_unsized {
632+
// For wide references (e.g., slices), the second component is a length.
626633
Some(bx.undef(self.type_isize()))
627634
} else {
628635
None
@@ -645,21 +652,54 @@ impl<'tcx> CodegenCx<'tcx> {
645652
_ => unreachable!(),
646653
}
647654
} else {
648-
assert_matches!(entry_arg_abi.mode, PassMode::Direct(_));
649-
650-
let value = match storage_class {
651-
Ok(_) => {
655+
match entry_arg_abi.mode {
656+
PassMode::Direct(_) => {
657+
let value = match storage_class {
658+
Ok(_) => {
659+
assert_eq!(storage_class, Ok(StorageClass::Input));
660+
bx.load(
661+
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
662+
value_ptr.unwrap(),
663+
entry_arg_abi.layout.align.abi,
664+
)
665+
}
666+
Err(SpecConstant { .. }) => {
667+
spec_const_id.unwrap().with_type(value_spirv_type)
668+
}
669+
};
670+
call_args.push(value);
671+
assert_eq!(value_len, None);
672+
}
673+
PassMode::Pair(..) => {
674+
// Load both elements of the scalar pair from the input variable.
652675
assert_eq!(storage_class, Ok(StorageClass::Input));
653-
bx.load(
654-
entry_arg_abi.layout.spirv_type(hir_param.ty_span, bx),
655-
value_ptr.unwrap(),
656-
entry_arg_abi.layout.align.abi,
657-
)
676+
let layout = entry_arg_abi.layout;
677+
let (a, b) = match layout.backend_repr {
678+
rustc_abi::BackendRepr::ScalarPair(a, b) => (a, b),
679+
other => span_bug!(
680+
hir_param.ty_span,
681+
"ScalarPair expected for entry param, found {other:?}"
682+
),
683+
};
684+
let b_offset = a
685+
.primitive()
686+
.size(self)
687+
.align_to(b.primitive().align(self).abi);
688+
689+
let elem0_ty = self.scalar_pair_element_backend_type(layout, 0, false);
690+
let elem1_ty = self.scalar_pair_element_backend_type(layout, 1, false);
691+
692+
let base_ptr = value_ptr.unwrap();
693+
let ptr1 = bx.inbounds_ptradd(base_ptr, self.const_usize(b_offset.bytes()));
694+
695+
let v0 = bx.load(elem0_ty, base_ptr, layout.align.abi);
696+
let v1 = bx.load(elem1_ty, ptr1, layout.align.restrict_for_offset(b_offset));
697+
call_args.push(v0);
698+
call_args.push(v1);
699+
assert_eq!(value_len, None);
658700
}
659-
Err(SpecConstant { .. }) => spec_const_id.unwrap().with_type(value_spirv_type),
660-
};
661-
call_args.push(value);
662-
assert_eq!(value_len, None);
701+
_ => unreachable!(),
702+
}
663703
}
664704

665705
// FIXME(eddyb) check whether the storage class is compatible with the
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// compile-fail
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(1)))]
7+
pub fn main(
8+
w: (u32, u32),
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
10+
) {
11+
out[0] = w.0 + w.1;
12+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
error: compute entry parameter `w` must be by-reference
2+
--> $DIR/compute_value_pair_fail.rs:8:8
3+
|
4+
8 | w: (u32, u32),
5+
| ^^^^^^^^^^
6+
|
7+
= help: consider changing the type to `&(u32, u32)`
8+
9+
error: aborting due to 1 previous error
10+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(1)))]
7+
pub fn main(
8+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
9+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(f32, u32),
10+
) {
11+
out[0] = w.0.to_bits() ^ w.1;
12+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// build-pass
2+
// compile-flags: -C target-feature=+Int64
3+
#![no_std]
4+
5+
use spirv_std::spirv;
6+
7+
#[spirv(compute(threads(1)))]
8+
pub fn main(
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
10+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(i32, i32),
11+
) {
12+
// Sum and reinterpret as u32 for output
13+
let s = (w.0 as i64 + w.1 as i64) as i32;
14+
out[0] = s as u32;
15+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(compute(threads(1)))]
7+
pub fn main(
8+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
9+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &(u32, f32),
10+
) {
11+
let a = w.0;
12+
let b_bits = w.1.to_bits();
13+
out[0] = a ^ b_bits;
14+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[repr(transparent)]
7+
pub struct Wrap((u32, u32));
8+
9+
#[spirv(compute(threads(1)))]
10+
pub fn main(
11+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32],
12+
#[spirv(uniform, descriptor_set = 0, binding = 1)] w: &Wrap,
13+
) {
14+
let a = (w.0).0;
15+
let b = (w.0).1;
16+
out[0] = a + b;
17+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[repr(transparent)]
7+
pub struct Inner((u32, u32));
8+
9+
#[repr(transparent)]
10+
pub struct Outer(
11+
core::mem::ManuallyDrop<Inner>,
12+
core::marker::PhantomData<()>,
13+
);
14+
15+
#[inline(never)]
16+
fn sum_outer(o: Outer) -> u32 {
17+
// SAFETY: repr(transparent) guarantees same layout as `Inner`.
18+
let i: Inner = unsafe { core::mem::ManuallyDrop::into_inner((o.0)) };
19+
(i.0).0 + (i.0).1
20+
}
21+
22+
#[spirv(compute(threads(1)))]
23+
pub fn main(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut [u32]) {
24+
let i = Inner((5, 7));
25+
let o = Outer(core::mem::ManuallyDrop::new(i), core::marker::PhantomData);
26+
out[0] = sum_outer(o);
27+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// build-pass
2+
#![no_std]
3+
4+
use spirv_std::spirv;
5+
6+
#[spirv(fragment)]
7+
pub fn main(
8+
#[spirv(flat)] pi: (u32, u32),
9+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut [u32],
10+
) {
11+
out[0] = pi.0.wrapping_add(pi.1);
12+
}

0 commit comments

Comments
 (0)