@@ -6,10 +6,11 @@ use spirv::Word;
66use super :: {
77 block:: DebugInfoInner ,
88 helpers:: { contains_builtin, global_needs_wrapper, map_storage_class} ,
9- Block , BlockContext , CachedConstant , CachedExpressions , DebugInfo , EntryPointContext , Error ,
10- Function , FunctionArgument , GlobalVariable , IdGenerator , Instruction , LocalImageType ,
11- LocalType , LocalVariable , LogicalLayout , LookupFunctionType , LookupType , NumericType , Options ,
12- PhysicalLayout , PipelineOptions , ResultMember , Writer , WriterFlags , BITS_PER_BYTE ,
9+ Block , BlockContext , CachedConstant , CachedExpressions , CooperativeType , DebugInfo ,
10+ EntryPointContext , Error , Function , FunctionArgument , GlobalVariable , IdGenerator , Instruction ,
11+ LocalImageType , LocalType , LocalVariable , LogicalLayout , LookupFunctionType , LookupType ,
12+ NumericType , Options , PhysicalLayout , PipelineOptions , ResultMember , Writer , WriterFlags ,
13+ BITS_PER_BYTE ,
1314} ;
1415use crate :: {
1516 arena:: { Handle , HandleVec , UniqueArena } ,
@@ -375,6 +376,12 @@ impl Writer {
375376 } )
376377 }
377378
379+ pub ( super ) fn get_cooperative_type_id ( & mut self , scalar : crate :: CooperativeScalar ) -> Word {
380+ match scalar {
381+ crate :: CooperativeScalar :: F32 => self . get_f32_type_id ( ) ,
382+ }
383+ }
384+
378385 pub ( super ) fn get_f32_pointer_type_id ( & mut self , class : spirv:: StorageClass ) -> Word {
379386 let f32_id = self . get_f32_type_id ( ) ;
380387 self . get_pointer_type_id ( f32_id, class)
@@ -436,7 +443,9 @@ impl Writer {
436443 // these cases, so unwrap.
437444 LocalType :: Numeric ( NumericType :: from_inner ( inner) . unwrap ( ) )
438445 }
439- crate :: TypeInner :: CooperativeMatrix { .. } => return None ,
446+ crate :: TypeInner :: CooperativeMatrix { .. } => {
447+ LocalType :: Cooperative ( CooperativeType :: from_inner ( inner) . unwrap ( ) )
448+ }
440449 crate :: TypeInner :: Pointer { base, space } => {
441450 let base_type_id = self . get_handle_type_id ( base) ;
442451 LocalType :: Pointer {
@@ -1353,6 +1362,14 @@ impl Writer {
13531362 self . require_any ( "16 bit floating-point" , & [ spirv:: Capability :: Float16 ] ) ?;
13541363 self . use_extension ( "SPV_KHR_16bit_storage" ) ;
13551364 }
1365+ // Cooperative types and ops
1366+ crate :: TypeInner :: CooperativeMatrix { .. } => {
1367+ self . require_any (
1368+ "cooperative matrix" ,
1369+ & [ spirv:: Capability :: CooperativeMatrixKHR ] ,
1370+ ) ?;
1371+ self . use_extension ( "SPV_KHR_cooperative_matrix" ) ;
1372+ }
13561373 _ => { }
13571374 }
13581375 Ok ( ( ) )
@@ -1379,12 +1396,31 @@ impl Writer {
13791396 instruction. to_words ( & mut self . logical_layout . declarations ) ;
13801397 }
13811398
1399+ fn write_cooperative_type_declaration_local ( & mut self , id : Word , coop : CooperativeType ) {
1400+ let instruction = match coop {
1401+ CooperativeType :: Matrix {
1402+ columns,
1403+ rows,
1404+ scalar,
1405+ } => {
1406+ let scalar_id = self . get_cooperative_type_id ( scalar) ;
1407+ Instruction :: type_coop_matrix ( id, scalar_id, rows, columns)
1408+ }
1409+ } ;
1410+
1411+ instruction. to_words ( & mut self . logical_layout . declarations ) ;
1412+ }
1413+
13821414 fn write_type_declaration_local ( & mut self , id : Word , local_ty : LocalType ) {
13831415 let instruction = match local_ty {
13841416 LocalType :: Numeric ( numeric) => {
13851417 self . write_numeric_type_declaration_local ( id, numeric) ;
13861418 return ;
13871419 }
1420+ LocalType :: Cooperative ( coop) => {
1421+ self . write_cooperative_type_declaration_local ( id, coop) ;
1422+ return ;
1423+ }
13881424 LocalType :: Pointer { base, class } => Instruction :: type_pointer ( id, class, base) ,
13891425 LocalType :: Image ( image) => {
13901426 let local_type = LocalType :: Numeric ( NumericType :: Scalar ( image. sampled_type ) ) ;
0 commit comments