Skip to content

Allow instructions to explicitly specify StorageClasses #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use itertools::Itertools;
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
use rustc_data_structures::fx::FxHashMap;
@@ -339,6 +339,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
PointeeDefState::Defining => {
let id = SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def(span, cx);
entry.insert(PointeeDefState::Defined(id));
@@ -350,6 +351,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
entry.insert(PointeeDefState::Defined(id));
SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def_with_id(cx, span, id)
}
46 changes: 33 additions & 13 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
@@ -407,14 +407,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
size: Size,
) -> Option<(SpirvValue, <Self as BackendTypes>::Type)> {
let ptr = ptr.strip_ptrcasts();
let mut leaf_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
let pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("non-pointer type: {other:?}")),
};

// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
// could instead be doing all the extra digging itself.
let mut indices = SmallVec::<[_; 8]>::new();
let mut leaf_ty = pointee_ty;
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
@@ -429,7 +430,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.then(|| self.type_ptr_to(leaf_ty))?;

let leaf_ptr = if indices.is_empty() {
assert_ty_eq!(self, ptr.ty, leaf_ptr_ty);
// Compare pointee types instead of pointer types as storage class might be different.
assert_ty_eq!(self, pointee_ty, leaf_ty);
ptr
} else {
let indices = indices
@@ -586,7 +588,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let ptr = ptr.strip_ptrcasts();
let ptr_id = ptr.def(self);
let original_pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("gep called on non-pointer type: {other:?}")),
};

@@ -1926,6 +1928,25 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return ptr;
}

// No cast is needed if only the storage class mismatches.
let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

// FIXME(jwollen) Do we need to choose `dest_ty` if it has a fixed storage class and `ptr` has none?
if ptr_pointee == dest_pointee {
return ptr;
}

// Strip a previous `pointercast`, to reveal the original pointer type.
let ptr = ptr.strip_ptrcasts();

@@ -1934,17 +1955,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

if ptr_pointee == dest_pointee {
return ptr;
}

let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);

if let Some((indices, _)) = self.recover_access_chain_from_offset(
@@ -2324,7 +2344,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));

let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
_ => self.fatal(format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
@@ -2696,7 +2716,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
(callee.def(self), return_type, arguments)
}

SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Pointer { pointee, .. } => match self.lookup_type(pointee) {
SpirvType::Function {
return_type,
arguments,
22 changes: 19 additions & 3 deletions crates/rustc_codegen_spirv/src/builder/mod.rs
Original file line number Diff line number Diff line change
@@ -15,8 +15,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::spirv::{StorageClass, Word};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
@@ -104,7 +104,23 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {

// HACK(eddyb) like the `CodegenCx` method but with `self.span()` awareness.
pub fn type_ptr_to(&self, ty: Word) -> Word {
SpirvType::Pointer { pointee: ty }.def(self.span(), self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(self.span(), self)
}

pub fn type_ptr_with_storage_class_to(
&self,
ty: Word,
storage_class: StorageClassKind,
) -> Word {
SpirvType::Pointer {
pointee: ty,
storage_class,
}
.def(self.span(), self)
}

// TODO: Definitely add tests to make sure this impl is right.
24 changes: 10 additions & 14 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::builder_spirv::{BuilderCursor, SpirvValue};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::dr;
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier, reflect};
use rspirv::spirv::{
@@ -307,19 +307,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
}
.def(self.span(), self),
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
if storage_class != StorageClass::Generic {
self.struct_err("TypePointer in asm! requires `Generic` storage class")
.with_note(format!(
"`{storage_class:?}` storage class was specified"
))
.with_help(format!(
"the storage class will be inferred automatically (e.g. to `{storage_class:?}`)"
))
.emit();
}
// The storage class can be specified explicitly or inferred later by using StorageClass::Generic.
let storage_class = match inst.operands[0].unwrap_storage_class() {
StorageClass::Generic => StorageClassKind::Inferred,
storage_class => StorageClassKind::Explicit(storage_class),
};
SpirvType::Pointer {
pointee: inst.operands[1].unwrap_id_ref(),
storage_class,
}
.def(self.span(), self)
}
@@ -678,6 +673,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {

TyPat::Pointer(_, pat) => SpirvType::Pointer {
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, cx),

@@ -931,7 +927,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Some(match kind {
TypeofKind::Plain => ty,
TypeofKind::Dereference => match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => {
self.tcx.dcx().span_err(
span,
@@ -953,7 +949,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.check_reg(span, reg);
if let Some(place) = place {
match self.lookup_type(place.val.llval.ty) {
SpirvType::Pointer { pointee } => Some(pointee),
SpirvType::Pointer { pointee, .. } => Some(pointee),
other => {
self.tcx.dcx().span_err(
span,
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
@@ -101,7 +101,7 @@ impl SpirvValue {
match entry.val {
SpirvConst::PtrTo { pointee } => {
let ty = match cx.lookup_type(self.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
ty => bug!("load called on value that wasn't a pointer: {:?}", ty),
};
// FIXME(eddyb) deduplicate this `if`-`else` and its other copies.
6 changes: 3 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
@@ -239,7 +239,7 @@ impl<'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'tcx> {
let (base_addr, _base_addr_space) = match self.tcx.global_alloc(alloc_id) {
GlobalAlloc::Memory(alloc) => {
let pointee = match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.tcx.dcx().fatal(format!(
"GlobalAlloc::Memory type not implemented: {}",
other.debug(ty, self)
@@ -259,7 +259,7 @@ impl<'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'tcx> {
.global_alloc(self.tcx.vtable_allocation((vty, dyn_ty.principal())))
.unwrap_memory();
let pointee = match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.tcx.dcx().fatal(format!(
"GlobalAlloc::VTable type not implemented: {}",
other.debug(ty, self)
@@ -328,7 +328,7 @@ impl<'tcx> CodegenCx<'tcx> {
if let Some(SpirvConst::ConstDataFromAlloc(alloc)) =
self.builder.lookup_const_by_id(pointee)
{
if let SpirvType::Pointer { pointee } = self.lookup_type(ty) {
if let SpirvType::Pointer { pointee, .. } = self.lookup_type(ty) {
let mut offset = Size::ZERO;
let init = self.read_from_const_alloc(alloc, &mut offset, pointee);
return self.static_addr_of(init, alloc.inner().align, None);
11 changes: 8 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/declare.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ use crate::abi::ConvSpirvType;
use crate::attr::AggregatedSpirvAttributes;
use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt};
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration};
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use itertools::Itertools;
use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word};
use rustc_attr::InlineAttr;
@@ -267,7 +267,12 @@ impl<'tcx> CodegenCx<'tcx> {
}

fn declare_global(&self, span: Span, ty: Word) -> SpirvValue {
let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self);
// Could be explicitly StorageClass::Private but is inferred anyway.
let ptr_ty = SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(span, self);
// FIXME(eddyb) figure out what the correct storage class is.
let result = self
.emit_global()
@@ -353,7 +358,7 @@ impl<'tcx> StaticCodegenMethods for CodegenCx<'tcx> {
Err(_) => return,
};
let value_ty = match self.lookup_type(g.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.tcx.dcx().fatal(format!(
"global had non-pointer type {}",
other.debug(g.ty, self)
2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
@@ -967,7 +967,7 @@ impl<'tcx> CodegenCx<'tcx> {
| SpirvType::Matrix { element, .. }
| SpirvType::Array { element, .. }
| SpirvType::RuntimeArray { element }
| SpirvType::Pointer { pointee: element }
| SpirvType::Pointer { pointee: element, .. }
| SpirvType::InterfaceBlock {
inner_type: element,
} => recurse(cx, element, has_bool, must_be_flat),
15 changes: 12 additions & 3 deletions crates/rustc_codegen_spirv/src/codegen_cx/mod.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ mod type_;
use crate::builder::{ExtInst, InstructionTable};
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvConst, SpirvValue, SpirvValueKind};
use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration};
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
use crate::spirv_type::{SpirvType, SpirvTypePrinter, StorageClassKind, TypeCache};
use crate::symbols::Symbols;
use crate::target::SpirvTarget;

@@ -234,11 +234,19 @@ impl<'tcx> CodegenCx<'tcx> {
}

pub fn type_ptr_to(&self, ty: Word) -> Word {
SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, self)
}

pub fn type_ptr_to_ext(&self, ty: Word, _address_space: AddressSpace) -> Word {
SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, self)
}

/// Zombie system:
@@ -866,6 +874,7 @@ impl<'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'tcx> {

let ty = SpirvType::Pointer {
pointee: function.ty,
storage_class: StorageClassKind::Inferred,
}
.def(span, self);

2 changes: 1 addition & 1 deletion crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
@@ -219,7 +219,7 @@ impl<'tcx> BaseTypeCodegenMethods<'tcx> for CodegenCx<'tcx> {
}
fn element_type(&self, ty: Self::Type) -> Self::Type {
match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
SpirvType::Vector { element, .. } => element,
spirv_type => self.tcx.dcx().fatal(format!(
"element_type called on invalid type: {spirv_type:?}"
14 changes: 11 additions & 3 deletions crates/rustc_codegen_spirv/src/linker/specializer.rs
Original file line number Diff line number Diff line change
@@ -1615,6 +1615,14 @@ impl<'a, S: Specialization> InferCx<'a, S> {

#[allow(clippy::match_same_arms)]
Ok(match (a.clone(), b.clone()) {
// Concrete result types explicitly created inside functions
// can be assigned to instances.
// FIXME(jwollen) do we need to infere instance generics?
(InferOperand::Instance(_), InferOperand::Concrete(new))
| (InferOperand::Concrete(new), InferOperand::Instance(_)) => {
InferOperand::Concrete(new)
}

// Instances of "generic" globals/functions must be of the same ID,
// and their `generic_args` inference variables must be unified.
(
@@ -1999,13 +2007,13 @@ impl<'a, S: Specialization> InferCx<'a, S> {

if let Some(type_of_result) = type_of_result {
// Keep the (instantiated) *Result Type*, for future instructions to use
// (but only if it has any `InferVar`s at all).
// if it has any `InferVar`s at all or if it was a concrete type.
match type_of_result {
InferOperand::Var(_) | InferOperand::Instance(_) => {
InferOperand::Var(_) | InferOperand::Instance(_) | InferOperand::Concrete(_) => {
self.type_of_result
.insert(inst.result_id.unwrap(), type_of_result);
}
InferOperand::Unknown | InferOperand::Concrete(_) => {}
InferOperand::Unknown => {}
}
}
}
62 changes: 51 additions & 11 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
@@ -61,6 +61,7 @@ pub enum SpirvType<'tcx> {
},
Pointer {
pointee: Word,
storage_class: StorageClassKind,
},
Function {
return_type: Word,
@@ -90,6 +91,17 @@ pub enum SpirvType<'tcx> {
RayQueryKhr,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum StorageClassKind {
/// Inferred based on globals and other pointers with explicit storage classes.
/// This corresponds to `StorageClass::Generic` in inline `asm!` and intermediate SPIR-V.
Inferred,

/// Explicitly set by an instruction that needs to create a storage class,
/// regardless of inputs.
Explicit(StorageClass),
}

impl SpirvType<'_> {
/// Note: `Builder::type_*` should be called *nowhere else* but here, to ensure
/// `CodegenCx::type_defs` stays up-to-date
@@ -213,13 +225,18 @@ impl SpirvType<'_> {
);
result
}
Self::Pointer { pointee } => {
Self::Pointer {
pointee,
storage_class,
} => {
// NOTE(eddyb) we emit `StorageClass::Generic` here, but later
// the linker will specialize the entire SPIR-V module to use
// storage classes inferred from `OpVariable`s.
let result = cx
.emit_global()
.type_pointer(id, StorageClass::Generic, pointee);
let storage_class = match storage_class {
StorageClassKind::Inferred => StorageClass::Generic,
StorageClassKind::Explicit(storage_class) => storage_class,
};
let result = cx.emit_global().type_pointer(id, storage_class, pointee);
// no pointers to functions
if let SpirvType::Function { .. } = cx.lookup_type(pointee) {
// FIXME(eddyb) use the `SPV_INTEL_function_pointers` extension.
@@ -286,13 +303,20 @@ impl SpirvType<'_> {
return cached;
}
let result = match self {
Self::Pointer { pointee } => {
Self::Pointer {
pointee,
storage_class,
} => {
// NOTE(eddyb) we emit `StorageClass::Generic` here, but later
// the linker will specialize the entire SPIR-V module to use
// storage classes inferred from `OpVariable`s.
let result =
cx.emit_global()
.type_pointer(Some(id), StorageClass::Generic, pointee);
let storage_class = match storage_class {
StorageClassKind::Inferred => StorageClass::Generic,
StorageClassKind::Explicit(storage_class) => storage_class,
};
let result = cx
.emit_global()
.type_pointer(Some(id), storage_class, pointee);
// no pointers to functions
if let SpirvType::Function { .. } = cx.lookup_type(pointee) {
// FIXME(eddyb) use the `SPV_INTEL_function_pointers` extension.
@@ -412,7 +436,13 @@ impl SpirvType<'_> {
SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count },
SpirvType::Array { element, count } => SpirvType::Array { element, count },
SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element },
SpirvType::Pointer { pointee } => SpirvType::Pointer { pointee },
SpirvType::Pointer {
pointee,
storage_class,
} => SpirvType::Pointer {
pointee,
storage_class,
},
SpirvType::Image {
sampled_type,
dim,
@@ -557,10 +587,14 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("id", &self.id)
.field("element", &self.cx.debug_type(element))
.finish(),
SpirvType::Pointer { pointee } => f
SpirvType::Pointer {
pointee,
storage_class,
} => f
.debug_struct("Pointer")
.field("id", &self.id)
.field("pointee", &self.cx.debug_type(pointee))
.field("storage_class", &storage_class)
.finish(),
SpirvType::Function {
return_type,
@@ -710,8 +744,14 @@ impl SpirvTypePrinter<'_, '_> {
ty(self.cx, stack, f, element)?;
f.write_str("]")
}
SpirvType::Pointer { pointee } => {
SpirvType::Pointer {
pointee,
storage_class,
} => {
f.write_str("*")?;
if let StorageClassKind::Explicit(storage_class) = storage_class {
write!(f, "{:?}", storage_class)?;
}
ty(self.cx, stack, f, pointee)
}
SpirvType::Function {
5 changes: 4 additions & 1 deletion crates/rustc_codegen_spirv/src/spirv_type_constraints.rs
Original file line number Diff line number Diff line change
@@ -427,7 +427,10 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> {
Op::ConvertPtrToU | Op::SatConvertSToU | Op::SatConvertUToS | Op::ConvertUToPtr => {}
Op::PtrCastToGeneric | Op::GenericCastToPtr => sig! { (Pointer(_, T)) -> Pointer(_, T) },
Op::GenericCastToPtrExplicit => sig! { {S} (Pointer(_, T)) -> Pointer(S, T) },
Op::Bitcast => {}
Op::Bitcast => sig! {
(Pointer(S, _)) -> Pointer(S, _) |
(_) -> _
},

// 3.37.12. Composite Instructions
Op::VectorExtractDynamic => sig! { (Vector(T), _) -> T },