From 12d37bc65a8b52191b7a6a66a70dbc18034ac2ee Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 24 Sep 2025 21:22:00 -0700 Subject: [PATCH 01/10] Switch rspirv to the latest git version --- Cargo.lock | 19 +++++++++++++------ Cargo.toml | 1 + 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5799df8143..8bcc71cb9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2504,7 +2504,7 @@ dependencies = [ "rspirv", "rustc-hash 1.1.0", "serde", - "spirv", + "spirv 0.3.0+sdk-1.3.268.0", "strum", "thiserror 2.0.17", "unicode-ident", @@ -2543,7 +2543,7 @@ dependencies = [ "ron", "serde", "serde_json", - "spirv", + "spirv 0.3.0+sdk-1.3.268.0", "toml", ] @@ -3546,12 +3546,11 @@ dependencies = [ [[package]] name = "rspirv" -version = "0.12.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" +version = "0.12.0+sdk-1.4.309.0" +source = "git+https://github.com/gfx-rs/rspirv?rev=89ce4d0e64c91b0635f617409dc57cb031749a39#89ce4d0e64c91b0635f617409dc57cb031749a39" dependencies = [ "rustc-hash 1.1.0", - "spirv", + "spirv 0.3.0+sdk-1.4.309.0", ] [[package]] @@ -3906,6 +3905,14 @@ dependencies = [ "serde", ] +[[package]] +name = "spirv" +version = "0.3.0+sdk-1.4.309.0" +source = "git+https://github.com/gfx-rs/rspirv?rev=89ce4d0e64c91b0635f617409dc57cb031749a39#89ce4d0e64c91b0635f617409dc57cb031749a39" +dependencies = [ + "bitflags 2.10.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index 965556b6b4..92ccdc48af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -258,6 +258,7 @@ ndk-sys = "0.6" # These overrides allow our examples to explicitly depend on release crates [patch.crates-io] wgpu = { path = "./wgpu" } +rspirv = { git = "https://github.com/gfx-rs/rspirv", rev = "89ce4d0e64c91b0635f617409dc57cb031749a39" } # https://github.com/Xudong-Huang/generator-rs/pull/75 generator = { git = "https://github.com/Xudong-Huang/generator-rs", rev = "70b89fdabcc0e82fe84ca17f65cc52ff25e8e6de" } From ab50ab0f289ab6f7a4fd5ec8aeb2e8e16fa810b8 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 16 Sep 2025 23:32:57 -0700 Subject: [PATCH 02/10] Add Cooperative* type to IR --- naga/src/back/glsl/mod.rs | 3 ++- naga/src/back/msl/writer.rs | 23 ++++++++++++++++ naga/src/back/spv/writer.rs | 2 ++ naga/src/common/wgsl/to_wgsl.rs | 19 ++++++++++++++ naga/src/common/wgsl/types.rs | 13 +++++++++ naga/src/compact/types.rs | 2 ++ naga/src/front/wgsl/lower/conversion.rs | 2 ++ naga/src/ir/mod.rs | 35 +++++++++++++++++++++++++ naga/src/proc/layouter.rs | 18 +++++++++++++ naga/src/proc/type_methods.rs | 9 ++++++- naga/src/valid/handles.rs | 1 + naga/src/valid/mod.rs | 1 + naga/src/valid/type.rs | 19 ++++++++++++++ 13 files changed, 145 insertions(+), 2 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 521e2dcade..ce2e627b0f 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -1108,7 +1108,8 @@ impl<'a, W: Write> Writer<'a, W> { TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?, // Write all variants instead of `_` so that if new variants are added a // no exhaustiveness error is thrown - TypeInner::Pointer { .. } + TypeInner::CooperativeMatrix { .. } + | TypeInner::Pointer { .. } | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 484142630d..5071276568 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -235,6 +235,20 @@ impl Display for TypeContext<'_> { rows, scalar, } => put_numeric_type(out, scalar, &[rows, columns]), + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + } => { + write!( + out, + "{}::simdgroup_{}{}x{}", + NAMESPACE, + scalar.to_msl_name(), + columns as u32, + rows as u32, + ) + } crate::TypeInner::Pointer { base, space } => { let sub = Self { handle: base, @@ -528,6 +542,14 @@ impl crate::Scalar { } } +impl crate::CooperativeScalar { + const fn to_msl_name(self) -> &'static str { + match self { + Self::F32 => "float", + } + } +} + const fn separate(need_separator: bool) -> &'static str { if need_separator { "," @@ -640,6 +662,7 @@ impl crate::Type { Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::Pointer { .. } | Ti::ValuePointer { .. } => self.name.is_some(), diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 1beb86577c..83cc20c432 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -435,6 +435,7 @@ impl Writer { // these cases, so unwrap. LocalType::Numeric(NumericType::from_inner(inner).unwrap()) } + crate::TypeInner::CooperativeMatrix { .. } => return None, crate::TypeInner::Pointer { base, space } => { let base_type_id = self.get_handle_type_id(base); LocalType::Pointer { @@ -1502,6 +1503,7 @@ impl Writer { | crate::TypeInner::Atomic(_) | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 25847a5df7..b0318972fe 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -305,6 +305,25 @@ impl TryToWgsl for crate::Scalar { } } +impl TryToWgsl for crate::CooperativeScalar { + const DESCRIPTION: &'static str = "cooperative scalar type"; + + fn try_to_wgsl(self) -> Option<&'static str> { + use crate::CooperativeScalar; + + Some(match self { + CooperativeScalar::F32 => "f32", + }) + } + + fn to_wgsl_for_diagnostics(self) -> String { + match self.try_to_wgsl() { + Some(static_string) => static_string.to_string(), + None => unreachable!(), + } + } +} + impl ToWgsl for crate::ImageDimension { fn to_wgsl(self) -> &'static str { use crate::ImageDimension as IDim; diff --git a/naga/src/common/wgsl/types.rs b/naga/src/common/wgsl/types.rs index 82b8eeaa67..93a94205d7 100644 --- a/naga/src/common/wgsl/types.rs +++ b/naga/src/common/wgsl/types.rs @@ -317,6 +317,19 @@ where ctx.write_scalar(scalar, out)?; out.write_str(">")?; } + TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + } => { + write!( + out, + "coop_mat{}x{}<{}>", + columns as u32, + rows as u32, + scalar.try_to_wgsl().unwrap_or_default() + )?; + } TypeInner::Pointer { base, space } => { let (address, maybe_access) = address_space_str(space); // Everything but `AddressSpace::Handle` gives us a `address` name, but diff --git a/naga/src/compact/types.rs b/naga/src/compact/types.rs index 0a1db16f9f..d06558b182 100644 --- a/naga/src/compact/types.rs +++ b/naga/src/compact/types.rs @@ -16,6 +16,7 @@ impl TypeTracer<'_> { Ti::Scalar { .. } | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic { .. } | Ti::ValuePointer { .. } | Ti::Image { .. } @@ -66,6 +67,7 @@ impl ModuleMap { Ti::Scalar(_) | Ti::Vector { .. } | Ti::Matrix { .. } + | Ti::CooperativeMatrix { .. } | Ti::Atomic(_) | Ti::ValuePointer { .. } | Ti::Image { .. } diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs index b22692a3cd..9e03ed5c9e 100644 --- a/naga/src/front/wgsl/lower/conversion.rs +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -350,6 +350,7 @@ impl crate::TypeInner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } + Ti::CooperativeMatrix { .. } => None, Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), Ti::Atomic(_) | Ti::Pointer { .. } @@ -375,6 +376,7 @@ impl crate::TypeInner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { Some(scalar) } + Ti::CooperativeMatrix { .. } => None, Ti::Atomic(_) => None, Ti::Pointer { base, .. } | Ti::Array { base, .. } => { types[base].inner.automatically_convertible_scalar(types) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 4093d823b4..28fa1dac08 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -486,6 +486,16 @@ impl From for u32 { } } +/// Number of components in a cooperative vector. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeVectorSize { + Eight = 8, +} + /// Primitive type for a scalar. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -513,6 +523,24 @@ pub enum ScalarKind { AbstractFloat, } +/// Primitive type for a scalar. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeScalar { + F32, +} + +impl CooperativeScalar { + pub const fn width(&self) -> Bytes { + match *self { + Self::F32 => 4, + } + } +} + /// Characteristics of a scalar type. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -761,6 +789,13 @@ pub enum TypeInner { rows: VectorSize, scalar: Scalar, }, + /// Matrix that is cooperatively processed by all the threads + /// in an opaque mapping. + CooperativeMatrix { + columns: CooperativeVectorSize, + rows: CooperativeVectorSize, + scalar: CooperativeScalar, + }, /// Atomic scalar. Atomic(Scalar), /// Pointer to another type. diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 204a523c91..5e7aed8a0f 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -86,6 +86,12 @@ impl From for Alignment { } } +impl From for Alignment { + fn from(size: crate::CooperativeVectorSize) -> Self { + Self(unsafe { NonZeroU32::new_unchecked(size as u32) }) + } +} + /// Size and alignment information for a type. #[derive(Clone, Copy, Debug, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -212,6 +218,18 @@ impl Layouter { alignment: Alignment::from(rows) * alignment, } } + Ti::CooperativeMatrix { + columns: _, + rows, + scalar, + } => { + let alignment = Alignment::new(scalar.width() as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(rows) * alignment, + } + } Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { size, alignment: Alignment::ONE, diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index c59d524f13..24a14868f9 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -202,6 +202,11 @@ impl crate::TypeInner { rows, scalar, } => Some(super::Alignment::from(rows) * scalar.width as u32 * columns as u32), + Self::CooperativeMatrix { + columns, + rows, + scalar, + } => Some(columns as u32 * rows as u32 * scalar.width() as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { base: _, @@ -361,6 +366,7 @@ impl crate::TypeInner { crate::TypeInner::Scalar(scalar) => Some((None, scalar)), crate::TypeInner::Vector { size, scalar } => Some((Some(size), scalar)), crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::Atomic(_) | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } @@ -385,7 +391,8 @@ impl crate::TypeInner { | crate::TypeInner::Matrix { scalar, .. } | crate::TypeInner::Atomic(scalar) => scalar.is_abstract(), crate::TypeInner::Array { base, .. } => types[base].inner.is_abstract(types), - crate::TypeInner::ValuePointer { .. } + crate::TypeInner::CooperativeMatrix { .. } + | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Pointer { .. } | crate::TypeInner::Struct { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index adb9f355c1..303b6cf193 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -393,6 +393,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } + | crate::TypeInner::CooperativeMatrix { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Atomic { .. } | crate::TypeInner::Image { .. } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index a5ec5affce..af3e8af3a2 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -457,6 +457,7 @@ impl crate::TypeInner { Self::Scalar { .. } | Self::Vector { .. } | Self::Matrix { .. } + | Self::CooperativeMatrix { .. } | Self::Array { size: crate::ArraySize::Constant(_), .. diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index aa0633e185..62c81b8c74 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -418,6 +418,25 @@ impl super::Validator { type_info.push_constant_compatibility = push_constant_compatibility; type_info } + Ti::CooperativeMatrix { + columns: _, + rows: _, + scalar, + } => { + if scalar != crate::CooperativeScalar::F32 { + return Err(TypeError::MatrixElementNotFloat); + } + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | TypeFlags::CREATION_RESOLVED, + Alignment::from_width(scalar.width()), + ) + } Ti::Atomic(scalar) => { match scalar { crate::Scalar { From d6a3cf43ff2506ab70202c34d0c020a5b7943b8a Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 17 Sep 2025 00:00:42 -0700 Subject: [PATCH 03/10] coop: first bits of Vulkan support for the type --- naga/src/back/spv/instructions.rs | 16 +++++++++++ naga/src/back/spv/mod.rs | 28 +++++++++++++++++++ naga/src/back/spv/writer.rs | 46 +++++++++++++++++++++++++++---- naga/src/valid/mod.rs | 2 ++ naga/src/valid/type.rs | 1 + 5 files changed, 88 insertions(+), 5 deletions(-) diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 788c3bc119..5e8c22e62d 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -281,6 +281,22 @@ impl super::Instruction { instruction } + pub(super) fn type_coop_matrix( + id: Word, + scalar_type_id: Word, + row_count: crate::CooperativeVectorSize, + column_count: crate::CooperativeVectorSize, + ) -> Self { + let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); + instruction.set_result(id); + instruction.add_operand(scalar_type_id); + instruction.add_operand(spirv::Scope::Subgroup as u32); + instruction.add_operand(column_count as u32); + instruction.add_operand(row_count as u32); + instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose + instruction + } + #[allow(clippy::too_many_arguments)] pub(super) fn type_image( id: Word, diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 4690dc7195..2b01a5a4c6 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -339,6 +339,33 @@ impl NumericType { } } +/// A cooperative type, for use in [`LocalType`]. +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum CooperativeType { + Matrix { + columns: crate::CooperativeVectorSize, + rows: crate::CooperativeVectorSize, + scalar: crate::CooperativeScalar, + }, +} + +impl CooperativeType { + const fn from_inner(inner: &crate::TypeInner) -> Option { + match *inner { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + } => Some(Self::Matrix { + columns, + rows, + scalar, + }), + _ => None, + } + } +} + /// A SPIR-V type constructed during code generation. /// /// This is the variant of [`LookupType`] used to represent types that might not @@ -388,6 +415,7 @@ impl NumericType { enum LocalType { /// A numeric type. Numeric(NumericType), + Cooperative(CooperativeType), Pointer { base: Word, class: spirv::StorageClass, diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 83cc20c432..91ef89c072 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -6,10 +6,11 @@ use spirv::Word; use super::{ block::DebugInfoInner, helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, - Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, EntryPointContext, Error, - Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalImageType, - LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, NumericType, Options, - PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, + Block, BlockContext, CachedConstant, CachedExpressions, CooperativeType, DebugInfo, + EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, + LocalImageType, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, + NumericType, Options, PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, + BITS_PER_BYTE, }; use crate::{ arena::{Handle, HandleVec, UniqueArena}, @@ -374,6 +375,12 @@ impl Writer { }) } + pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word { + match scalar { + crate::CooperativeScalar::F32 => self.get_f32_type_id(), + } + } + pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let f32_id = self.get_f32_type_id(); self.get_pointer_type_id(f32_id, class) @@ -435,7 +442,9 @@ impl Writer { // these cases, so unwrap. LocalType::Numeric(NumericType::from_inner(inner).unwrap()) } - crate::TypeInner::CooperativeMatrix { .. } => return None, + crate::TypeInner::CooperativeMatrix { .. } => { + LocalType::Cooperative(CooperativeType::from_inner(inner).unwrap()) + } crate::TypeInner::Pointer { base, space } => { let base_type_id = self.get_handle_type_id(base); LocalType::Pointer { @@ -1355,6 +1364,14 @@ impl Writer { self.require_any("16 bit floating-point", &[spirv::Capability::Float16])?; self.use_extension("SPV_KHR_16bit_storage"); } + // Cooperative types and ops + crate::TypeInner::CooperativeMatrix { .. } => { + self.require_any( + "cooperative matrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + self.use_extension("SPV_KHR_cooperative_matrix"); + } _ => {} } Ok(()) @@ -1381,12 +1398,31 @@ impl Writer { instruction.to_words(&mut self.logical_layout.declarations); } + fn write_cooperative_type_declaration_local(&mut self, id: Word, coop: CooperativeType) { + let instruction = match coop { + CooperativeType::Matrix { + columns, + rows, + scalar, + } => { + let scalar_id = self.get_cooperative_type_id(scalar); + Instruction::type_coop_matrix(id, scalar_id, rows, columns) + } + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { let instruction = match local_ty { LocalType::Numeric(numeric) => { self.write_numeric_type_declaration_local(id, numeric); return; } + LocalType::Cooperative(coop) => { + self.write_cooperative_type_declaration_local(id, coop); + return; + } LocalType::Pointer { base, class } => Instruction::type_pointer(id, class, base), LocalType::Image(image) => { let local_type = LocalType::Numeric(NumericType::Scalar(image.sampled_type)); diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index af3e8af3a2..3af000d6d4 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -190,6 +190,8 @@ bitflags::bitflags! { const SHADER_BARYCENTRICS = 1 << 29; /// Support for task shaders, mesh shaders, and per-primitive fragment inputs const MESH_SHADER = 1 << 30; + /// Support for cooperative matrix types and operations + const COOPERATIVE_MATRIX = 1 << 31; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 62c81b8c74..49022d3069 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -423,6 +423,7 @@ impl super::Validator { rows: _, scalar, } => { + self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; if scalar != crate::CooperativeScalar::F32 { return Err(TypeError::MatrixElementNotFloat); } From 3f7c47ecd346ede409edd2a3a16f4db3ba8d47cd Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 19 Sep 2025 22:38:29 -0700 Subject: [PATCH 04/10] coop: wgsl parsing, IR role --- naga/src/back/msl/writer.rs | 1 + naga/src/back/spv/instructions.rs | 17 +++- naga/src/back/spv/mod.rs | 7 +- naga/src/back/spv/writer.rs | 3 +- naga/src/common/wgsl/to_wgsl.rs | 20 +++-- naga/src/common/wgsl/types.rs | 6 +- naga/src/front/wgsl/error.rs | 12 +++ naga/src/front/wgsl/lower/construction.rs | 26 ++++++ naga/src/front/wgsl/lower/mod.rs | 22 +++++ naga/src/front/wgsl/parse/ast.rs | 22 +++++ naga/src/front/wgsl/parse/lexer.rs | 12 +++ naga/src/front/wgsl/parse/mod.rs | 58 +++++++++++++ naga/src/ir/mod.rs | 21 ++++- naga/src/proc/layouter.rs | 5 +- naga/src/proc/type_methods.rs | 1 + naga/src/valid/type.rs | 1 + naga/tests/in/wgsl/cooperative-matrix.toml | 2 + naga/tests/in/wgsl/cooperative-matrix.wgsl | 7 ++ .../analysis/wgsl-cooperative-matrix.info.ron | 78 +++++++++++++++++ .../ir/wgsl-cooperative-matrix.compact.ron | 84 +++++++++++++++++++ naga/tests/out/ir/wgsl-cooperative-matrix.ron | 84 +++++++++++++++++++ .../out/spv/wgsl-cooperative-matrix.spvasm | 17 ++++ 22 files changed, 486 insertions(+), 20 deletions(-) create mode 100644 naga/tests/in/wgsl/cooperative-matrix.toml create mode 100644 naga/tests/in/wgsl/cooperative-matrix.wgsl create mode 100644 naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron create mode 100644 naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron create mode 100644 naga/tests/out/ir/wgsl-cooperative-matrix.ron create mode 100644 naga/tests/out/spv/wgsl-cooperative-matrix.spvasm diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 5071276568..c2b1dcb4dc 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -239,6 +239,7 @@ impl Display for TypeContext<'_> { columns, rows, scalar, + role: _, } => { write!( out, diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 5e8c22e62d..bb559606d9 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -284,8 +284,9 @@ impl super::Instruction { pub(super) fn type_coop_matrix( id: Word, scalar_type_id: Word, - row_count: crate::CooperativeVectorSize, - column_count: crate::CooperativeVectorSize, + row_count: crate::CooperativeSize, + column_count: crate::CooperativeSize, + role: spirv::CooperativeMatrixUse, ) -> Self { let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); instruction.set_result(id); @@ -293,7 +294,7 @@ impl super::Instruction { instruction.add_operand(spirv::Scope::Subgroup as u32); instruction.add_operand(column_count as u32); instruction.add_operand(row_count as u32); - instruction.add_operand(spirv::CooperativeMatrixUse::MatrixAKHR as u32); //TODO: configure or expose + instruction.add_operand(role as u32); instruction } @@ -1305,3 +1306,13 @@ impl From for spirv::Dim { } } } + +impl From for spirv::CooperativeMatrixUse { + fn from(role: crate::CooperativeRole) -> Self { + match role { + crate::CooperativeRole::A => Self::MatrixAKHR, + crate::CooperativeRole::B => Self::MatrixBKHR, + crate::CooperativeRole::C => Self::MatrixAccumulatorKHR, + } + } +} diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 2b01a5a4c6..f897dc5192 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -343,9 +343,10 @@ impl NumericType { #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum CooperativeType { Matrix { - columns: crate::CooperativeVectorSize, - rows: crate::CooperativeVectorSize, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, scalar: crate::CooperativeScalar, + role: crate::CooperativeRole, }, } @@ -356,10 +357,12 @@ impl CooperativeType { columns, rows, scalar, + role, } => Some(Self::Matrix { columns, rows, scalar, + role, }), _ => None, } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 91ef89c072..293737ce82 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1404,9 +1404,10 @@ impl Writer { columns, rows, scalar, + role, } => { let scalar_id = self.get_cooperative_type_id(scalar); - Instruction::type_coop_matrix(id, scalar_id, rows, columns) + Instruction::type_coop_matrix(id, scalar_id, rows, columns, role.into()) } }; diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index b0318972fe..d8988c7f34 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -324,15 +324,23 @@ impl TryToWgsl for crate::CooperativeScalar { } } -impl ToWgsl for crate::ImageDimension { +impl ToWgsl for crate::CooperativeRole { fn to_wgsl(self) -> &'static str { - use crate::ImageDimension as IDim; + match self { + Self::A => "A", + Self::B => "B", + Self::C => "C", + } + } +} +impl ToWgsl for crate::ImageDimension { + fn to_wgsl(self) -> &'static str { match self { - IDim::D1 => "1d", - IDim::D2 => "2d", - IDim::D3 => "3d", - IDim::Cube => "cube", + Self::D1 => "1d", + Self::D2 => "2d", + Self::D3 => "3d", + Self::Cube => "cube", } } } diff --git a/naga/src/common/wgsl/types.rs b/naga/src/common/wgsl/types.rs index 93a94205d7..a678a617f7 100644 --- a/naga/src/common/wgsl/types.rs +++ b/naga/src/common/wgsl/types.rs @@ -321,13 +321,15 @@ where columns, rows, scalar, + role, } => { write!( out, - "coop_mat{}x{}<{}>", + "coop_mat{}x{}<{},{}>", columns as u32, rows as u32, - scalar.try_to_wgsl().unwrap_or_default() + scalar.try_to_wgsl().unwrap_or_default(), + role.to_wgsl(), )?; } TypeInner::Pointer { base, space } => { diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 17dab5cb0e..8c749acc73 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -412,6 +412,8 @@ pub(crate) enum Error<'a> { TypeTooLarge { span: Span, }, + UnderspecifiedCooperativeMatrix, + UnknownCooperativeScalar(Span), } impl From for Error<'_> { @@ -1386,6 +1388,16 @@ impl<'a> Error<'a> { crate::valid::MAX_TYPE_SIZE )], }, + Error::UnderspecifiedCooperativeMatrix => ParseError { + message: "cooperative matrix constructor is underspecified".into(), + labels: vec![], + notes: vec![format!("must be F32")], + }, + Error::UnknownCooperativeScalar(span) => ParseError { + message: "unknown cooperative scalar type".into(), + labels: vec![(span, "type needs the scalar type specified".into())], + notes: vec![format!("must be F32")], + }, } } } diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 997d5a3123..9ac11bfc98 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -638,6 +638,32 @@ impl<'source> Lowerer<'source, '_> { }; Constructor::Type(ty) } + ast::ConstructorType::PartialCooperativeMatrix { .. } => { + return Err(Box::new(Error::UnderspecifiedCooperativeMatrix)); + } + ast::ConstructorType::CooperativeMatrix { + rows, + columns, + ty, + ty_span, + role, + } => { + let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; + let scalar = match ctx.module.types[ty].inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }) => crate::CooperativeScalar::F32, + _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + }; + let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + }); + Constructor::Type(ty) + } ast::ConstructorType::PartialArray => Constructor::PartialArray, ast::ConstructorType::Array { base, size } => { let base = self.resolve_ast_type(base, &mut ctx.as_const())?; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 2066d7cf2c..022aa3c4bb 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3957,6 +3957,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { _ => return Err(Box::new(Error::BadMatrixScalarKind(ty_span, scalar))), } } + ast::Type::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + } => { + let ty = self.resolve_ast_type(ty, ctx)?; + let scalar = match ctx.module.types[ty].inner { + ir::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }) => crate::CooperativeScalar::F32, + _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + }; + ir::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + role, + } + } ast::Type::Atomic(scalar) => scalar.to_inner_atomic(), ast::Type::Pointer { base, space } => { let base = self.resolve_ast_type(base, ctx)?; diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c48..af05a84110 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -235,6 +235,13 @@ pub enum Type<'a> { ty: Handle>, ty_span: Span, }, + CooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ty: Handle>, + ty_span: Span, + role: crate::CooperativeRole, + }, Atomic(Scalar), Pointer { base: Handle>, @@ -385,6 +392,21 @@ pub enum ConstructorType<'a> { ty_span: Span, }, + /// A cooperative matrix construction base `coop_mat8x8(...)`. + PartialCooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + }, + + /// A full cooperative matrix construction `coop_mat8x8(...)`. + CooperativeMatrix { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ty: Handle>, + ty_span: Span, + role: crate::CooperativeRole, + }, + /// An array whose component type and size are inferred from the arguments: /// `array(3,4,5)`. PartialArray, diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index d0a8033987..ed87e37100 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -584,6 +584,18 @@ impl<'a> Lexer<'a> { }) } + pub(in crate::front::wgsl) fn next_cooperative_role( + &mut self, + ) -> Result<'a, crate::CooperativeRole> { + let (ident, span) = self.next_ident_with_span()?; + match ident { + "A" => Ok(crate::CooperativeRole::A), + "B" => Ok(crate::CooperativeRole::B), + "C" => Ok(crate::CooperativeRole::C), + _ => Err(Box::new(Error::UnknownAccess(span))), + } + } + pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<'a, ()> { self.expect(Token::Paren('(')) } diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c01ba4de30..49d7eaab25 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -658,6 +658,12 @@ impl Parser { ty_span: Span::UNDEFINED, })) } + "coop_mat8x8" => { + return Ok(Some(ast::ConstructorType::PartialCooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::Eight, + })) + } "array" => ast::ConstructorType::PartialArray, "atomic" | "binding_array" @@ -701,6 +707,19 @@ impl Parser { ty_span, })) } + ( + Token::Paren('<'), + ast::ConstructorType::PartialCooperativeMatrix { columns, rows }, + ) => { + let (ty, ty_span, role) = self.cooperative_scalar_and_role(lexer, ctx)?; + Ok(Some(ast::ConstructorType::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + })) + } (Token::Paren('<'), ast::ConstructorType::PartialArray) => { lexer.expect_generic_paren('<')?; let base = self.type_decl(lexer, ctx)?; @@ -1437,6 +1456,22 @@ impl Parser { Ok((ty, span)) } + /// Parses ``, returning (T, span of T, R, span of R) + fn cooperative_scalar_and_role<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<'a, (Handle>, Span, crate::CooperativeRole)> { + lexer.expect_generic_paren('<')?; + let start = lexer.start_byte_offset(); + let ty = self.type_decl(lexer, ctx)?; + let ty_span = lexer.span_from(start); + lexer.expect(Token::Separator(','))?; + let role = lexer.next_cooperative_role()?; + lexer.expect_generic_paren('>')?; + Ok((ty, ty_span, role)) + } + fn matrix_with_type<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1453,6 +1488,23 @@ impl Parser { }) } + fn cooperative_matrix_with_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + ) -> Result<'a, ast::Type<'a>> { + let (ty, ty_span, role) = self.cooperative_scalar_and_role(lexer, ctx)?; + Ok(ast::Type::CooperativeMatrix { + columns, + rows, + ty, + ty_span, + role, + }) + } + fn type_decl_impl<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1684,6 +1736,12 @@ impl Parser { ty: ctx.new_scalar(Scalar::F16), ty_span: Span::UNDEFINED, }, + "coop_mat8x8" => self.cooperative_matrix_with_type( + lexer, + ctx, + crate::CooperativeSize::Eight, + crate::CooperativeSize::Eight, + )?, "atomic" => { let scalar = lexer.next_scalar_generic()?; ast::Type::Atomic(scalar) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 28fa1dac08..31c3d5f8f2 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -492,7 +492,7 @@ impl From for u32 { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum CooperativeVectorSize { +pub enum CooperativeSize { Eight = 8, } @@ -523,7 +523,7 @@ pub enum ScalarKind { AbstractFloat, } -/// Primitive type for a scalar. +/// Primitive type for a cooperative scalar. #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -541,6 +541,18 @@ impl CooperativeScalar { } } +/// Role of a cooperative variable in the equation "A * B + C" +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CooperativeRole { + A, + B, + C, +} + /// Characteristics of a scalar type. #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -792,9 +804,10 @@ pub enum TypeInner { /// Matrix that is cooperatively processed by all the threads /// in an opaque mapping. CooperativeMatrix { - columns: CooperativeVectorSize, - rows: CooperativeVectorSize, + columns: CooperativeSize, + rows: CooperativeSize, scalar: CooperativeScalar, + role: CooperativeRole, }, /// Atomic scalar. Atomic(Scalar), diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 5e7aed8a0f..7f9380d766 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -86,8 +86,8 @@ impl From for Alignment { } } -impl From for Alignment { - fn from(size: crate::CooperativeVectorSize) -> Self { +impl From for Alignment { + fn from(size: crate::CooperativeSize) -> Self { Self(unsafe { NonZeroU32::new_unchecked(size as u32) }) } } @@ -222,6 +222,7 @@ impl Layouter { columns: _, rows, scalar, + role: _, } => { let alignment = Alignment::new(scalar.width() as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 24a14868f9..54fec19a89 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -206,6 +206,7 @@ impl crate::TypeInner { columns, rows, scalar, + role: _, } => Some(columns as u32 * rows as u32 * scalar.width() as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 49022d3069..e222811ac2 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -422,6 +422,7 @@ impl super::Validator { columns: _, rows: _, scalar, + role: _, } => { self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; if scalar != crate::CooperativeScalar::F32 { diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml new file mode 100644 index 0000000000..1bfa633ff0 --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -0,0 +1,2 @@ +targets = "SPIRV" +god_mode = true diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl new file mode 100644 index 0000000000..335034818f --- /dev/null +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -0,0 +1,7 @@ +var a: coop_mat8x8; +var b: coop_mat8x8; + +@compute @workgroup_size(8, 8, 1) +fn main() { + //let c = a * b; +} diff --git a/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron b/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron new file mode 100644 index 0000000000..f806c3f3dd --- /dev/null +++ b/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron @@ -0,0 +1,78 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 0, + space: Private, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(0), + ), + ( + uniformity: ( + non_uniform_result: Some(2), + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 0, + space: Private, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(2), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(0), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(0), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + ), + ], + const_expression_types: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron new file mode 100644 index 0000000000..1298f69e2c --- /dev/null +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -0,0 +1,84 @@ +( + types: [ + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: F32, + role: A, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("a"), + space: Private, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (8, 8, 1), + workgroup_size_overrides: None, + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(0), + Load( + pointer: 0, + ), + GlobalVariable(0), + Load( + pointer: 2, + ), + Binary( + op: Add, + left: 1, + right: 3, + ), + ], + named_expressions: { + 4: "a2", + }, + body: [ + Emit(( + start: 1, + end: 2, + )), + Emit(( + start: 3, + end: 5, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron new file mode 100644 index 0000000000..1298f69e2c --- /dev/null +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -0,0 +1,84 @@ +( + types: [ + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: F32, + role: A, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("a"), + space: Private, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (8, 8, 1), + workgroup_size_overrides: None, + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(0), + Load( + pointer: 0, + ), + GlobalVariable(0), + Load( + pointer: 2, + ), + Binary( + op: Add, + left: 1, + right: 3, + ), + ], + named_expressions: { + 4: "a2", + }, + body: [ + Emit(( + start: 1, + end: 2, + )), + Emit(( + start: 3, + end: 5, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm new file mode 100644 index 0000000000..33e7477e5d --- /dev/null +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -0,0 +1,17 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 7 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %4 "main" +OpExecutionMode %4 LocalSize 8 8 1 +%2 = OpTypeVoid +%5 = OpTypeFunction %2 +%4 = OpFunction %2 None %5 +%3 = OpLabel +OpBranch %6 +%6 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file From e13b90dd2f000a99494153c35235ece6851da070 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sat, 20 Sep 2025 00:20:32 -0700 Subject: [PATCH 05/10] coop: handle simple ops, end-to-end with SPIRV --- naga/src/back/spv/block.rs | 16 +++++++- naga/src/back/spv/instructions.rs | 15 ++++---- naga/src/back/spv/mod.rs | 1 + naga/src/back/spv/writer.rs | 18 ++++++++- naga/src/ir/mod.rs | 9 +++++ naga/src/proc/type_methods.rs | 1 + naga/src/proc/typifier.rs | 25 +++++++++++++ naga/src/valid/expression.rs | 43 ++++++++++++++++++++-- naga/tests/in/wgsl/cooperative-matrix.toml | 4 ++ naga/tests/in/wgsl/cooperative-matrix.wgsl | 4 +- 10 files changed, 120 insertions(+), 16 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index d0556acdc5..31224a9a5c 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -19,6 +19,7 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { crate::TypeInner::Scalar(_) => Dimension::Scalar, crate::TypeInner::Vector { .. } => Dimension::Vector, crate::TypeInner::Matrix { .. } => Dimension::Matrix, + crate::TypeInner::CooperativeMatrix { .. } => Dimension::CooperativeMatrix, _ => unreachable!(), } } @@ -766,6 +767,7 @@ impl BlockContext<'_> { rows, scalar, } => { + //TODO: why not just rely on `Fadd` for matrices? self.write_matrix_matrix_column_op( block, id, @@ -781,6 +783,7 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } + crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FAdd, _ => unimplemented!(), }, crate::BinaryOperator::Subtract => match *left_ty_inner { @@ -809,6 +812,7 @@ impl BlockContext<'_> { self.cached[expr_handle] = id; return Ok(()); } + crate::TypeInner::CooperativeMatrix { .. } => spirv::Op::FSub, _ => unimplemented!(), }, crate::BinaryOperator::Multiply => { @@ -842,10 +846,12 @@ impl BlockContext<'_> { (Dimension::Vector, Dimension::Matrix) => { spirv::Op::VectorTimesMatrix } - (Dimension::Matrix, Dimension::Scalar) => { + (Dimension::Matrix, Dimension::Scalar) + | (Dimension::CooperativeMatrix, Dimension::Scalar) => { spirv::Op::MatrixTimesScalar } - (Dimension::Scalar, Dimension::Matrix) => { + (Dimension::Scalar, Dimension::Matrix) + | (Dimension::Scalar, Dimension::CooperativeMatrix) => { reverse_operands = true; spirv::Op::MatrixTimesScalar } @@ -864,6 +870,12 @@ impl BlockContext<'_> { } (Dimension::Vector, Dimension::Vector) | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + (Dimension::CooperativeMatrix, Dimension::CooperativeMatrix) + //Note: technically can do `FMul` but IR doesn't have matrix per-component multiplication + | (Dimension::CooperativeMatrix, _) + | (_, Dimension::CooperativeMatrix) => { + unimplemented!() + } } } crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index bb559606d9..9e542917f3 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -284,17 +284,18 @@ impl super::Instruction { pub(super) fn type_coop_matrix( id: Word, scalar_type_id: Word, - row_count: crate::CooperativeSize, - column_count: crate::CooperativeSize, - role: spirv::CooperativeMatrixUse, + scope_id: Word, + row_count_id: Word, + column_count_id: Word, + matrix_use_id: Word, ) -> Self { let mut instruction = Self::new(Op::TypeCooperativeMatrixKHR); instruction.set_result(id); instruction.add_operand(scalar_type_id); - instruction.add_operand(spirv::Scope::Subgroup as u32); - instruction.add_operand(column_count as u32); - instruction.add_operand(row_count as u32); - instruction.add_operand(role as u32); + instruction.add_operand(scope_id); + instruction.add_operand(row_count_id); + instruction.add_operand(column_count_id); + instruction.add_operand(matrix_use_id); instruction } diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index f897dc5192..0dc4faa288 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -481,6 +481,7 @@ enum Dimension { Scalar, Vector, Matrix, + CooperativeMatrix, } /// Key used to look up an operation which we have wrapped in a helper diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 293737ce82..5a04f5dca5 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1370,7 +1370,9 @@ impl Writer { "cooperative matrix", &[spirv::Capability::CooperativeMatrixKHR], )?; + self.require_any("memory model", &[spirv::Capability::VulkanMemoryModel])?; self.use_extension("SPV_KHR_cooperative_matrix"); + self.use_extension("SPV_KHR_vulkan_memory_model"); } _ => {} } @@ -1407,7 +1409,12 @@ impl Writer { role, } => { let scalar_id = self.get_cooperative_type_id(scalar); - Instruction::type_coop_matrix(id, scalar_id, rows, columns, role.into()) + let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let columns_id = self.get_index_constant(columns as u32); + let rows_id = self.get_index_constant(rows as u32); + let role_id = + self.get_index_constant(spirv::CooperativeMatrixUse::from(role) as u32); + Instruction::type_coop_matrix(id, scalar_id, scope_id, rows_id, columns_id, role_id) } }; @@ -2683,7 +2690,14 @@ impl Writer { } let addressing_model = spirv::AddressingModel::Logical; - let memory_model = spirv::MemoryModel::GLSL450; + let memory_model = if self + .capabilities_used + .contains(&spirv::Capability::VulkanMemoryModel) + { + spirv::MemoryModel::Vulkan + } else { + spirv::MemoryModel::GLSL450 + }; //self.check(addressing_model.required_capabilities())?; //self.check(memory_model.required_capabilities())?; diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 31c3d5f8f2..e65e04fbf6 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -539,6 +539,15 @@ impl CooperativeScalar { Self::F32 => 4, } } + + pub const fn to_scalar(&self) -> Scalar { + match *self { + Self::F32 => Scalar { + kind: ScalarKind::Float, + width: 4, + }, + } + } } /// Role of a cooperative variable in the equation "A * B + C" diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 54fec19a89..c4a9091c74 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -115,6 +115,7 @@ impl crate::TypeInner { match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar), Ti::Matrix { scalar, .. } => Some(scalar), + Ti::CooperativeMatrix { scalar, .. } => Some(scalar.to_scalar()), _ => None, } } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 79b4f95e10..89599e079c 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -143,6 +143,17 @@ impl Clone for TypeResolution { columns, scalar, }, + Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + } => Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }, Ti::Pointer { base, space } => Ti::Pointer { base, space }, Ti::ValuePointer { size, @@ -587,6 +598,20 @@ impl<'a> ResolveContext<'a> { (&Ti::Scalar { .. }, _) => res_right.clone(), (_, &Ti::Scalar { .. }) => res_left.clone(), (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(), + ( + &Ti::CooperativeMatrix { + columns: _, + rows, + scalar, + role, + }, + &Ti::CooperativeMatrix { columns, .. }, + ) => TypeResolution::Value(Ti::CooperativeMatrix { + columns, + rows, + scalar, + role, + }), (tl, tr) => { return Err(ResolveError::IncompatibleOperands(format!( "{tl:?} * {tr:?}" diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 68023b5bf0..8bb9af142b 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -788,7 +788,9 @@ impl super::Validator { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, - Ti::Matrix { .. } => left_inner == right_inner, + Ti::Matrix { .. } | Ti::CooperativeMatrix { .. } => { + left_inner == right_inner + } _ => false, }, Bo::Divide | Bo::Modulo => match *left_inner { @@ -818,7 +820,7 @@ impl super::Validator { scalar: scalar2, .. }, ) => scalar1 == scalar2, - // Scalar/matrix. + // Scalar * matrix. ( &Ti::Scalar(Sc { kind: Sk::Float, .. @@ -831,7 +833,7 @@ impl super::Validator { kind: Sk::Float, .. }), ) => true, - // Vector/vector. + // Vector * vector. ( &Ti::Vector { size: size1, @@ -864,9 +866,44 @@ impl super::Validator { }, &Ti::Matrix { rows, .. }, ) => size == rows, + // Matrix * matrix. (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { columns == rows } + // Coop matrix * coop matrix. + ( + &Ti::CooperativeMatrix { + columns, + scalar: scalar1, + role: role1, + .. + }, + &Ti::CooperativeMatrix { + rows, + scalar: scalar2, + role: role2, + .. + }, + ) => columns == rows && scalar1 == scalar2 && role1 == role2, + // Scalar * coop matrix. + ( + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + &Ti::CooperativeMatrix { + scalar: crate::CooperativeScalar::F32, + .. + }, + ) + | ( + &Ti::CooperativeMatrix { + scalar: crate::CooperativeScalar::F32, + .. + }, + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + ) => true, _ => false, }; let left_width = left_inner.scalar_width().unwrap_or(0); diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index 1bfa633ff0..c06d67d21a 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -1,2 +1,6 @@ targets = "SPIRV" god_mode = true + +[spv] +debug = true +version = [1, 4] diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 335034818f..91e371b9fb 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -1,7 +1,7 @@ var a: coop_mat8x8; -var b: coop_mat8x8; +//var b: coop_mat8x8; @compute @workgroup_size(8, 8, 1) fn main() { - //let c = a * b; + let a2 = a + a; } From 1cfb7151425083c430977d1e0783be3b93255fda Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 21 Sep 2025 23:47:35 -0700 Subject: [PATCH 06/10] coop: mulAdd instruction --- naga/src/back/dot/mod.rs | 6 ++++ naga/src/back/glsl/mod.rs | 3 +- naga/src/back/hlsl/writer.rs | 4 ++- naga/src/back/msl/writer.rs | 7 +++++ naga/src/back/pipeline_constants.rs | 9 ++++++ naga/src/back/spv/block.rs | 11 +++++++ naga/src/back/spv/instructions.rs | 12 ++++++++ naga/src/back/wgsl/writer.rs | 9 ++++++ naga/src/compact/expressions.rs | 12 ++++++++ naga/src/front/wgsl/lower/mod.rs | 11 +++++-- naga/src/ir/mod.rs | 9 ++++++ naga/src/proc/constant_evaluator.rs | 3 ++ naga/src/proc/typifier.rs | 1 + naga/src/valid/analyzer.rs | 5 ++++ naga/src/valid/expression.rs | 35 ++++++++++++++++++++++ naga/src/valid/function.rs | 5 ++-- naga/src/valid/handles.rs | 3 ++ naga/tests/in/wgsl/cooperative-matrix.wgsl | 5 ++-- 18 files changed, 141 insertions(+), 9 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1f1396eccf..2305248630 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -761,6 +761,12 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } + E::MulAdd { a, b, c } => { + edges.insert("a", a); + edges.insert("b", b); + edges.insert("c", c); + ("MulAdd".into(), 6) + } }; // give uniform expressions an outline diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index ce2e627b0f..dc945d0d75 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -4352,7 +4352,8 @@ impl<'a, W: Write> Writer<'a, W> { } // not supported yet Expression::RayQueryGetIntersection { .. } - | Expression::RayQueryVertexPositions { .. } => unreachable!(), + | Expression::RayQueryVertexPositions { .. } + | Expression::MulAdd { .. } => unreachable!(), } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index f13601bff9..dd1e1d4d1b 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -4298,7 +4298,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } // Not supported yet - Expression::RayQueryVertexPositions { .. } => unreachable!(), + Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => { + unreachable!() + } // Nothing to do here, since call expression already cached Expression::CallResult(_) | Expression::AtomicResult { .. } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c2b1dcb4dc..f368df52da 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -2845,6 +2845,13 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::MulAdd { a, b, c } => { + self.put_expression(a, context, false)?; + write!(self.out, " * ")?; + self.put_expression(b, context, false)?; + write!(self.out, " + ")?; + self.put_expression(c, context, false)?; + } } Ok(()) } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 109cc591e7..df18311b1d 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -658,6 +658,15 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E } => { adjust(query); } + Expression::MulAdd { + ref mut a, + ref mut b, + ref mut c, + } => { + adjust(a); + adjust(b); + adjust(c); + } } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 31224a9a5c..404ef936cf 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1805,6 +1805,17 @@ impl BlockContext<'_> { )?; self.write_ray_query_return_vertex_position(query, block, committed) } + crate::Expression::MulAdd { a, b, c } => { + let id = self.gen_id(); + block.body.push(Instruction::coop_mul_add( + result_type_id, + id, + self.cached[a], + self.cached[b], + self.cached[c], + )); + id + } }; self.cached[expr_handle] = id; diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 9e542917f3..3091b6cfee 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1245,6 +1245,18 @@ impl super::Instruction { instruction } + + // Cooperative operations + pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(a); + instruction.add_operand(b); + instruction.add_operand(c); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index d1ebf62e6e..5f9f52984d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1686,6 +1686,15 @@ impl Writer { write!(self.out, ")")? } + Expression::MulAdd { a, b, c } => { + write!(self.out, "mulAdd(")?; + self.write_expr(module, a, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, b, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, c, func_ctx)?; + write!(self.out, ")")? + } // Not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } => unreachable!(), diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index f36d747a93..98f3bbc3c9 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -253,6 +253,9 @@ impl ExpressionTracer<'_> { } => { self.expressions_used.insert(query); } + Ex::MulAdd { a, b, c } => { + self.expressions_used.insert_iter([a, b, c]); + } } } } @@ -419,6 +422,15 @@ impl ModuleMap { ref mut query, committed: _, } => adjust(query), + Ex::MulAdd { + ref mut a, + ref mut b, + ref mut c, + } => { + adjust(a); + adjust(b); + adjust(c); + } } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 022aa3c4bb..41e7b84c26 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3084,7 +3084,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "quadSwapY" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3108,7 +3107,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "quadSwapDiagonal" => { let mut args = ctx.prepare_args(arguments, 1, span); @@ -3132,6 +3130,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } + "coopMulAdd" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let a = self.expression(args.next()?, ctx)?; + let b = self.expression(args.next()?, ctx)?; + let c = self.expression(args.next()?, ctx)?; + args.finish()?; + + ir::Expression::MulAdd { a, b, c } + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index e65e04fbf6..1325434425 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -1860,6 +1860,15 @@ pub enum Expression { /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, + + /// Return a * b + c. + /// Currently only supported for [`TypeInner::CooperativeMatrix`] types, + /// where it's only valid in uniform control flow. + MulAdd { + a: Handle, + b: Handle, + c: Handle, + }, } /// The value of the switch case. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 1e0c5ac15a..8b3c414e4d 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -970,6 +970,8 @@ pub enum ConstantEvaluatorError { "Expected reject and accept args. to be scalars of vectors of the same type, got something else", )] SelectAcceptRejectTypeMismatch, + #[error("Cooperative operations can't be constant")] + CooperativeOperation, } impl<'a> ConstantEvaluator<'a> { @@ -1357,6 +1359,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } + Expression::MulAdd { .. } => Err(ConstantEvaluatorError::CooperativeOperation), } } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 89599e079c..f90f36de68 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -801,6 +801,7 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), + crate::Expression::MulAdd { a, b: _, c: _ } => past(a)?.clone(), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 6ef2ca0988..6f00ffda6f 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -29,6 +29,7 @@ bitflags::bitflags! { const WORK_GROUP_BARRIER = 0x1; const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 }; const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 }; + const COOP_OPS = 0x8; } } @@ -864,6 +865,10 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::MulAdd { a, b, c } => Uniformity { + non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))), + requirements: UniformityRequirements::COOP_OPS, + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 8bb9af142b..01541ca65e 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -141,6 +141,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), + #[error("Invalid operand for MulAdd")] + InvalidMulAddOperand, } #[derive(Clone, Debug, thiserror::Error)] @@ -1267,6 +1269,39 @@ impl super::Validator { } }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, + E::MulAdd { a, b, c } => { + match resolver[a] { + Ti::CooperativeMatrix { + role: crate::CooperativeRole::A, + .. + } => {} + ref other => { + log::error!("A operand type: {other:?}"); + return Err(ExpressionError::InvalidMulAddOperand); + } + } + match resolver[b] { + Ti::CooperativeMatrix { + role: crate::CooperativeRole::B, + .. + } => {} + ref other => { + log::error!("B operand type: {other:?}"); + return Err(ExpressionError::InvalidMulAddOperand); + } + } + match resolver[c] { + Ti::CooperativeMatrix { + role: crate::CooperativeRole::C, + .. + } => {} + ref other => { + log::error!("C operand type: {other:?}"); + return Err(ExpressionError::InvalidMulAddOperand); + } + } + ShaderStages::COMPUTE + } }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 0216c6ef7f..8af0ba8e98 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -804,9 +804,8 @@ impl super::Validator { | Ex::As { .. } | Ex::ArrayLength(_) | Ex::RayQueryGetIntersection { .. } - | Ex::RayQueryVertexPositions { .. } => { - self.emit_expression(handle, context)? - } + | Ex::RayQueryVertexPositions { .. } + | Ex::MulAdd { .. } => self.emit_expression(handle, context)?, Ex::CallResult(_) | Ex::AtomicResult { .. } | Ex::WorkGroupUniformLoadResult { .. } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 303b6cf193..e2a3325f6d 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -662,6 +662,9 @@ impl super::Validator { } => { handle.check_dep(query)?; } + crate::Expression::MulAdd { a, b, c } => { + handle.check_dep(a)?.check_dep(b)?.check_dep(c)?; + } } Ok(()) } diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 91e371b9fb..2380046fbc 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -1,7 +1,8 @@ var a: coop_mat8x8; -//var b: coop_mat8x8; +var b: coop_mat8x8; +var c: coop_mat8x8; @compute @workgroup_size(8, 8, 1) fn main() { - let a2 = a + a; + let d = coopMulAdd(a, b, c); } From e485df6608a2e3852e57e302291e0c6f0471caa2 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 23 Sep 2025 00:22:46 -0700 Subject: [PATCH 07/10] coop: Implement Load/Store statement --- naga/src/back/dot/mod.rs | 22 ++- naga/src/back/glsl/mod.rs | 3 +- naga/src/back/hlsl/writer.rs | 4 +- naga/src/back/mod.rs | 12 -- naga/src/back/msl/writer.rs | 131 ++++++++++++++++-- naga/src/back/pipeline_constants.rs | 15 +- naga/src/back/spv/block.rs | 51 ++++++- naga/src/back/spv/instructions.rs | 34 +++++ naga/src/back/spv/mod.rs | 2 +- naga/src/back/spv/writer.rs | 9 +- naga/src/back/wgsl/writer.rs | 37 +++-- naga/src/common/wgsl/to_wgsl.rs | 19 --- naga/src/compact/expressions.rs | 8 +- naga/src/compact/statements.rs | 26 ++++ naga/src/front/spv/mod.rs | 1 + naga/src/front/wgsl/error.rs | 6 +- naga/src/front/wgsl/lower/construction.rs | 7 +- naga/src/front/wgsl/lower/mod.rs | 51 +++++-- naga/src/front/wgsl/parse/mod.rs | 10 +- naga/src/ir/mod.rs | 43 ++---- naga/src/proc/constant_evaluator.rs | 4 +- naga/src/proc/layouter.rs | 2 +- naga/src/proc/terminator.rs | 3 +- naga/src/proc/type_methods.rs | 19 ++- naga/src/proc/typifier.rs | 18 ++- naga/src/valid/analyzer.rs | 20 ++- naga/src/valid/expression.rs | 34 ++--- naga/src/valid/function.rs | 70 +++++++++- naga/src/valid/handles.rs | 14 +- naga/src/valid/type.rs | 4 +- naga/tests/in/wgsl/cooperative-matrix.toml | 2 +- naga/tests/in/wgsl/cooperative-matrix.wgsl | 8 +- .../ir/wgsl-cooperative-matrix.compact.ron | 129 ++++++++++++++--- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 129 ++++++++++++++--- .../out/spv/wgsl-cooperative-matrix.spvasm | 85 ++++++++++-- 35 files changed, 798 insertions(+), 234 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 2305248630..9f5f2f3dab 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -422,6 +422,24 @@ impl StatementGraph { }, } } + S::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major: _, + } => { + self.dependencies.push((id, target, "target")); + self.dependencies.push((id, pointer, "pointer")); + if let Some(stride) = stride { + self.dependencies.push((id, stride, "stride")); + } + if store { + "Store" + } else { + "Load" + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -761,11 +779,11 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("get{ty}HitVertexPositions").into(), 4) } - E::MulAdd { a, b, c } => { + E::CooperativeMultiplyAdd { a, b, c } => { edges.insert("a", a); edges.insert("b", b); edges.insert("c", c); - ("MulAdd".into(), 6) + ("cooperativeMultiplyAdd".into(), 4) } }; diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index dc945d0d75..b609aa137b 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2816,6 +2816,7 @@ impl<'a, W: Write> Writer<'a, W> { } writeln!(self.out, ");")?; } + Statement::CooperativeLoadStore { .. } => unimplemented!(), } Ok(()) @@ -4353,7 +4354,7 @@ impl<'a, W: Write> Writer<'a, W> { // not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } - | Expression::MulAdd { .. } => unreachable!(), + | Expression::CooperativeMultiplyAdd { .. } => unreachable!(), } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index dd1e1d4d1b..a778deac98 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2769,6 +2769,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } + Statement::CooperativeLoadStore { .. } => unimplemented!(), } Ok(()) @@ -4298,7 +4299,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } // Not supported yet - Expression::RayQueryVertexPositions { .. } | Expression::MulAdd { .. } => { + Expression::RayQueryVertexPositions { .. } + | Expression::CooperativeMultiplyAdd { .. } => { unreachable!() } // Nothing to do here, since call expression already cached diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 8be763234e..ef9d829969 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -311,18 +311,6 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { } } -impl crate::TypeInner { - /// Returns true if this is a handle to a type rather than the type directly. - pub const fn is_handle(&self) -> bool { - match *self { - crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } - | crate::TypeInner::AccelerationStructure { .. } => true, - _ => false, - } - } -} - impl crate::Statement { /// Returns true if the statement directly terminates the current block. /// diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f368df52da..e6de2165ed 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp /// allowing them to be conveniently passed to user-defined or wrapper /// functions. The struct is declared in [`Writer::write_type_defs`]. pub(crate) const EXTERNAL_TEXTURE_WRAPPER_STRUCT: &str = "NagaExternalTextureWrapper"; +pub(crate) const COOPERATIVE_MULTIPLY_ADD_FUNCTION: &str = "NagaCooperativeMultiplyAdd"; /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// @@ -483,6 +484,12 @@ enum WrappedFunction { ImageQuerySize { class: crate::ImageClass, }, + CooperativeMultiplyAdd { + columns: crate::CooperativeSize, + rows: crate::CooperativeSize, + intermediate: crate::CooperativeSize, + scalar: crate::Scalar, + }, } pub struct Writer { @@ -543,14 +550,6 @@ impl crate::Scalar { } } -impl crate::CooperativeScalar { - const fn to_msl_name(self) -> &'static str { - match self { - Self::F32 => "float", - } - } -} - const fn separate(need_separator: bool) -> &'static str { if need_separator { "," @@ -2845,12 +2844,14 @@ impl Writer { } write!(self.out, "}}")?; } - crate::Expression::MulAdd { a, b, c } => { - self.put_expression(a, context, false)?; - write!(self.out, " * ")?; - self.put_expression(b, context, false)?; - write!(self.out, " + ")?; - self.put_expression(c, context, false)?; + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + write!(self.out, "{COOPERATIVE_MULTIPLY_ADD_FUNCTION}(")?; + self.put_expression(a, context, true)?; + write!(self.out, ", ")?; + self.put_expression(b, context, true)?; + write!(self.out, ", ")?; + self.put_expression(c, context, true)?; + write!(self.out, ")")?; } } Ok(()) @@ -4241,6 +4242,49 @@ impl Writer { } writeln!(self.out, ");")?; } + crate::Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + } => { + let op_str = if store { "store" } else { "load" }; + write!(self.out, "{level}{NAMESPACE}::simdgroup_{op_str}(")?; + self.put_expression(target, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(pointer, &context.expression, true)?; + if stride.is_some() || row_major { + write!(self.out, ", ")?; + match stride { + Some(expression) => { + self.put_expression(expression, &context.expression, true)?; + } + None => { + let default_stride = match *context.expression.resolve_type(target) + { + crate::TypeInner::CooperativeMatrix { + columns, rows, .. + } => { + if row_major { + columns as u32 + } else { + rows as u32 + } + } + _ => 0, + }; + write!(self.out, "{default_stride}")?; + } + } + } + if row_major { + let matrix_origin = "0"; + let transpose = true; + write!(self.out, ", {matrix_origin}, {transpose}")?; + } + writeln!(self.out, ");")?; + } } } @@ -6297,6 +6341,62 @@ template Ok(()) } + fn write_wrapped_cooperative_multiply_add( + &mut self, + module: &crate::Module, + func_ctx: &back::FunctionCtx, + a: Handle, + b: Handle, + ) -> BackendResult { + let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + .. + } => (columns, rows, scalar), + _ => unreachable!(), + }; + let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) { + crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), + _ => unreachable!(), + }; + let wrapped = WrappedFunction::CooperativeMultiplyAdd { + columns: b_c, + rows: a_r, + intermediate: a_c, + scalar, + }; + if !self.wrapped_functions.insert(wrapped) { + return Ok(()); + } + let scalar_name = match scalar.width { + 2 => "half", + 4 => "float", + 8 => "double", + _ => unreachable!(), + }; + writeln!( + self.out, + "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_MULTIPLY_ADD_FUNCTION}(const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& a, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& b, const {NAMESPACE}::simdgroup_{scalar_name}{}x{}& c) {{", + b_c as u32, a_r as u32, a_c as u32, a_r as u32, b_c as u32, b_r as u32, b_c as u32, a_r as u32, + )?; + let l1 = back::Level(1); + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} d;", + b_c as u32, a_r as u32 + )?; + writeln!( + self.out, + "{l1}{NAMESPACE}::simdgroup_multiply_accumulate(d,a,b,c);" + )?; + writeln!(self.out, "{l1}return d;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + pub(super) fn write_wrapped_functions( &mut self, module: &crate::Module, @@ -6371,6 +6471,9 @@ template crate::Expression::ImageQuery { image, query } => { self.write_wrapped_image_query(module, func_ctx, image, query)?; } + crate::Expression::CooperativeMultiplyAdd { a, b, c: _ } => { + self.write_wrapped_cooperative_multiply_add(module, func_ctx, a, b)?; + } _ => {} } } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index df18311b1d..4d97e21cc7 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -658,7 +658,7 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E } => { adjust(query); } - Expression::MulAdd { + Expression::CooperativeMultiplyAdd { ref mut a, ref mut b, ref mut c, @@ -889,6 +889,19 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S adjust(index); adjust(value); } + Statement::CooperativeLoadStore { + store: _, + ref mut target, + ref mut pointer, + ref mut stride, + row_major: _, + } => { + adjust(target); + adjust(pointer); + if let Some(ref mut stride) = *stride { + adjust(stride); + } + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 404ef936cf..479de9e425 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1805,14 +1805,21 @@ impl BlockContext<'_> { )?; self.write_ray_query_return_vertex_position(query, block, committed) } - crate::Expression::MulAdd { a, b, c } => { + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { + self.writer.require_any( + "CooperativeMatrix", + &[spirv::Capability::CooperativeMatrixKHR], + )?; + let a_id = self.cached[a]; + let b_id = self.cached[b]; + let c_id = self.cached[c]; let id = self.gen_id(); block.body.push(Instruction::coop_mul_add( result_type_id, id, - self.cached[a], - self.cached[b], - self.cached[c], + a_id, + b_id, + c_id, )); id } @@ -3679,6 +3686,42 @@ impl BlockContext<'_> { self.write_subgroup_gather(mode, argument, result, &mut block)?; } Statement::MeshFunction(_) => unreachable!(), + Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + } => { + let layout = if row_major { + spirv::CooperativeMatrixLayout::RowMajorKHR + } else { + spirv::CooperativeMatrixLayout::ColumnMajorKHR + }; + let layout_id = self.get_index_constant(layout as u32); + let stride_id = stride.map(|exp| self.cached[exp]); + if store { + block.body.push(Instruction::coop_store( + self.cached[target], + self.cached[pointer], + layout_id, + stride_id, + )); + } else { + let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty); + let id = self.gen_id(); + block.body.push(Instruction::coop_load( + result_type_id, + id, + self.cached[pointer], + layout_id, + stride_id, + )); + block + .body + .push(Instruction::store(self.cached[target], id, None)); + } + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 3091b6cfee..419c276fc4 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1247,6 +1247,40 @@ impl super::Instruction { } // Cooperative operations + pub(super) fn coop_load( + result_type_id: Word, + id: Word, + pointer_id: Word, + layout_id: Word, + stride_id: Option, + ) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + instruction.add_operand(layout_id); + if let Some(stride_id) = stride_id { + instruction.add_operand(stride_id); + } + + instruction + } + pub(super) fn coop_store( + id: Word, + pointer_id: Word, + layout_id: Word, + stride_id: Option, + ) -> Self { + let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR); + instruction.add_operand(pointer_id); + instruction.add_operand(id); + instruction.add_operand(layout_id); + if let Some(stride_id) = stride_id { + instruction.add_operand(stride_id); + } + + instruction + } pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixMulAddKHR); instruction.set_type(result_type_id); diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 0dc4faa288..6738c776cd 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -345,7 +345,7 @@ enum CooperativeType { Matrix { columns: crate::CooperativeSize, rows: crate::CooperativeSize, - scalar: crate::CooperativeScalar, + scalar: crate::Scalar, role: crate::CooperativeRole, }, } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 5a04f5dca5..8e9b3553a0 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -375,12 +375,6 @@ impl Writer { }) } - pub(super) fn get_cooperative_type_id(&mut self, scalar: crate::CooperativeScalar) -> Word { - match scalar { - crate::CooperativeScalar::F32 => self.get_f32_type_id(), - } - } - pub(super) fn get_f32_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { let f32_id = self.get_f32_type_id(); self.get_pointer_type_id(f32_id, class) @@ -1408,7 +1402,8 @@ impl Writer { scalar, role, } => { - let scalar_id = self.get_cooperative_type_id(scalar); + let scalar_id = + self.get_localtype_id(LocalType::Numeric(NumericType::Scalar(scalar))); let scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let columns_id = self.get_index_constant(columns as u32); let rows_id = self.get_index_constant(rows as u32); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 5f9f52984d..ad766b0f6b 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -985,6 +985,25 @@ impl Writer { } writeln!(self.out, ");")?; } + Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + } => { + let op_str = if store { "Store" } else { "Load" }; + let suffix = if row_major { "T" } else { "" }; + write!(self.out, "coop{op_str}{suffix}(")?; + self.write_expr(module, target, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, pointer, func_ctx)?; + if let Some(stride) = stride { + write!(self.out, ", ")?; + self.write_expr(module, stride, func_ctx)?; + } + write!(self.out, ")")? + } } Ok(()) @@ -1686,15 +1705,6 @@ impl Writer { write!(self.out, ")")? } - Expression::MulAdd { a, b, c } => { - write!(self.out, "mulAdd(")?; - self.write_expr(module, a, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, b, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, c, func_ctx)?; - write!(self.out, ")")? - } // Not supported yet Expression::RayQueryGetIntersection { .. } | Expression::RayQueryVertexPositions { .. } => unreachable!(), @@ -1705,6 +1715,15 @@ impl Writer { | Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} + Expression::CooperativeMultiplyAdd { a, b, c } => { + write!(self.out, "coopMultiplyAdd(")?; + self.write_expr(module, a, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, b, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, c, func_ctx)?; + write!(self.out, ")")?; + } } Ok(()) diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index d8988c7f34..8fade7179b 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -305,25 +305,6 @@ impl TryToWgsl for crate::Scalar { } } -impl TryToWgsl for crate::CooperativeScalar { - const DESCRIPTION: &'static str = "cooperative scalar type"; - - fn try_to_wgsl(self) -> Option<&'static str> { - use crate::CooperativeScalar; - - Some(match self { - CooperativeScalar::F32 => "f32", - }) - } - - fn to_wgsl_for_diagnostics(self) -> String { - match self.try_to_wgsl() { - Some(static_string) => static_string.to_string(), - None => unreachable!(), - } - } -} - impl ToWgsl for crate::CooperativeRole { fn to_wgsl(self) -> &'static str { match self { diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 98f3bbc3c9..2b2117cc16 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -253,8 +253,10 @@ impl ExpressionTracer<'_> { } => { self.expressions_used.insert(query); } - Ex::MulAdd { a, b, c } => { - self.expressions_used.insert_iter([a, b, c]); + Ex::CooperativeMultiplyAdd { a, b, c } => { + self.expressions_used.insert(a); + self.expressions_used.insert(b); + self.expressions_used.insert(c); } } } @@ -422,7 +424,7 @@ impl ModuleMap { ref mut query, committed: _, } => adjust(query), - Ex::MulAdd { + Ex::CooperativeMultiplyAdd { ref mut a, ref mut b, ref mut c, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index b370501bac..e91436ca74 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -166,6 +166,19 @@ impl FunctionTracer<'_> { self.expressions_used.insert(argument); self.expressions_used.insert(result); } + St::CooperativeLoadStore { + store: _, + target, + pointer, + stride, + row_major: _, + } => { + self.expressions_used.insert(target); + self.expressions_used.insert(pointer); + if let Some(stride) = stride { + self.expressions_used.insert(stride); + } + } // Trivial statements. St::Break @@ -405,6 +418,19 @@ impl FunctionMap { adjust(argument); adjust(result); } + St::CooperativeLoadStore { + store: _, + ref mut target, + ref mut pointer, + ref mut stride, + row_major: _, + } => { + adjust(target); + adjust(pointer); + if let Some(ref mut stride) = *stride { + adjust(stride); + } + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 2a3a971a8b..43b067a3d2 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4695,6 +4695,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), + S::CooperativeLoadStore { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 8c749acc73..f0d6a4b848 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -413,7 +413,7 @@ pub(crate) enum Error<'a> { span: Span, }, UnderspecifiedCooperativeMatrix, - UnknownCooperativeScalar(Span), + UnsupportedCooperativeScalar(Span), } impl From for Error<'_> { @@ -1393,8 +1393,8 @@ impl<'a> Error<'a> { labels: vec![], notes: vec![format!("must be F32")], }, - Error::UnknownCooperativeScalar(span) => ParseError { - message: "unknown cooperative scalar type".into(), + Error::UnsupportedCooperativeScalar(span) => ParseError { + message: "cooperative scalar type is not supported".into(), labels: vec![(span, "type needs the scalar type specified".into())], notes: vec![format!("must be F32")], }, diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 9ac11bfc98..2159ef01ad 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -650,11 +650,8 @@ impl<'source> Lowerer<'source, '_> { } => { let ty = self.resolve_ast_type(ty, &mut ctx.as_const())?; let scalar = match ctx.module.types[ty].inner { - crate::TypeInner::Scalar(crate::Scalar { - kind: crate::ScalarKind::Float, - width: 4, - }) => crate::CooperativeScalar::F32, - _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + crate::TypeInner::Scalar(s) => s, + _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))), }; let ty = ctx.ensure_type_exists(crate::TypeInner::CooperativeMatrix { columns, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 41e7b84c26..5a1a90a0e9 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1679,8 +1679,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .as_expression(block, &mut emitter) .interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?; block.extend(emitter.finish(&ctx.function.expressions)); - ctx.local_table - .insert(v.handle, Declared::Runtime(Typed::Reference(handle))); + let typed = if ctx.module.types[ty].inner.is_handle() { + Typed::Plain(handle) + } else { + Typed::Reference(handle) + }; + ctx.local_table.insert(v.handle, Declared::Runtime(typed)); match initializer { Some(initializer) => ir::Statement::Store { @@ -2136,8 +2140,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expr = match *global { LoweredGlobalDecl::Var(handle) => { let expr = ir::Expression::GlobalVariable(handle); - match ctx.module.global_variables[handle].space { + let v = &ctx.module.global_variables[handle]; + let force_value = ctx.module.types[v.ty].inner.is_handle(); + match v.space { ir::AddressSpace::Handle => Typed::Plain(expr), + _ if force_value => Typed::Plain(expr), _ => Typed::Reference(expr), } } @@ -3130,14 +3137,41 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - "coopMulAdd" => { + "coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let target = self.expression(args.next()?, ctx)?; + let pointer = self.expression(args.next()?, ctx)?; + let stride = if args.total_args > 2 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; + + let store = function.name.contains("Store"); + let row_major = function.name.ends_with("T"); + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::CooperativeLoadStore { + store, + target, + pointer, + stride, + row_major, + }, + span, + ); + return Ok(None); + } + "coopMultiplyAdd" => { let mut args = ctx.prepare_args(arguments, 3, span); let a = self.expression(args.next()?, ctx)?; let b = self.expression(args.next()?, ctx)?; let c = self.expression(args.next()?, ctx)?; args.finish()?; - ir::Expression::MulAdd { a, b, c } + ir::Expression::CooperativeMultiplyAdd { a, b, c } } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) @@ -3973,11 +4007,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } => { let ty = self.resolve_ast_type(ty, ctx)?; let scalar = match ctx.module.types[ty].inner { - ir::TypeInner::Scalar(crate::Scalar { - kind: crate::ScalarKind::Float, - width: 4, - }) => crate::CooperativeScalar::F32, - _ => return Err(Box::new(Error::UnknownCooperativeScalar(ty_span))), + ir::TypeInner::Scalar(s) => s, + _ => return Err(Box::new(Error::UnsupportedCooperativeScalar(ty_span))), }; ir::TypeInner::CooperativeMatrix { columns, diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 49d7eaab25..576bd9c977 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -658,12 +658,10 @@ impl Parser { ty_span: Span::UNDEFINED, })) } - "coop_mat8x8" => { - return Ok(Some(ast::ConstructorType::PartialCooperativeMatrix { - columns: crate::CooperativeSize::Eight, - rows: crate::CooperativeSize::Eight, - })) - } + "coop_mat8x8" => ast::ConstructorType::PartialCooperativeMatrix { + columns: crate::CooperativeSize::Eight, + rows: crate::CooperativeSize::Eight, + }, "array" => ast::ConstructorType::PartialArray, "atomic" | "binding_array" diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 1325434425..9ed185a7ce 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -523,33 +523,6 @@ pub enum ScalarKind { AbstractFloat, } -/// Primitive type for a cooperative scalar. -#[repr(u8)] -#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum CooperativeScalar { - F32, -} - -impl CooperativeScalar { - pub const fn width(&self) -> Bytes { - match *self { - Self::F32 => 4, - } - } - - pub const fn to_scalar(&self) -> Scalar { - match *self { - Self::F32 => Scalar { - kind: ScalarKind::Float, - width: 4, - }, - } - } -} - /// Role of a cooperative variable in the equation "A * B + C" #[repr(u8)] #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -815,7 +788,7 @@ pub enum TypeInner { CooperativeMatrix { columns: CooperativeSize, rows: CooperativeSize, - scalar: CooperativeScalar, + scalar: Scalar, role: CooperativeRole, }, /// Atomic scalar. @@ -1861,10 +1834,8 @@ pub enum Expression { /// [`SubgroupGather`]: Statement::SubgroupGather SubgroupOperationResult { ty: Handle }, - /// Return a * b + c. - /// Currently only supported for [`TypeInner::CooperativeMatrix`] types, - /// where it's only valid in uniform control flow. - MulAdd { + /// Compute `a * b + c` + CooperativeMultiplyAdd { a: Handle, b: Handle, c: Handle, @@ -2312,6 +2283,14 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, + /// Load from or store into a cooperative primitive. + CooperativeLoadStore { + store: bool, + target: Handle, + pointer: Handle, + stride: Option>, + row_major: bool, + }, } /// A function argument. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 8b3c414e4d..a8f2601ca9 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -1359,7 +1359,9 @@ impl<'a> ConstantEvaluator<'a> { Expression::SubgroupOperationResult { .. } => { Err(ConstantEvaluatorError::SubgroupExpression) } - Expression::MulAdd { .. } => Err(ConstantEvaluatorError::CooperativeOperation), + Expression::CooperativeMultiplyAdd { .. } => { + Err(ConstantEvaluatorError::CooperativeOperation) + } } } diff --git a/naga/src/proc/layouter.rs b/naga/src/proc/layouter.rs index 7f9380d766..5165ac7a01 100644 --- a/naga/src/proc/layouter.rs +++ b/naga/src/proc/layouter.rs @@ -224,7 +224,7 @@ impl Layouter { scalar, role: _, } => { - let alignment = Alignment::new(scalar.width() as u32) + let alignment = Alignment::new(scalar.width as u32) .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; TypeLayout { size, diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index f76d4c06a3..fc5d1aae31 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -44,7 +44,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } | S::ControlBarrier(_) - | S::MemoryBarrier(_)), + | S::MemoryBarrier(_) + | S::CooperativeLoadStore { .. }), ) | None => block.push(S::Return { value: None }, Default::default()), } diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index c4a9091c74..136ea29218 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -115,7 +115,7 @@ impl crate::TypeInner { match *self { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar), Ti::Matrix { scalar, .. } => Some(scalar), - Ti::CooperativeMatrix { scalar, .. } => Some(scalar.to_scalar()), + Ti::CooperativeMatrix { scalar, .. } => Some(scalar), _ => None, } } @@ -183,14 +183,25 @@ impl crate::TypeInner { pub fn is_atomic_pointer(&self, types: &crate::UniqueArena) -> bool { match *self { - crate::TypeInner::Pointer { base, .. } => match types[base].inner { - crate::TypeInner::Atomic { .. } => true, + Self::Pointer { base, .. } => match types[base].inner { + Self::Atomic { .. } => true, _ => false, }, _ => false, } } + /// Returns true if a variable of this type is a handle. + pub const fn is_handle(&self) -> bool { + match *self { + Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure { .. } + | Self::CooperativeMatrix { .. } => true, + _ => false, + } + } + /// Attempt to calculate the size of this type. Returns `None` if the size /// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`]. pub fn try_size(&self, gctx: super::GlobalCtx) -> Option { @@ -208,7 +219,7 @@ impl crate::TypeInner { rows, scalar, role: _, - } => Some(columns as u32 * rows as u32 * scalar.width() as u32), + } => Some(columns as u32 * rows as u32 * scalar.width as u32), Self::Pointer { .. } | Self::ValuePointer { .. } => Some(POINTER_SPAN), Self::Array { base: _, diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index f90f36de68..8e323d7724 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -454,7 +454,8 @@ impl<'a> ResolveContext<'a> { } crate::Expression::GlobalVariable(h) => { let var = &self.global_vars[h]; - if var.space == crate::AddressSpace::Handle { + let ty = &types[var.ty].inner; + if var.space == crate::AddressSpace::Handle || ty.is_handle() { TypeResolution::Handle(var.ty) } else { TypeResolution::Value(Ti::Pointer { @@ -465,10 +466,15 @@ impl<'a> ResolveContext<'a> { } crate::Expression::LocalVariable(h) => { let var = &self.local_vars[h]; - TypeResolution::Value(Ti::Pointer { - base: var.ty, - space: crate::AddressSpace::Function, - }) + let ty = &types[var.ty].inner; + if ty.is_handle() { + TypeResolution::Handle(var.ty) + } else { + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: crate::AddressSpace::Function, + }) + } } crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { Ti::Pointer { base, space: _ } => { @@ -801,7 +807,7 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), - crate::Expression::MulAdd { a, b: _, c: _ } => past(a)?.clone(), + crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 6f00ffda6f..a530c1869c 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -865,7 +865,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, - E::MulAdd { a, b, c } => Uniformity { + E::CooperativeMultiplyAdd { a, b, c } => Uniformity { non_uniform_result: self.add_ref(a).or(self.add_ref(b).or(self.add_ref(c))), requirements: UniformityRequirements::COOP_OPS, }, @@ -1228,6 +1228,24 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::CooperativeLoadStore { + store: _, + target, + pointer, + stride, + row_major: _, + } => { + if let Some(stride) = stride { + let _ = self.add_ref(stride); + } + FunctionUniformity { + result: Uniformity { + non_uniform_result: self.add_ref(target).or(self.add_ref(pointer)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + } + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 01541ca65e..466ca26b60 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -141,8 +141,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), - #[error("Invalid operand for MulAdd")] - InvalidMulAddOperand, + #[error("Invalid operand for cooperative op")] + InvalidCooperativeOperand(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -888,24 +888,10 @@ impl super::Validator { }, ) => columns == rows && scalar1 == scalar2 && role1 == role2, // Scalar * coop matrix. - ( - &Ti::Scalar(Sc { - kind: Sk::Float, .. - }), - &Ti::CooperativeMatrix { - scalar: crate::CooperativeScalar::F32, - .. - }, - ) - | ( - &Ti::CooperativeMatrix { - scalar: crate::CooperativeScalar::F32, - .. - }, - &Ti::Scalar(Sc { - kind: Sk::Float, .. - }), - ) => true, + (&Ti::Scalar(s1), &Ti::CooperativeMatrix { scalar: s2, .. }) + | (&Ti::CooperativeMatrix { scalar: s1, .. }, &Ti::Scalar(s2)) => { + s1 == s2 + } _ => false, }; let left_width = left_inner.scalar_width().unwrap_or(0); @@ -1269,7 +1255,7 @@ impl super::Validator { } }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, - E::MulAdd { a, b, c } => { + E::CooperativeMultiplyAdd { a, b, c } => { match resolver[a] { Ti::CooperativeMatrix { role: crate::CooperativeRole::A, @@ -1277,7 +1263,7 @@ impl super::Validator { } => {} ref other => { log::error!("A operand type: {other:?}"); - return Err(ExpressionError::InvalidMulAddOperand); + return Err(ExpressionError::InvalidCooperativeOperand(a)); } } match resolver[b] { @@ -1287,7 +1273,7 @@ impl super::Validator { } => {} ref other => { log::error!("B operand type: {other:?}"); - return Err(ExpressionError::InvalidMulAddOperand); + return Err(ExpressionError::InvalidCooperativeOperand(b)); } } match resolver[c] { @@ -1297,7 +1283,7 @@ impl super::Validator { } => {} ref other => { log::error!("C operand type: {other:?}"); - return Err(ExpressionError::InvalidMulAddOperand); + return Err(ExpressionError::InvalidCooperativeOperand(c)); } } ShaderStages::COMPUTE diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 8af0ba8e98..a90bc6186a 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1,6 +1,5 @@ use alloc::{format, string::String}; -use super::validate_atomic_compare_exchange_struct; use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, ExpressionError, FunctionInfo, ModuleInfo, @@ -213,6 +212,10 @@ pub enum FunctionError { WorkgroupUniformLoadInvalidPointer(Handle), #[error("Subgroup operation is invalid")] InvalidSubgroup(#[from] SubgroupError), + #[error("Invalid target type for a cooperative store")] + InvalidCooperativeStoreTarget(Handle), + #[error("Cooperative load/store data pointer has invalid type")] + InvalidCooperativeDataPointer(Handle), #[error("Emit statement should not cover \"result\" expressions like {0:?}")] EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] @@ -584,7 +587,7 @@ impl super::Validator { .with_span_handle(result, context.expressions) .into_other()); }; - if !validate_atomic_compare_exchange_struct( + if !super::validate_atomic_compare_exchange_struct( context.types, members, |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(pointer_scalar), @@ -805,7 +808,9 @@ impl super::Validator { | Ex::ArrayLength(_) | Ex::RayQueryGetIntersection { .. } | Ex::RayQueryVertexPositions { .. } - | Ex::MulAdd { .. } => self.emit_expression(handle, context)?, + | Ex::CooperativeMultiplyAdd { .. } => { + self.emit_expression(handle, context)? + } Ex::CallResult(_) | Ex::AtomicResult { .. } | Ex::WorkGroupUniformLoadResult { .. } @@ -1081,7 +1086,7 @@ impl super::Validator { } else if let Some(tr) = pointer_base_tr { context.compare_types(value_tr, &tr) } else { - false + value_ty.is_handle() }; if !good { @@ -1660,6 +1665,63 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } + S::CooperativeLoadStore { + store, + target, + pointer, + stride: _, + row_major: _, + } => { + stages &= super::ShaderStages::COMPUTE; + + let target_scalar = + match *context.resolve_type_inner(target, &self.valid_expression_set)? { + Ti::CooperativeMatrix { scalar, .. } => scalar, + ref other => { + log::error!("Target operand type: {other:?}"); + return Err(FunctionError::InvalidCooperativeStoreTarget(target) + .with_span_handle(target, context.expressions)); + } + }; + + let ty_inner = + context.resolve_type_inner(pointer, &self.valid_expression_set)?; + //TODO: validate stride + let (pty_array, space) = match *ty_inner { + crate::TypeInner::Pointer { base, space } => (base, space), + _ => { + return Err(FunctionError::InvalidCooperativeDataPointer(pointer) + .with_span_handle(pointer, context.expressions)) + } + }; + let pty_scalar = match context.types[pty_array].inner { + crate::TypeInner::Array { + base, + size: _, + stride: _, + } => base, + _ => { + return Err(FunctionError::InvalidCooperativeDataPointer(pointer) + .with_span_handle(pointer, context.expressions)) + } + }; + let space = match context.types[pty_scalar].inner { + crate::TypeInner::Scalar(s) if s == target_scalar => space, + _ => { + return Err(FunctionError::InvalidCooperativeDataPointer(pointer) + .with_span_handle(pointer, context.expressions)) + } + }; + + if store && !space.access().contains(crate::StorageAccess::STORE) { + return Err( + FunctionError::InvalidStorePointer(pointer).with_span_static( + context.expressions.get_span(pointer), + "writing to this location is not permitted", + ), + ); + } + } } } Ok(BlockInfo { stages }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e2a3325f6d..6dbd542814 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -662,7 +662,7 @@ impl super::Validator { } => { handle.check_dep(query)?; } - crate::Expression::MulAdd { a, b, c } => { + crate::Expression::CooperativeMultiplyAdd { a, b, c } => { handle.check_dep(a)?.check_dep(b)?.check_dep(c)?; } } @@ -869,6 +869,18 @@ impl super::Validator { validate_expr(result)?; Ok(()) } + crate::Statement::CooperativeLoadStore { + store: _, + target, + pointer, + stride, + row_major: _, + } => { + validate_expr(target)?; + validate_expr(pointer)?; + validate_expr_opt(stride)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index e222811ac2..076227eb64 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -425,7 +425,7 @@ impl super::Validator { role: _, } => { self.require_type_capability(Capabilities::COOPERATIVE_MATRIX)?; - if scalar != crate::CooperativeScalar::F32 { + if scalar.kind != crate::ScalarKind::Float || scalar.width != 4 { return Err(TypeError::MatrixElementNotFloat); } TypeInfo::new( @@ -436,7 +436,7 @@ impl super::Validator { | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED, - Alignment::from_width(scalar.width()), + Alignment::from_width(scalar.width), ) } Ti::Atomic(scalar) => { diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index c06d67d21a..4a3be8b94e 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -1,4 +1,4 @@ -targets = "SPIRV" +targets = "IR | SPIRV | METAL" god_mode = true [spv] diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 2380046fbc..24ecb9a2b3 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -1,8 +1,12 @@ var a: coop_mat8x8; var b: coop_mat8x8; -var c: coop_mat8x8; +@group(0) @binding(0) +var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { - let d = coopMulAdd(a, b, c); + var c = coop_mat8x8(); + coopLoad(c, &ext); + var d = coopMultiplyAdd(a, b, c); + coopStore(c, &ext); } diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 1298f69e2c..31d47d603a 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -1,14 +1,56 @@ ( types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), ( name: None, inner: CooperativeMatrix( columns: Eight, rows: Eight, - scalar: F32, + scalar: ( + kind: Float, + width: 4, + ), role: A, ), ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: B, + ), + ), + ( + name: None, + inner: Array( + base: 0, + size: Dynamic, + stride: 4, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: C, + ), + ), ], special_types: ( ray_desc: None, @@ -25,7 +67,26 @@ name: Some("a"), space: Private, binding: None, - ty: 0, + ty: 1, + init: None, + ), + ( + name: Some("b"), + space: Private, + binding: None, + ty: 2, + init: None, + ), + ( + name: Some("ext"), + space: Storage( + access: ("LOAD | STORE"), + ), + binding: Some(( + group: 0, + binding: 0, + )), + ty: 3, init: None, ), ], @@ -42,34 +103,56 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [ - GlobalVariable(0), - Load( - pointer: 0, + local_variables: [ + ( + name: Some("c"), + ty: 4, + init: Some(0), ), - GlobalVariable(0), - Load( - pointer: 2, + ( + name: Some("d"), + ty: 4, + init: None, ), - Binary( - op: Add, - left: 1, - right: 3, + ], + expressions: [ + ZeroValue(4), + LocalVariable(0), + GlobalVariable(2), + GlobalVariable(0), + GlobalVariable(1), + CooperativeMultiplyAdd( + a: 3, + b: 4, + c: 1, ), + LocalVariable(1), + GlobalVariable(2), ], - named_expressions: { - 4: "a2", - }, + named_expressions: {}, body: [ + CooperativeLoadStore( + store: false, + target: 1, + pointer: 2, + stride: None, + row_major: false, + ), Emit(( - start: 1, - end: 2, - )), - Emit(( - start: 3, - end: 5, + start: 5, + end: 6, )), + Store( + pointer: 6, + value: 5, + ), + CooperativeLoadStore( + store: true, + target: 1, + pointer: 7, + stride: None, + row_major: false, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 1298f69e2c..31d47d603a 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -1,14 +1,56 @@ ( types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), ( name: None, inner: CooperativeMatrix( columns: Eight, rows: Eight, - scalar: F32, + scalar: ( + kind: Float, + width: 4, + ), role: A, ), ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: B, + ), + ), + ( + name: None, + inner: Array( + base: 0, + size: Dynamic, + stride: 4, + ), + ), + ( + name: None, + inner: CooperativeMatrix( + columns: Eight, + rows: Eight, + scalar: ( + kind: Float, + width: 4, + ), + role: C, + ), + ), ], special_types: ( ray_desc: None, @@ -25,7 +67,26 @@ name: Some("a"), space: Private, binding: None, - ty: 0, + ty: 1, + init: None, + ), + ( + name: Some("b"), + space: Private, + binding: None, + ty: 2, + init: None, + ), + ( + name: Some("ext"), + space: Storage( + access: ("LOAD | STORE"), + ), + binding: Some(( + group: 0, + binding: 0, + )), + ty: 3, init: None, ), ], @@ -42,34 +103,56 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [ - GlobalVariable(0), - Load( - pointer: 0, + local_variables: [ + ( + name: Some("c"), + ty: 4, + init: Some(0), ), - GlobalVariable(0), - Load( - pointer: 2, + ( + name: Some("d"), + ty: 4, + init: None, ), - Binary( - op: Add, - left: 1, - right: 3, + ], + expressions: [ + ZeroValue(4), + LocalVariable(0), + GlobalVariable(2), + GlobalVariable(0), + GlobalVariable(1), + CooperativeMultiplyAdd( + a: 3, + b: 4, + c: 1, ), + LocalVariable(1), + GlobalVariable(2), ], - named_expressions: { - 4: "a2", - }, + named_expressions: {}, body: [ + CooperativeLoadStore( + store: false, + target: 1, + pointer: 2, + stride: None, + row_major: false, + ), Emit(( - start: 1, - end: 2, - )), - Emit(( - start: 3, - end: 5, + start: 5, + end: 6, )), + Store( + pointer: 6, + value: 5, + ), + CooperativeLoadStore( + store: true, + target: 1, + pointer: 7, + stride: None, + row_major: false, + ), Return( value: None, ), diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 33e7477e5d..0e8a882994 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,17 +1,82 @@ ; SPIR-V -; Version: 1.1 +; Version: 1.4 ; Generator: rspirv -; Bound: 7 +; Bound: 37 OpCapability Shader +OpCapability CooperativeMatrixKHR +OpCapability VulkanMemoryModel +OpExtension "SPV_KHR_cooperative_matrix" +OpExtension "SPV_KHR_vulkan_memory_model" %1 = OpExtInstImport "GLSL.std.450" -OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %4 "main" -OpExecutionMode %4 LocalSize 8 8 1 +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %25 "main" %15 %18 %21 +OpExecutionMode %25 LocalSize 8 8 1 +%3 = OpString "cooperative-matrix.wgsl" +OpSource Unknown 0 %3 "var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c = coop_mat8x8(); + coopLoad(c, &ext); + var d = coopMultiplyAdd(a, b, c); + coopStore(c, &ext); +} +" +OpName %15 "a" +OpName %18 "b" +OpName %21 "ext" +OpName %25 "main" +OpName %30 "c" +OpName %32 "d" +OpDecorate %12 ArrayStride 4 +OpDecorate %21 DescriptorSet 0 +OpDecorate %21 Binding 0 +OpDecorate %22 Block +OpMemberDecorate %22 0 Offset 0 %2 = OpTypeVoid -%5 = OpTypeFunction %2 -%4 = OpFunction %2 None %5 -%3 = OpLabel -OpBranch %6 -%6 = OpLabel +%4 = OpTypeFloat 32 +%7 = OpTypeInt 32 0 +%6 = OpConstant %7 3 +%8 = OpConstant %7 8 +%9 = OpConstant %7 0 +%5 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %9 +%11 = OpConstant %7 1 +%10 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %11 +%12 = OpTypeRuntimeArray %4 +%14 = OpConstant %7 2 +%13 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %14 +%16 = OpTypePointer Private %5 +%17 = OpConstantNull %5 +%15 = OpVariable %16 Private %17 +%19 = OpTypePointer Private %10 +%20 = OpConstantNull %10 +%18 = OpVariable %19 Private %20 +%22 = OpTypeStruct %12 +%23 = OpTypePointer StorageBuffer %22 +%21 = OpVariable %23 StorageBuffer +%26 = OpTypeFunction %2 +%27 = OpTypePointer StorageBuffer %12 +%29 = OpConstantNull %13 +%31 = OpTypePointer Function %13 +%33 = OpConstantNull %13 +%25 = OpFunction %2 None %26 +%24 = OpLabel +%30 = OpVariable %31 Function %29 +%32 = OpVariable %31 Function %33 +%28 = OpAccessChain %27 %21 %9 +OpBranch %34 +%34 = OpLabel +OpLine %3 9 5 +%35 = OpCooperativeMatrixLoadKHR %13 %28 %11 +OpStore %30 %35 +OpLine %3 10 13 +%36 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 +OpLine %3 10 5 +OpStore %32 %36 +OpLine %3 11 5 +OpCooperativeMatrixStoreKHR %28 %30 %11 OpReturn OpFunctionEnd \ No newline at end of file From 6813768ac59e7d6b4941ae91a22bcce541fb2509 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 25 Sep 2025 21:32:39 -0700 Subject: [PATCH 08/10] coop: fixes and changelog --- CHANGELOG.md | 2 ++ naga/src/back/spv/block.rs | 16 ++++++++-- naga/src/valid/function.rs | 20 +++--------- naga/tests/in/wgsl/cooperative-matrix.wgsl | 4 +-- .../ir/wgsl-cooperative-matrix.compact.ron | 32 ++++++++++++++----- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 32 ++++++++++++++----- .../tests/out/msl/wgsl-cooperative-matrix.msl | 31 ++++++++++++++++++ .../out/spv/wgsl-cooperative-matrix.spvasm | 22 ++++++++----- 8 files changed, 115 insertions(+), 44 deletions(-) create mode 100644 naga/tests/out/msl/wgsl-cooperative-matrix.msl diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a63afa5b8..e5b42f87c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -376,6 +376,8 @@ By @wumpf in [#8282](https://github.com/gfx-rs/wgpu/pull/8282), [#8285](https:// - Expose `naga::front::wgsl::UnimplementedEnableExtension`. By @ErichDonGubler in [#8237](https://github.com/gfx-rs/wgpu/pull/8237). +- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). + ### Changes #### General diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 479de9e425..0ca55c31c7 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3693,6 +3693,18 @@ impl BlockContext<'_> { stride, row_major, } => { + let pointer_id = match self.write_access_chain( + pointer, + &mut block, + AccessTypeAdjustment::None, + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Copperative load/store out-of-bounds handling", + )); + } + }; let layout = if row_major { spirv::CooperativeMatrixLayout::RowMajorKHR } else { @@ -3703,7 +3715,7 @@ impl BlockContext<'_> { if store { block.body.push(Instruction::coop_store( self.cached[target], - self.cached[pointer], + pointer_id, layout_id, stride_id, )); @@ -3713,7 +3725,7 @@ impl BlockContext<'_> { block.body.push(Instruction::coop_load( result_type_id, id, - self.cached[pointer], + pointer_id, layout_id, stride_id, )); diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index a90bc6186a..7911f3e2ae 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1684,32 +1684,20 @@ impl super::Validator { } }; - let ty_inner = - context.resolve_type_inner(pointer, &self.valid_expression_set)?; + let ty_inner = context.resolve_pointer_type(pointer); //TODO: validate stride - let (pty_array, space) = match *ty_inner { + let (pty_scalar, space) = match *ty_inner { crate::TypeInner::Pointer { base, space } => (base, space), _ => { return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)) - } - }; - let pty_scalar = match context.types[pty_array].inner { - crate::TypeInner::Array { - base, - size: _, - stride: _, - } => base, - _ => { - return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)) + .with_span_handle(pointer, context.expressions)); } }; let space = match context.types[pty_scalar].inner { crate::TypeInner::Scalar(s) if s == target_scalar => space, _ => { return Err(FunctionError::InvalidCooperativeDataPointer(pointer) - .with_span_handle(pointer, context.expressions)) + .with_span_handle(pointer, context.expressions)); } }; diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index 24ecb9a2b3..e65fe0d589 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -6,7 +6,7 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { var c = coop_mat8x8(); - coopLoad(c, &ext); + coopLoad(c, &ext[4]); var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext); + coopStore(c, &ext[0]); } diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 31d47d603a..7582580360 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -119,40 +119,56 @@ ZeroValue(4), LocalVariable(0), GlobalVariable(2), + AccessIndex( + base: 2, + index: 4, + ), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 3, - b: 4, + a: 4, + b: 5, c: 1, ), LocalVariable(1), GlobalVariable(2), + AccessIndex( + base: 8, + index: 0, + ), ], named_expressions: {}, body: [ CooperativeLoadStore( store: false, target: 1, - pointer: 2, + pointer: 3, stride: None, row_major: false, ), Emit(( - start: 5, - end: 6, + start: 3, + end: 4, + )), + Emit(( + start: 6, + end: 7, )), Store( - pointer: 6, - value: 5, + pointer: 7, + value: 6, ), CooperativeLoadStore( store: true, target: 1, - pointer: 7, + pointer: 9, stride: None, row_major: false, ), + Emit(( + start: 9, + end: 10, + )), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 31d47d603a..7582580360 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -119,40 +119,56 @@ ZeroValue(4), LocalVariable(0), GlobalVariable(2), + AccessIndex( + base: 2, + index: 4, + ), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 3, - b: 4, + a: 4, + b: 5, c: 1, ), LocalVariable(1), GlobalVariable(2), + AccessIndex( + base: 8, + index: 0, + ), ], named_expressions: {}, body: [ CooperativeLoadStore( store: false, target: 1, - pointer: 2, + pointer: 3, stride: None, row_major: false, ), Emit(( - start: 5, - end: 6, + start: 3, + end: 4, + )), + Emit(( + start: 6, + end: 7, )), Store( - pointer: 6, - value: 5, + pointer: 7, + value: 6, ), CooperativeLoadStore( store: true, target: 1, - pointer: 7, + pointer: 9, stride: None, row_major: false, ), + Emit(( + start: 9, + end: 10, + )), Return( value: None, ), diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl new file mode 100644 index 0000000000..bed4406760 --- /dev/null +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -0,0 +1,31 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct _mslBufferSizes { + uint size2; +}; + +typedef float type_3[1]; +metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const metal::simdgroup_float8x8& a, const metal::simdgroup_float8x8& b, const metal::simdgroup_float8x8& c) { + metal::simdgroup_float8x8 d; + metal::simdgroup_multiply_accumulate(d,a,b,c); + return d; +} + + +kernel void main_( + device type_3 const& ext [[user(fake0)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] +) { + metal::simdgroup_float8x8 a = {}; + metal::simdgroup_float8x8 b = {}; + metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; + metal::simdgroup_float8x8 d = {}; + metal::simdgroup_load(c, ext[4]); + d = NagaCooperativeMultiplyAdd(a, b, c); + metal::simdgroup_store(c, ext[0]); + return; +} diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 0e8a882994..a3626919a9 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 37 +; Bound: 41 OpCapability Shader OpCapability CooperativeMatrixKHR OpCapability VulkanMemoryModel @@ -20,9 +20,9 @@ var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { var c = coop_mat8x8(); - coopLoad(c, &ext); + coopLoad(c, &ext[4]); var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext); + coopStore(c, &ext[0]); } " OpName %15 "a" @@ -62,6 +62,8 @@ OpMemberDecorate %22 0 Offset 0 %29 = OpConstantNull %13 %31 = OpTypePointer Function %13 %33 = OpConstantNull %13 +%35 = OpTypePointer StorageBuffer %4 +%36 = OpConstant %7 4 %25 = OpFunction %2 None %26 %24 = OpLabel %30 = OpVariable %31 Function %29 @@ -70,13 +72,17 @@ OpMemberDecorate %22 0 Offset 0 OpBranch %34 %34 = OpLabel OpLine %3 9 5 -%35 = OpCooperativeMatrixLoadKHR %13 %28 %11 -OpStore %30 %35 +%37 = OpAccessChain %35 %28 %36 +%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 +OpStore %30 %38 +OpLine %3 9 18 OpLine %3 10 13 -%36 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 +%39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 OpLine %3 10 5 -OpStore %32 %36 +OpStore %32 %39 OpLine %3 11 5 -OpCooperativeMatrixStoreKHR %28 %30 %11 +%40 = OpAccessChain %35 %28 %9 +OpCooperativeMatrixStoreKHR %40 %30 %11 +OpLine %3 11 19 OpReturn OpFunctionEnd \ No newline at end of file From d18bc38231f46c7e70797f0187c96d4889f28b33 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 25 Sep 2025 22:36:48 -0700 Subject: [PATCH 09/10] coop: make stride non-optional --- naga/src/back/dot/mod.rs | 4 +- naga/src/back/msl/writer.rs | 26 +------------ naga/src/back/pipeline_constants.rs | 4 +- naga/src/back/spv/block.rs | 5 +-- naga/src/back/spv/instructions.rs | 19 ++-------- naga/src/back/wgsl/writer.rs | 6 +-- naga/src/compact/statements.rs | 8 +--- naga/src/front/wgsl/lower/mod.rs | 24 +++++++++--- naga/src/ir/mod.rs | 2 +- naga/src/valid/analyzer.rs | 22 +++++------ naga/src/valid/handles.rs | 2 +- naga/tests/in/wgsl/cooperative-matrix.toml | 2 +- .../ir/wgsl-cooperative-matrix.compact.ron | 38 ++++++++++--------- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 38 ++++++++++--------- .../tests/out/msl/wgsl-cooperative-matrix.msl | 4 +- .../out/spv/wgsl-cooperative-matrix.spvasm | 8 ++-- .../out/wgsl/wgsl-cooperative-matrix.wgsl | 13 +++++++ 17 files changed, 105 insertions(+), 120 deletions(-) create mode 100644 naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 9f5f2f3dab..2fb2d58f93 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -431,9 +431,7 @@ impl StatementGraph { } => { self.dependencies.push((id, target, "target")); self.dependencies.push((id, pointer, "pointer")); - if let Some(stride) = stride { - self.dependencies.push((id, stride, "stride")); - } + self.dependencies.push((id, stride, "stride")); if store { "Store" } else { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index e6de2165ed..a48424c0c1 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -4254,30 +4254,8 @@ impl Writer { self.put_expression(target, &context.expression, true)?; write!(self.out, ", ")?; self.put_expression(pointer, &context.expression, true)?; - if stride.is_some() || row_major { - write!(self.out, ", ")?; - match stride { - Some(expression) => { - self.put_expression(expression, &context.expression, true)?; - } - None => { - let default_stride = match *context.expression.resolve_type(target) - { - crate::TypeInner::CooperativeMatrix { - columns, rows, .. - } => { - if row_major { - columns as u32 - } else { - rows as u32 - } - } - _ => 0, - }; - write!(self.out, "{default_stride}")?; - } - } - } + write!(self.out, ", ")?; + self.put_expression(stride, &context.expression, true)?; if row_major { let matrix_origin = "0"; let transpose = true; diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 4d97e21cc7..3d1c2c3f21 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -898,9 +898,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } => { adjust(target); adjust(pointer); - if let Some(ref mut stride) = *stride { - adjust(stride); - } + adjust(stride); } Statement::Break | Statement::Continue diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 0ca55c31c7..77d1eb6b30 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3711,13 +3711,12 @@ impl BlockContext<'_> { spirv::CooperativeMatrixLayout::ColumnMajorKHR }; let layout_id = self.get_index_constant(layout as u32); - let stride_id = stride.map(|exp| self.cached[exp]); if store { block.body.push(Instruction::coop_store( self.cached[target], pointer_id, layout_id, - stride_id, + self.cached[stride], )); } else { let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty); @@ -3727,7 +3726,7 @@ impl BlockContext<'_> { id, pointer_id, layout_id, - stride_id, + self.cached[stride], )); block .body diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 419c276fc4..22eaa99340 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1252,33 +1252,22 @@ impl super::Instruction { id: Word, pointer_id: Word, layout_id: Word, - stride_id: Option, + stride_id: Word, ) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixLoadKHR); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(pointer_id); instruction.add_operand(layout_id); - if let Some(stride_id) = stride_id { - instruction.add_operand(stride_id); - } - + instruction.add_operand(stride_id); instruction } - pub(super) fn coop_store( - id: Word, - pointer_id: Word, - layout_id: Word, - stride_id: Option, - ) -> Self { + pub(super) fn coop_store(id: Word, pointer_id: Word, layout_id: Word, stride_id: Word) -> Self { let mut instruction = Self::new(Op::CooperativeMatrixStoreKHR); instruction.add_operand(pointer_id); instruction.add_operand(id); instruction.add_operand(layout_id); - if let Some(stride_id) = stride_id { - instruction.add_operand(stride_id); - } - + instruction.add_operand(stride_id); instruction } pub(super) fn coop_mul_add(result_type_id: Word, id: Word, a: Word, b: Word, c: Word) -> Self { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index ad766b0f6b..ac962967eb 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -998,10 +998,8 @@ impl Writer { self.write_expr(module, target, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, pointer, func_ctx)?; - if let Some(stride) = stride { - write!(self.out, ", ")?; - self.write_expr(module, stride, func_ctx)?; - } + write!(self.out, ", ")?; + self.write_expr(module, stride, func_ctx)?; write!(self.out, ")")? } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index e91436ca74..1678c3be4b 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -175,9 +175,7 @@ impl FunctionTracer<'_> { } => { self.expressions_used.insert(target); self.expressions_used.insert(pointer); - if let Some(stride) = stride { - self.expressions_used.insert(stride); - } + self.expressions_used.insert(stride); } // Trivial statements. @@ -427,9 +425,7 @@ impl FunctionMap { } => { adjust(target); adjust(pointer); - if let Some(ref mut stride) = *stride { - adjust(stride); - } + adjust(stride); } // Trivial statements. diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 5a1a90a0e9..29e414d978 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -3138,19 +3138,33 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(result)); } "coopLoad" | "coopLoadT" | "coopStore" | "coopStoreT" => { + let store = function.name.contains("Store"); + let row_major = function.name.ends_with("T"); + let mut args = ctx.prepare_args(arguments, 2, span); let target = self.expression(args.next()?, ctx)?; let pointer = self.expression(args.next()?, ctx)?; let stride = if args.total_args > 2 { - Some(self.expression(args.next()?, ctx)?) + self.expression(args.next()?, ctx)? } else { - None + // Infer the stride from the matrix type + let stride = match *resolve_inner!(ctx, target) { + ir::TypeInner::CooperativeMatrix { columns, rows, .. } => { + if row_major { + columns as u32 + } else { + rows as u32 + } + } + _ => 0, + }; + ctx.append_expression( + ir::Expression::Literal(ir::Literal::U32(stride)), + Span::UNDEFINED, + )? }; args.finish()?; - let store = function.name.contains("Store"); - let row_major = function.name.ends_with("T"); - let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::CooperativeLoadStore { diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 9ed185a7ce..4ba53923c7 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -2288,7 +2288,7 @@ pub enum Statement { store: bool, target: Handle, pointer: Handle, - stride: Option>, + stride: Handle, row_major: bool, }, } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index a530c1869c..e396adb0c3 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1234,18 +1234,16 @@ impl FunctionInfo { pointer, stride, row_major: _, - } => { - if let Some(stride) = stride { - let _ = self.add_ref(stride); - } - FunctionUniformity { - result: Uniformity { - non_uniform_result: self.add_ref(target).or(self.add_ref(pointer)), - requirements: UniformityRequirements::COOP_OPS, - }, - exit: ExitFlags::empty(), - } - } + } => FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .add_ref(target) + .or(self.add_ref(pointer)) + .or(self.add_ref(stride)), + requirements: UniformityRequirements::COOP_OPS, + }, + exit: ExitFlags::empty(), + }, }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 6dbd542814..db91466e6e 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -878,7 +878,7 @@ impl super::Validator { } => { validate_expr(target)?; validate_expr(pointer)?; - validate_expr_opt(stride)?; + validate_expr(stride)?; Ok(()) } crate::Statement::Break diff --git a/naga/tests/in/wgsl/cooperative-matrix.toml b/naga/tests/in/wgsl/cooperative-matrix.toml index 4a3be8b94e..a95da7bf80 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.toml +++ b/naga/tests/in/wgsl/cooperative-matrix.toml @@ -1,4 +1,4 @@ -targets = "IR | SPIRV | METAL" +targets = "IR | SPIRV | METAL | WGSL" god_mode = true [spv] diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 7582580360..7f8fc73568 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -123,52 +123,54 @@ base: 2, index: 4, ), + Literal(U32(8)), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 4, - b: 5, + a: 5, + b: 6, c: 1, ), LocalVariable(1), GlobalVariable(2), AccessIndex( - base: 8, + base: 9, index: 0, ), + Literal(U32(8)), ], named_expressions: {}, body: [ + Emit(( + start: 3, + end: 4, + )), CooperativeLoadStore( store: false, target: 1, pointer: 3, - stride: None, + stride: 4, row_major: false, ), Emit(( - start: 3, - end: 4, - )), - Emit(( - start: 6, - end: 7, + start: 7, + end: 8, )), Store( - pointer: 7, - value: 6, + pointer: 8, + value: 7, ), + Emit(( + start: 10, + end: 11, + )), CooperativeLoadStore( store: true, target: 1, - pointer: 9, - stride: None, + pointer: 10, + stride: 11, row_major: false, ), - Emit(( - start: 9, - end: 10, - )), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 7582580360..7f8fc73568 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -123,52 +123,54 @@ base: 2, index: 4, ), + Literal(U32(8)), GlobalVariable(0), GlobalVariable(1), CooperativeMultiplyAdd( - a: 4, - b: 5, + a: 5, + b: 6, c: 1, ), LocalVariable(1), GlobalVariable(2), AccessIndex( - base: 8, + base: 9, index: 0, ), + Literal(U32(8)), ], named_expressions: {}, body: [ + Emit(( + start: 3, + end: 4, + )), CooperativeLoadStore( store: false, target: 1, pointer: 3, - stride: None, + stride: 4, row_major: false, ), Emit(( - start: 3, - end: 4, - )), - Emit(( - start: 6, - end: 7, + start: 7, + end: 8, )), Store( - pointer: 7, - value: 6, + pointer: 8, + value: 7, ), + Emit(( + start: 10, + end: 11, + )), CooperativeLoadStore( store: true, target: 1, - pointer: 9, - stride: None, + pointer: 10, + stride: 11, row_major: false, ), - Emit(( - start: 9, - end: 10, - )), Return( value: None, ), diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl index bed4406760..4e17948e6b 100644 --- a/naga/tests/out/msl/wgsl-cooperative-matrix.msl +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -24,8 +24,8 @@ kernel void main_( metal::simdgroup_float8x8 b = {}; metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; metal::simdgroup_float8x8 d = {}; - metal::simdgroup_load(c, ext[4]); + metal::simdgroup_load(c, ext[4], 8u); d = NagaCooperativeMultiplyAdd(a, b, c); - metal::simdgroup_store(c, ext[0]); + metal::simdgroup_store(c, ext[0], 8u); return; } diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index a3626919a9..56d9e8c7ae 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -71,18 +71,18 @@ OpMemberDecorate %22 0 Offset 0 %28 = OpAccessChain %27 %21 %9 OpBranch %34 %34 = OpLabel +OpLine %3 9 18 OpLine %3 9 5 %37 = OpAccessChain %35 %28 %36 -%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 +%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 OpStore %30 %38 -OpLine %3 9 18 OpLine %3 10 13 %39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 OpLine %3 10 5 OpStore %32 %39 +OpLine %3 11 19 OpLine %3 11 5 %40 = OpAccessChain %35 %28 %9 -OpCooperativeMatrixStoreKHR %40 %30 %11 -OpLine %3 11 19 +OpCooperativeMatrixStoreKHR %40 %30 %11 %8 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl new file mode 100644 index 0000000000..2b249bb4d5 --- /dev/null +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -0,0 +1,13 @@ +var a: coop_mat8x8; +var b: coop_mat8x8; +@group(0) @binding(0) +var ext: array; + +@compute @workgroup_size(8, 8, 1) +fn main() { + var c: coop_mat8x8 = coop_mat8x8(); + var d: coop_mat8x8; + +coopLoad((&c), (&ext[4]), 8u) d = coopMultiplyAdd((&a), (&b), (&c)); +coopStore((&c), (&ext[0]), 8u) return; +} From d801c1eac8f8ae7c5e67cf8482c0c9a8096518d5 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 25 Sep 2025 23:20:20 -0700 Subject: [PATCH 10/10] coop: rewire WGSL support using references --- CHANGELOG.md | 2 +- naga/src/back/mod.rs | 10 ++ naga/src/back/msl/writer.rs | 20 ++-- naga/src/back/spv/block.rs | 8 +- naga/src/back/spv/writer.rs | 6 +- naga/src/back/wgsl/writer.rs | 12 +-- naga/src/front/wgsl/error.rs | 6 ++ naga/src/front/wgsl/lower/mod.rs | 49 +++++---- naga/src/proc/type_methods.rs | 11 -- naga/src/proc/typifier.rs | 26 ++--- naga/src/valid/expression.rs | 45 +++----- naga/src/valid/function.rs | 15 ++- naga/tests/in/wgsl/cooperative-matrix.wgsl | 5 +- .../analysis/wgsl-cooperative-matrix.info.ron | 78 -------------- .../ir/wgsl-cooperative-matrix.compact.ron | 86 +-------------- naga/tests/out/ir/wgsl-cooperative-matrix.ron | 40 +------ .../tests/out/msl/wgsl-cooperative-matrix.msl | 19 +--- .../out/spv/wgsl-cooperative-matrix.spvasm | 100 +++++++----------- .../out/wgsl/wgsl-cooperative-matrix.wgsl | 7 +- 19 files changed, 174 insertions(+), 371 deletions(-) delete mode 100644 naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron diff --git a/CHANGELOG.md b/CHANGELOG.md index e5b42f87c3..d54292e53a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -376,7 +376,7 @@ By @wumpf in [#8282](https://github.com/gfx-rs/wgpu/pull/8282), [#8285](https:// - Expose `naga::front::wgsl::UnimplementedEnableExtension`. By @ErichDonGubler in [#8237](https://github.com/gfx-rs/wgpu/pull/8237). -- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V with METAL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). +- Added support for cooperative load/store operations in shaders. Currently only WGSL on the input and SPIR-V,METAL, and WGSL on the output are supported. By @kvark in [#8251](https://github.com/gfx-rs/wgpu/issues/8251). ### Changes diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index ef9d829969..54007b387a 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -311,6 +311,16 @@ pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { } } +impl crate::TypeInner { + /// Returns true if a variable of this type is a handle. + pub const fn is_handle(&self) -> bool { + match *self { + Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure { .. } => true, + _ => false, + } + } +} + impl crate::Statement { /// Returns true if the statement directly terminates the current block. /// diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index a48424c0c1..498ba45769 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -6327,16 +6327,22 @@ template b: Handle, ) -> BackendResult { let (a_c, a_r, scalar) = match *func_ctx.resolve_type(a, &module.types) { - crate::TypeInner::CooperativeMatrix { - columns, - rows, - scalar, - .. - } => (columns, rows, scalar), + crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner { + crate::TypeInner::CooperativeMatrix { + columns, + rows, + scalar, + .. + } => (columns, rows, scalar), + _ => unreachable!(), + }, _ => unreachable!(), }; let (b_c, b_r) = match *func_ctx.resolve_type(b, &module.types) { - crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), + crate::TypeInner::Pointer { base, space: _ } => match module.types[base].inner { + crate::TypeInner::CooperativeMatrix { columns, rows, .. } => (columns, rows), + _ => unreachable!(), + }, _ => unreachable!(), }; let wrapped = WrappedFunction::CooperativeMultiplyAdd { diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 77d1eb6b30..4a7949425c 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3719,7 +3719,13 @@ impl BlockContext<'_> { self.cached[stride], )); } else { - let result_type_id = self.get_expression_type_id(&self.fun_info[target].ty); + let result_type_id = + match *self.fun_info[target].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Pointer { base, space: _ } => { + self.get_handle_type_id(base) + } + _ => unreachable!(), + }; let id = self.gen_id(); block.body.push(Instruction::coop_load( result_type_id, diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 8e9b3553a0..73270c2543 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -970,14 +970,13 @@ impl Writer { } } - // Handle globals are pre-emitted and should be loaded automatically. - // - // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. match ir_module.types[var.ty].inner { + // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. crate::TypeInner::BindingArray { .. } => { gv.access_id = gv.var_id; } _ => { + // Handle globals are pre-emitted and should be loaded automatically. if var.space == crate::AddressSpace::Handle { let var_type_id = self.get_handle_type_id(var.ty); let id = self.id_gen.next(); @@ -1063,6 +1062,7 @@ impl Writer { } }), ); + context .function .variables diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index ac962967eb..379af9160f 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -994,13 +994,13 @@ impl Writer { } => { let op_str = if store { "Store" } else { "Load" }; let suffix = if row_major { "T" } else { "" }; - write!(self.out, "coop{op_str}{suffix}(")?; - self.write_expr(module, target, func_ctx)?; + write!(self.out, "{level}coop{op_str}{suffix}(")?; + self.write_expr_with_indirection(module, target, func_ctx, Indirection::Reference)?; write!(self.out, ", ")?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, stride, func_ctx)?; - write!(self.out, ")")? + writeln!(self.out, ");")? } } @@ -1715,11 +1715,11 @@ impl Writer { | Expression::WorkGroupUniformLoadResult { .. } => {} Expression::CooperativeMultiplyAdd { a, b, c } => { write!(self.out, "coopMultiplyAdd(")?; - self.write_expr(module, a, func_ctx)?; + self.write_expr_with_indirection(module, a, func_ctx, Indirection::Reference)?; write!(self.out, ", ")?; - self.write_expr(module, b, func_ctx)?; + self.write_expr_with_indirection(module, b, func_ctx, Indirection::Reference)?; write!(self.out, ", ")?; - self.write_expr(module, c, func_ctx)?; + self.write_expr_with_indirection(module, c, func_ctx, Indirection::Reference)?; write!(self.out, ")")?; } } diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f0d6a4b848..1487d4bef3 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -412,6 +412,7 @@ pub(crate) enum Error<'a> { TypeTooLarge { span: Span, }, + InvalidCooperativeMatrix, UnderspecifiedCooperativeMatrix, UnsupportedCooperativeScalar(Span), } @@ -1388,6 +1389,11 @@ impl<'a> Error<'a> { crate::valid::MAX_TYPE_SIZE )], }, + Error::InvalidCooperativeMatrix => ParseError { + message: "given type is not a cooperative matrix".into(), + labels: vec![], + notes: vec![format!("must be coop_mat")], + }, Error::UnderspecifiedCooperativeMatrix => ParseError { message: "cooperative matrix constructor is underspecified".into(), labels: vec![], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 29e414d978..46e03cc543 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -846,6 +846,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn ensure_type_exists(&mut self, inner: ir::TypeInner) -> Handle { self.as_global().ensure_type_exists(None, inner) } + + fn _get_runtime_expression(&self, expr: Handle) -> &ir::Expression { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => &ctx.function.expressions[expr], + ExpressionContextType::Constant(_) | ExpressionContextType::Override => { + unreachable!() + } + } + } } struct ArgumentContext<'ctx, 'source> { @@ -955,6 +964,13 @@ impl Typed { Self::Plain(expr) => Typed::Plain(f(expr)?), }) } + + fn ref_or(self, error: E) -> core::result::Result { + match self { + Self::Reference(v) => Ok(v), + Self::Plain(_) => Err(error), + } + } } /// A single vector component or swizzle. @@ -1679,12 +1695,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .as_expression(block, &mut emitter) .interrupt_emitter(ir::Expression::LocalVariable(var), Span::UNDEFINED)?; block.extend(emitter.finish(&ctx.function.expressions)); - let typed = if ctx.module.types[ty].inner.is_handle() { - Typed::Plain(handle) - } else { - Typed::Reference(handle) - }; - ctx.local_table.insert(v.handle, Declared::Runtime(typed)); + ctx.local_table + .insert(v.handle, Declared::Runtime(Typed::Reference(handle))); match initializer { Some(initializer) => ir::Statement::Store { @@ -1979,12 +1991,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value_span = ctx.ast_expressions.get_span(value); let target = self .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; - let target_handle = match target { - Typed::Reference(handle) => handle, - Typed::Plain(_) => { - return Err(Box::new(Error::BadIncrDecrReferenceType(value_span))) - } - }; + let target_handle = target.ref_or(Error::BadIncrDecrReferenceType(value_span))?; let mut ectx = ctx.as_expression(block, &mut emitter); let scalar = match *resolve_inner!(ectx, target_handle) { @@ -2141,10 +2148,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { LoweredGlobalDecl::Var(handle) => { let expr = ir::Expression::GlobalVariable(handle); let v = &ctx.module.global_variables[handle]; - let force_value = ctx.module.types[v.ty].inner.is_handle(); match v.space { ir::AddressSpace::Handle => Typed::Plain(expr), - _ if force_value => Typed::Plain(expr), _ => Typed::Reference(expr), } } @@ -3142,7 +3147,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let row_major = function.name.ends_with("T"); let mut args = ctx.prepare_args(arguments, 2, span); - let target = self.expression(args.next()?, ctx)?; + let target = self + .expression_for_reference(args.next()?, ctx)? + .ref_or(Error::InvalidCooperativeMatrix)?; let pointer = self.expression(args.next()?, ctx)?; let stride = if args.total_args > 2 { self.expression(args.next()?, ctx)? @@ -3180,9 +3187,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } "coopMultiplyAdd" => { let mut args = ctx.prepare_args(arguments, 3, span); - let a = self.expression(args.next()?, ctx)?; - let b = self.expression(args.next()?, ctx)?; - let c = self.expression(args.next()?, ctx)?; + let a = self + .expression_for_reference(args.next()?, ctx)? + .ref_or(Error::InvalidCooperativeMatrix)?; + let b = self + .expression_for_reference(args.next()?, ctx)? + .ref_or(Error::InvalidCooperativeMatrix)?; + let c = self + .expression_for_reference(args.next()?, ctx)? + .ref_or(Error::InvalidCooperativeMatrix)?; args.finish()?; ir::Expression::CooperativeMultiplyAdd { a, b, c } diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index 136ea29218..fe3eb4b626 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -191,17 +191,6 @@ impl crate::TypeInner { } } - /// Returns true if a variable of this type is a handle. - pub const fn is_handle(&self) -> bool { - match *self { - Self::Image { .. } - | Self::Sampler { .. } - | Self::AccelerationStructure { .. } - | Self::CooperativeMatrix { .. } => true, - _ => false, - } - } - /// Attempt to calculate the size of this type. Returns `None` if the size /// exceeds the limit of [`crate::valid::MAX_TYPE_SIZE`]. pub fn try_size(&self, gctx: super::GlobalCtx) -> Option { diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 8e323d7724..0147fdec46 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -454,8 +454,7 @@ impl<'a> ResolveContext<'a> { } crate::Expression::GlobalVariable(h) => { let var = &self.global_vars[h]; - let ty = &types[var.ty].inner; - if var.space == crate::AddressSpace::Handle || ty.is_handle() { + if var.space == crate::AddressSpace::Handle { TypeResolution::Handle(var.ty) } else { TypeResolution::Value(Ti::Pointer { @@ -466,15 +465,10 @@ impl<'a> ResolveContext<'a> { } crate::Expression::LocalVariable(h) => { let var = &self.local_vars[h]; - let ty = &types[var.ty].inner; - if ty.is_handle() { - TypeResolution::Handle(var.ty) - } else { - TypeResolution::Value(Ti::Pointer { - base: var.ty, - space: crate::AddressSpace::Function, - }) - } + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: crate::AddressSpace::Function, + }) } crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { Ti::Pointer { base, space: _ } => { @@ -807,7 +801,15 @@ impl<'a> ResolveContext<'a> { scalar: crate::Scalar::U32, size: crate::VectorSize::Quad, }), - crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => past(c)?.clone(), + crate::Expression::CooperativeMultiplyAdd { a: _, b: _, c } => { + match *past(c)?.inner_with(types) { + Ti::Pointer { base, space: _ } => TypeResolution::Handle(base), + ref other => { + log::error!("Pointer type {other:?}"); + return Err(ResolveError::InvalidPointer(c)); + } + } + } }) } } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 466ca26b60..909556477b 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1256,34 +1256,23 @@ impl super::Validator { }, E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, E::CooperativeMultiplyAdd { a, b, c } => { - match resolver[a] { - Ti::CooperativeMatrix { - role: crate::CooperativeRole::A, - .. - } => {} - ref other => { - log::error!("A operand type: {other:?}"); - return Err(ExpressionError::InvalidCooperativeOperand(a)); - } - } - match resolver[b] { - Ti::CooperativeMatrix { - role: crate::CooperativeRole::B, - .. - } => {} - ref other => { - log::error!("B operand type: {other:?}"); - return Err(ExpressionError::InvalidCooperativeOperand(b)); - } - } - match resolver[c] { - Ti::CooperativeMatrix { - role: crate::CooperativeRole::C, - .. - } => {} - ref other => { - log::error!("C operand type: {other:?}"); - return Err(ExpressionError::InvalidCooperativeOperand(c)); + let roles = [ + crate::CooperativeRole::A, + crate::CooperativeRole::B, + crate::CooperativeRole::C, + ]; + for (operand, expected_role) in [a, b, c].into_iter().zip(roles) { + match resolver[operand] { + Ti::Pointer { base, space: _ } => match module.types[base].inner { + Ti::CooperativeMatrix { role, .. } if role == expected_role => {} + ref other => { + log::error!("{expected_role:?} operand type: {other:?}"); + return Err(ExpressionError::InvalidCooperativeOperand(a)); + } + }, + _ => { + return Err(ExpressionError::InvalidPointerType(operand)); + } } } ShaderStages::COMPUTE diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 7911f3e2ae..119460b64d 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1086,7 +1086,7 @@ impl super::Validator { } else if let Some(tr) = pointer_base_tr { context.compare_types(value_tr, &tr) } else { - value_ty.is_handle() + false }; if !good { @@ -1674,15 +1674,22 @@ impl super::Validator { } => { stages &= super::ShaderStages::COMPUTE; - let target_scalar = - match *context.resolve_type_inner(target, &self.valid_expression_set)? { + let target_scalar = match *context.resolve_pointer_type(target) { + crate::TypeInner::Pointer { base, space: _ } => match context.types[base] + .inner + { Ti::CooperativeMatrix { scalar, .. } => scalar, ref other => { log::error!("Target operand type: {other:?}"); return Err(FunctionError::InvalidCooperativeStoreTarget(target) .with_span_handle(target, context.expressions)); } - }; + }, + _ => { + return Err(FunctionError::InvalidCooperativeStoreTarget(target) + .with_span_handle(target, context.expressions)); + } + }; let ty_inner = context.resolve_pointer_type(pointer); //TODO: validate stride diff --git a/naga/tests/in/wgsl/cooperative-matrix.wgsl b/naga/tests/in/wgsl/cooperative-matrix.wgsl index e65fe0d589..602e52ae6c 100644 --- a/naga/tests/in/wgsl/cooperative-matrix.wgsl +++ b/naga/tests/in/wgsl/cooperative-matrix.wgsl @@ -7,6 +7,7 @@ var ext: array; fn main() { var c = coop_mat8x8(); coopLoad(c, &ext[4]); - var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext[0]); + //var d = coopMultiplyAdd(a, b, c); + //coopStore(d, &ext[0]); + //c = d; } diff --git a/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron b/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron deleted file mode 100644 index f806c3f3dd..0000000000 --- a/naga/tests/out/analysis/wgsl-cooperative-matrix.info.ron +++ /dev/null @@ -1,78 +0,0 @@ -( - type_flags: [ - ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), - ], - functions: [], - entry_points: [ - ( - flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), - uniformity: ( - non_uniform_result: None, - requirements: (""), - ), - may_kill: false, - sampling_set: [], - global_uses: [ - ("READ"), - ], - expressions: [ - ( - uniformity: ( - non_uniform_result: Some(0), - requirements: (""), - ), - ref_count: 1, - assignable_global: Some(0), - ty: Value(Pointer( - base: 0, - space: Private, - )), - ), - ( - uniformity: ( - non_uniform_result: Some(0), - requirements: (""), - ), - ref_count: 1, - assignable_global: None, - ty: Handle(0), - ), - ( - uniformity: ( - non_uniform_result: Some(2), - requirements: (""), - ), - ref_count: 1, - assignable_global: Some(0), - ty: Value(Pointer( - base: 0, - space: Private, - )), - ), - ( - uniformity: ( - non_uniform_result: Some(2), - requirements: (""), - ), - ref_count: 1, - assignable_global: None, - ty: Handle(0), - ), - ( - uniformity: ( - non_uniform_result: Some(0), - requirements: (""), - ), - ref_count: 0, - assignable_global: None, - ty: Handle(0), - ), - ], - sampling: [], - dual_source_blending: false, - diagnostic_filter_leaf: None, - ), - ], - const_expression_types: [], -) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron index 7f8fc73568..cafa1a037d 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.compact.ron @@ -7,30 +7,6 @@ width: 4, )), ), - ( - name: None, - inner: CooperativeMatrix( - columns: Eight, - rows: Eight, - scalar: ( - kind: Float, - width: 4, - ), - role: A, - ), - ), - ( - name: None, - inner: CooperativeMatrix( - columns: Eight, - rows: Eight, - scalar: ( - kind: Float, - width: 4, - ), - role: B, - ), - ), ( name: None, inner: Array( @@ -63,20 +39,6 @@ constants: [], overrides: [], global_variables: [ - ( - name: Some("a"), - space: Private, - binding: None, - ty: 1, - init: None, - ), - ( - name: Some("b"), - space: Private, - binding: None, - ty: 2, - init: None, - ), ( name: Some("ext"), space: Storage( @@ -86,7 +48,7 @@ group: 0, binding: 0, )), - ty: 3, + ty: 1, init: None, ), ], @@ -106,38 +68,19 @@ local_variables: [ ( name: Some("c"), - ty: 4, + ty: 2, init: Some(0), ), - ( - name: Some("d"), - ty: 4, - init: None, - ), ], expressions: [ - ZeroValue(4), + ZeroValue(2), LocalVariable(0), - GlobalVariable(2), + GlobalVariable(0), AccessIndex( base: 2, index: 4, ), - Literal(U32(8)), - GlobalVariable(0), - GlobalVariable(1), - CooperativeMultiplyAdd( - a: 5, - b: 6, - c: 1, - ), - LocalVariable(1), - GlobalVariable(2), - AccessIndex( - base: 9, - index: 0, - ), - Literal(U32(8)), + Literal(U32(0)), ], named_expressions: {}, body: [ @@ -152,25 +95,6 @@ stride: 4, row_major: false, ), - Emit(( - start: 7, - end: 8, - )), - Store( - pointer: 8, - value: 7, - ), - Emit(( - start: 10, - end: 11, - )), - CooperativeLoadStore( - store: true, - target: 1, - pointer: 10, - stride: 11, - row_major: false, - ), Return( value: None, ), diff --git a/naga/tests/out/ir/wgsl-cooperative-matrix.ron b/naga/tests/out/ir/wgsl-cooperative-matrix.ron index 7f8fc73568..25f6eae72a 100644 --- a/naga/tests/out/ir/wgsl-cooperative-matrix.ron +++ b/naga/tests/out/ir/wgsl-cooperative-matrix.ron @@ -109,11 +109,6 @@ ty: 4, init: Some(0), ), - ( - name: Some("d"), - ty: 4, - init: None, - ), ], expressions: [ ZeroValue(4), @@ -123,21 +118,7 @@ base: 2, index: 4, ), - Literal(U32(8)), - GlobalVariable(0), - GlobalVariable(1), - CooperativeMultiplyAdd( - a: 5, - b: 6, - c: 1, - ), - LocalVariable(1), - GlobalVariable(2), - AccessIndex( - base: 9, - index: 0, - ), - Literal(U32(8)), + Literal(U32(0)), ], named_expressions: {}, body: [ @@ -152,25 +133,6 @@ stride: 4, row_major: false, ), - Emit(( - start: 7, - end: 8, - )), - Store( - pointer: 8, - value: 7, - ), - Emit(( - start: 10, - end: 11, - )), - CooperativeLoadStore( - store: true, - target: 1, - pointer: 10, - stride: 11, - row_major: false, - ), Return( value: None, ), diff --git a/naga/tests/out/msl/wgsl-cooperative-matrix.msl b/naga/tests/out/msl/wgsl-cooperative-matrix.msl index 4e17948e6b..716d116b39 100644 --- a/naga/tests/out/msl/wgsl-cooperative-matrix.msl +++ b/naga/tests/out/msl/wgsl-cooperative-matrix.msl @@ -5,27 +5,16 @@ using metal::uint; struct _mslBufferSizes { - uint size2; + uint size0; }; -typedef float type_3[1]; -metal::simdgroup_float8x8 NagaCooperativeMultiplyAdd(const metal::simdgroup_float8x8& a, const metal::simdgroup_float8x8& b, const metal::simdgroup_float8x8& c) { - metal::simdgroup_float8x8 d; - metal::simdgroup_multiply_accumulate(d,a,b,c); - return d; -} - +typedef float type_1[1]; kernel void main_( - device type_3 const& ext [[user(fake0)]] + device type_1 const& ext [[user(fake0)]] , constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { - metal::simdgroup_float8x8 a = {}; - metal::simdgroup_float8x8 b = {}; metal::simdgroup_float8x8 c = metal::simdgroup_float8x8 {}; - metal::simdgroup_float8x8 d = {}; - metal::simdgroup_load(c, ext[4], 8u); - d = NagaCooperativeMultiplyAdd(a, b, c); - metal::simdgroup_store(c, ext[0], 8u); + metal::simdgroup_load(c, ext[4], 0u); return; } diff --git a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm index 56d9e8c7ae..37a134eee0 100644 --- a/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm +++ b/naga/tests/out/spv/wgsl-cooperative-matrix.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 41 +; Bound: 29 OpCapability Shader OpCapability CooperativeMatrixKHR OpCapability VulkanMemoryModel @@ -9,8 +9,8 @@ OpExtension "SPV_KHR_cooperative_matrix" OpExtension "SPV_KHR_vulkan_memory_model" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical Vulkan -OpEntryPoint GLCompute %25 "main" %15 %18 %21 -OpExecutionMode %25 LocalSize 8 8 1 +OpEntryPoint GLCompute %15 "main" %11 +OpExecutionMode %15 LocalSize 8 8 1 %3 = OpString "cooperative-matrix.wgsl" OpSource Unknown 0 %3 "var a: coop_mat8x8; var b: coop_mat8x8; @@ -21,68 +21,48 @@ var ext: array; fn main() { var c = coop_mat8x8(); coopLoad(c, &ext[4]); - var d = coopMultiplyAdd(a, b, c); - coopStore(c, &ext[0]); + //var d = coopMultiplyAdd(a, b, c); + //coopStore(d, &ext[0]); + //c = d; } " -OpName %15 "a" -OpName %18 "b" -OpName %21 "ext" -OpName %25 "main" -OpName %30 "c" -OpName %32 "d" -OpDecorate %12 ArrayStride 4 -OpDecorate %21 DescriptorSet 0 -OpDecorate %21 Binding 0 -OpDecorate %22 Block -OpMemberDecorate %22 0 Offset 0 +OpName %11 "ext" +OpName %15 "main" +OpName %21 "c" +OpDecorate %5 ArrayStride 4 +OpDecorate %11 DescriptorSet 0 +OpDecorate %11 Binding 0 +OpDecorate %12 Block +OpMemberDecorate %12 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeFloat 32 -%7 = OpTypeInt 32 0 -%6 = OpConstant %7 3 -%8 = OpConstant %7 8 -%9 = OpConstant %7 0 -%5 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %9 -%11 = OpConstant %7 1 -%10 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %11 -%12 = OpTypeRuntimeArray %4 -%14 = OpConstant %7 2 -%13 = OpTypeCooperativeMatrixKHR %4 %6 %8 %8 %14 -%16 = OpTypePointer Private %5 -%17 = OpConstantNull %5 -%15 = OpVariable %16 Private %17 -%19 = OpTypePointer Private %10 -%20 = OpConstantNull %10 -%18 = OpVariable %19 Private %20 -%22 = OpTypeStruct %12 -%23 = OpTypePointer StorageBuffer %22 -%21 = OpVariable %23 StorageBuffer -%26 = OpTypeFunction %2 -%27 = OpTypePointer StorageBuffer %12 -%29 = OpConstantNull %13 -%31 = OpTypePointer Function %13 -%33 = OpConstantNull %13 -%35 = OpTypePointer StorageBuffer %4 -%36 = OpConstant %7 4 -%25 = OpFunction %2 None %26 -%24 = OpLabel -%30 = OpVariable %31 Function %29 -%32 = OpVariable %31 Function %33 -%28 = OpAccessChain %27 %21 %9 -OpBranch %34 -%34 = OpLabel +%5 = OpTypeRuntimeArray %4 +%8 = OpTypeInt 32 0 +%7 = OpConstant %8 3 +%9 = OpConstant %8 8 +%10 = OpConstant %8 2 +%6 = OpTypeCooperativeMatrixKHR %4 %7 %9 %9 %10 +%12 = OpTypeStruct %5 +%13 = OpTypePointer StorageBuffer %12 +%11 = OpVariable %13 StorageBuffer +%16 = OpTypeFunction %2 +%17 = OpTypePointer StorageBuffer %5 +%18 = OpConstant %8 0 +%20 = OpConstantNull %6 +%22 = OpTypePointer Function %6 +%24 = OpTypePointer StorageBuffer %4 +%25 = OpConstant %8 4 +%27 = OpConstant %8 1 +%15 = OpFunction %2 None %16 +%14 = OpLabel +%21 = OpVariable %22 Function %20 +%19 = OpAccessChain %17 %11 %18 +OpBranch %23 +%23 = OpLabel OpLine %3 9 18 OpLine %3 9 5 -%37 = OpAccessChain %35 %28 %36 -%38 = OpCooperativeMatrixLoadKHR %13 %37 %11 %8 -OpStore %30 %38 -OpLine %3 10 13 -%39 = OpCooperativeMatrixMulAddKHR %13 %15 %18 %30 -OpLine %3 10 5 -OpStore %32 %39 -OpLine %3 11 19 -OpLine %3 11 5 -%40 = OpAccessChain %35 %28 %9 -OpCooperativeMatrixStoreKHR %40 %30 %11 %8 +%26 = OpAccessChain %24 %19 %25 +%28 = OpCooperativeMatrixLoadKHR %6 %26 %27 %18 +OpStore %21 %28 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl index 2b249bb4d5..a03b0b5190 100644 --- a/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl +++ b/naga/tests/out/wgsl/wgsl-cooperative-matrix.wgsl @@ -1,13 +1,10 @@ -var a: coop_mat8x8; -var b: coop_mat8x8; @group(0) @binding(0) var ext: array; @compute @workgroup_size(8, 8, 1) fn main() { var c: coop_mat8x8 = coop_mat8x8(); - var d: coop_mat8x8; -coopLoad((&c), (&ext[4]), 8u) d = coopMultiplyAdd((&a), (&b), (&c)); -coopStore((&c), (&ext[0]), 8u) return; + coopLoad(c, (&ext[4]), 0u); + return; }