@@ -128,176 +128,194 @@ impl super::Device {
128128 primitive_class : MTLPrimitiveTopologyClass ,
129129 naga_stage : naga:: ShaderStage ,
130130 ) -> Result < CompiledShader , crate :: PipelineError > {
131- let naga_shader = if let ShaderModuleSource :: Naga ( naga) = & stage. module . source {
132- naga
133- } else {
134- panic ! ( "load_shader required a naga shader" ) ;
135- } ;
136- let stage_bit = map_naga_stage ( naga_stage) ;
137- let ( module, module_info) = naga:: back:: pipeline_constants:: process_overrides (
138- & naga_shader. module ,
139- & naga_shader. info ,
140- Some ( ( naga_stage, stage. entry_point ) ) ,
141- stage. constants ,
142- )
143- . map_err ( |e| crate :: PipelineError :: PipelineConstants ( stage_bit, format ! ( "MSL: {e:?}" ) ) ) ?;
144-
145- let ep_resources = & layout. per_stage_map [ naga_stage] ;
146-
147- let bounds_check_policy = if stage. module . bounds_checks . bounds_checks {
148- naga:: proc:: BoundsCheckPolicy :: Restrict
149- } else {
150- naga:: proc:: BoundsCheckPolicy :: Unchecked
151- } ;
131+ match stage. module . source {
132+ ShaderModuleSource :: Naga ( ref naga_shader) => {
133+ let stage_bit = map_naga_stage ( naga_stage) ;
134+ let ( module, module_info) = naga:: back:: pipeline_constants:: process_overrides (
135+ & naga_shader. module ,
136+ & naga_shader. info ,
137+ Some ( ( naga_stage, stage. entry_point ) ) ,
138+ stage. constants ,
139+ )
140+ . map_err ( |e| {
141+ crate :: PipelineError :: PipelineConstants ( stage_bit, format ! ( "MSL: {e:?}" ) )
142+ } ) ?;
152143
153- let options = naga:: back:: msl:: Options {
154- lang_version : match self . shared . private_caps . msl_version {
155- MTLLanguageVersion :: V1_0 => ( 1 , 0 ) ,
156- MTLLanguageVersion :: V1_1 => ( 1 , 1 ) ,
157- MTLLanguageVersion :: V1_2 => ( 1 , 2 ) ,
158- MTLLanguageVersion :: V2_0 => ( 2 , 0 ) ,
159- MTLLanguageVersion :: V2_1 => ( 2 , 1 ) ,
160- MTLLanguageVersion :: V2_2 => ( 2 , 2 ) ,
161- MTLLanguageVersion :: V2_3 => ( 2 , 3 ) ,
162- MTLLanguageVersion :: V2_4 => ( 2 , 4 ) ,
163- MTLLanguageVersion :: V3_0 => ( 3 , 0 ) ,
164- MTLLanguageVersion :: V3_1 => ( 3 , 1 ) ,
165- } ,
166- inline_samplers : Default :: default ( ) ,
167- spirv_cross_compatibility : false ,
168- fake_missing_bindings : false ,
169- per_entry_point_map : naga:: back:: msl:: EntryPointResourceMap :: from ( [ (
170- stage. entry_point . to_owned ( ) ,
171- ep_resources. clone ( ) ,
172- ) ] ) ,
173- bounds_check_policies : naga:: proc:: BoundsCheckPolicies {
174- index : bounds_check_policy,
175- buffer : bounds_check_policy,
176- image_load : bounds_check_policy,
177- // TODO: support bounds checks on binding arrays
178- binding_array : naga:: proc:: BoundsCheckPolicy :: Unchecked ,
179- } ,
180- zero_initialize_workgroup_memory : stage. zero_initialize_workgroup_memory ,
181- force_loop_bounding : stage. module . bounds_checks . force_loop_bounding ,
182- } ;
144+ let ep_resources = & layout. per_stage_map [ naga_stage] ;
183145
184- let pipeline_options = naga:: back:: msl:: PipelineOptions {
185- entry_point : Some ( ( naga_stage, stage. entry_point . to_owned ( ) ) ) ,
186- allow_and_force_point_size : match primitive_class {
187- MTLPrimitiveTopologyClass :: Point => true ,
188- _ => false ,
189- } ,
190- vertex_pulling_transform : true ,
191- vertex_buffer_mappings : vertex_buffer_mappings. to_vec ( ) ,
192- } ;
146+ let bounds_check_policy = if stage. module . bounds_checks . bounds_checks {
147+ naga:: proc:: BoundsCheckPolicy :: Restrict
148+ } else {
149+ naga:: proc:: BoundsCheckPolicy :: Unchecked
150+ } ;
193151
194- let ( source, info) =
195- naga:: back:: msl:: write_string ( & module, & module_info, & options, & pipeline_options)
196- . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "MSL: {e:?}" ) ) ) ?;
152+ let options = naga:: back:: msl:: Options {
153+ lang_version : match self . shared . private_caps . msl_version {
154+ MTLLanguageVersion :: V1_0 => ( 1 , 0 ) ,
155+ MTLLanguageVersion :: V1_1 => ( 1 , 1 ) ,
156+ MTLLanguageVersion :: V1_2 => ( 1 , 2 ) ,
157+ MTLLanguageVersion :: V2_0 => ( 2 , 0 ) ,
158+ MTLLanguageVersion :: V2_1 => ( 2 , 1 ) ,
159+ MTLLanguageVersion :: V2_2 => ( 2 , 2 ) ,
160+ MTLLanguageVersion :: V2_3 => ( 2 , 3 ) ,
161+ MTLLanguageVersion :: V2_4 => ( 2 , 4 ) ,
162+ MTLLanguageVersion :: V3_0 => ( 3 , 0 ) ,
163+ MTLLanguageVersion :: V3_1 => ( 3 , 1 ) ,
164+ } ,
165+ inline_samplers : Default :: default ( ) ,
166+ spirv_cross_compatibility : false ,
167+ fake_missing_bindings : false ,
168+ per_entry_point_map : naga:: back:: msl:: EntryPointResourceMap :: from ( [ (
169+ stage. entry_point . to_owned ( ) ,
170+ ep_resources. clone ( ) ,
171+ ) ] ) ,
172+ bounds_check_policies : naga:: proc:: BoundsCheckPolicies {
173+ index : bounds_check_policy,
174+ buffer : bounds_check_policy,
175+ image_load : bounds_check_policy,
176+ // TODO: support bounds checks on binding arrays
177+ binding_array : naga:: proc:: BoundsCheckPolicy :: Unchecked ,
178+ } ,
179+ zero_initialize_workgroup_memory : stage. zero_initialize_workgroup_memory ,
180+ force_loop_bounding : stage. module . bounds_checks . force_loop_bounding ,
181+ } ;
197182
198- log:: debug!(
199- "Naga generated shader for entry point '{}' and stage {:?}\n {}" ,
200- stage. entry_point,
201- naga_stage,
202- & source
203- ) ;
183+ let pipeline_options = naga:: back:: msl:: PipelineOptions {
184+ entry_point : Some ( ( naga_stage, stage. entry_point . to_owned ( ) ) ) ,
185+ allow_and_force_point_size : match primitive_class {
186+ MTLPrimitiveTopologyClass :: Point => true ,
187+ _ => false ,
188+ } ,
189+ vertex_pulling_transform : true ,
190+ vertex_buffer_mappings : vertex_buffer_mappings. to_vec ( ) ,
191+ } ;
204192
205- let options = metal:: CompileOptions :: new ( ) ;
206- options. set_language_version ( self . shared . private_caps . msl_version ) ;
193+ let ( source, info) = naga:: back:: msl:: write_string (
194+ & module,
195+ & module_info,
196+ & options,
197+ & pipeline_options,
198+ )
199+ . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "MSL: {e:?}" ) ) ) ?;
207200
208- if self . shared . private_caps . supports_preserve_invariance {
209- options. set_preserve_invariance ( true ) ;
210- }
201+ log:: debug!(
202+ "Naga generated shader for entry point '{}' and stage {:?}\n {}" ,
203+ stage. entry_point,
204+ naga_stage,
205+ & source
206+ ) ;
211207
212- let library = self
213- . shared
214- . device
215- . lock ( )
216- . new_library_with_source ( source. as_ref ( ) , & options)
217- . map_err ( |err| {
218- log:: warn!( "Naga generated shader:\n {source}" ) ;
219- crate :: PipelineError :: Linkage ( stage_bit, format ! ( "Metal: {err}" ) )
220- } ) ?;
221-
222- let ep_index = module
223- . entry_points
224- . iter ( )
225- . position ( |ep| ep. stage == naga_stage && ep. name == stage. entry_point )
226- . ok_or ( crate :: PipelineError :: EntryPoint ( naga_stage) ) ?;
227- let ep = & module. entry_points [ ep_index] ;
228- let translated_ep_name = info. entry_point_names [ 0 ]
229- . as_ref ( )
230- . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "{e}" ) ) ) ?;
231-
232- let wg_size = MTLSize {
233- width : ep. workgroup_size [ 0 ] as _ ,
234- height : ep. workgroup_size [ 1 ] as _ ,
235- depth : ep. workgroup_size [ 2 ] as _ ,
236- } ;
208+ let options = metal:: CompileOptions :: new ( ) ;
209+ options. set_language_version ( self . shared . private_caps . msl_version ) ;
237210
238- let function = library
239- . get_function ( translated_ep_name, None )
240- . map_err ( |e| {
241- log:: error!( "get_function: {e:?}" ) ;
242- crate :: PipelineError :: EntryPoint ( naga_stage)
243- } ) ?;
244-
245- // collect sizes indices, immutable buffers, and work group memory sizes
246- let ep_info = & module_info. get_entry_point ( ep_index) ;
247- let mut wg_memory_sizes = Vec :: new ( ) ;
248- let mut sized_bindings = Vec :: new ( ) ;
249- let mut immutable_buffer_mask = 0 ;
250- for ( var_handle, var) in module. global_variables . iter ( ) {
251- match var. space {
252- naga:: AddressSpace :: WorkGroup => {
253- if !ep_info[ var_handle] . is_empty ( ) {
254- let size = module. types [ var. ty ] . inner . size ( module. to_ctx ( ) ) ;
255- wg_memory_sizes. push ( size) ;
256- }
211+ if self . shared . private_caps . supports_preserve_invariance {
212+ options. set_preserve_invariance ( true ) ;
257213 }
258- naga:: AddressSpace :: Uniform | naga:: AddressSpace :: Storage { .. } => {
259- let br = match var. binding {
260- Some ( br) => br,
261- None => continue ,
262- } ;
263- let storage_access_store = match var. space {
264- naga:: AddressSpace :: Storage { access } => {
265- access. contains ( naga:: StorageAccess :: STORE )
214+
215+ let library = self
216+ . shared
217+ . device
218+ . lock ( )
219+ . new_library_with_source ( source. as_ref ( ) , & options)
220+ . map_err ( |err| {
221+ log:: warn!( "Naga generated shader:\n {source}" ) ;
222+ crate :: PipelineError :: Linkage ( stage_bit, format ! ( "Metal: {err}" ) )
223+ } ) ?;
224+
225+ let ep_index = module
226+ . entry_points
227+ . iter ( )
228+ . position ( |ep| ep. stage == naga_stage && ep. name == stage. entry_point )
229+ . ok_or ( crate :: PipelineError :: EntryPoint ( naga_stage) ) ?;
230+ let ep = & module. entry_points [ ep_index] ;
231+ let translated_ep_name = info. entry_point_names [ 0 ]
232+ . as_ref ( )
233+ . map_err ( |e| crate :: PipelineError :: Linkage ( stage_bit, format ! ( "{e}" ) ) ) ?;
234+
235+ let wg_size = MTLSize {
236+ width : ep. workgroup_size [ 0 ] as _ ,
237+ height : ep. workgroup_size [ 1 ] as _ ,
238+ depth : ep. workgroup_size [ 2 ] as _ ,
239+ } ;
240+
241+ let function = library
242+ . get_function ( translated_ep_name, None )
243+ . map_err ( |e| {
244+ log:: error!( "get_function: {e:?}" ) ;
245+ crate :: PipelineError :: EntryPoint ( naga_stage)
246+ } ) ?;
247+
248+ // collect sizes indices, immutable buffers, and work group memory sizes
249+ let ep_info = & module_info. get_entry_point ( ep_index) ;
250+ let mut wg_memory_sizes = Vec :: new ( ) ;
251+ let mut sized_bindings = Vec :: new ( ) ;
252+ let mut immutable_buffer_mask = 0 ;
253+ for ( var_handle, var) in module. global_variables . iter ( ) {
254+ match var. space {
255+ naga:: AddressSpace :: WorkGroup => {
256+ if !ep_info[ var_handle] . is_empty ( ) {
257+ let size = module. types [ var. ty ] . inner . size ( module. to_ctx ( ) ) ;
258+ wg_memory_sizes. push ( size) ;
259+ }
266260 }
267- _ => false ,
268- } ;
261+ naga:: AddressSpace :: Uniform | naga:: AddressSpace :: Storage { .. } => {
262+ let br = match var. binding {
263+ Some ( br) => br,
264+ None => continue ,
265+ } ;
266+ let storage_access_store = match var. space {
267+ naga:: AddressSpace :: Storage { access } => {
268+ access. contains ( naga:: StorageAccess :: STORE )
269+ }
270+ _ => false ,
271+ } ;
269272
270- // check for an immutable buffer
271- if !ep_info[ var_handle] . is_empty ( ) && !storage_access_store {
272- let slot = ep_resources. resources [ & br] . buffer . unwrap ( ) ;
273- immutable_buffer_mask |= 1 << slot;
274- }
273+ // check for an immutable buffer
274+ if !ep_info[ var_handle] . is_empty ( ) && !storage_access_store {
275+ let slot = ep_resources. resources [ & br] . buffer . unwrap ( ) ;
276+ immutable_buffer_mask |= 1 << slot;
277+ }
275278
276- let mut dynamic_array_container_ty = var. ty ;
277- if let naga:: TypeInner :: Struct { ref members, .. } = module. types [ var. ty ] . inner
278- {
279- dynamic_array_container_ty = members. last ( ) . unwrap ( ) . ty ;
280- }
281- if let naga:: TypeInner :: Array {
282- size : naga:: ArraySize :: Dynamic ,
283- ..
284- } = module. types [ dynamic_array_container_ty] . inner
285- {
286- sized_bindings. push ( br) ;
279+ let mut dynamic_array_container_ty = var. ty ;
280+ if let naga:: TypeInner :: Struct { ref members, .. } =
281+ module. types [ var. ty ] . inner
282+ {
283+ dynamic_array_container_ty = members. last ( ) . unwrap ( ) . ty ;
284+ }
285+ if let naga:: TypeInner :: Array {
286+ size : naga:: ArraySize :: Dynamic ,
287+ ..
288+ } = module. types [ dynamic_array_container_ty] . inner
289+ {
290+ sized_bindings. push ( br) ;
291+ }
292+ }
293+ _ => { }
287294 }
288295 }
289- _ => { }
296+
297+ Ok ( CompiledShader {
298+ library,
299+ function,
300+ wg_size,
301+ wg_memory_sizes,
302+ sized_bindings,
303+ immutable_buffer_mask,
304+ } )
290305 }
306+ ShaderModuleSource :: Passthrough ( ref shader) => Ok ( CompiledShader {
307+ library : shader. library . clone ( ) ,
308+ function : shader. function . clone ( ) ,
309+ wg_size : MTLSize {
310+ width : shader. num_workgroups . 0 as u64 ,
311+ height : shader. num_workgroups . 1 as u64 ,
312+ depth : shader. num_workgroups . 2 as u64 ,
313+ } ,
314+ wg_memory_sizes : vec ! [ ] ,
315+ sized_bindings : vec ! [ ] ,
316+ immutable_buffer_mask : 0 ,
317+ } ) ,
291318 }
292-
293- Ok ( CompiledShader {
294- library,
295- function,
296- wg_size,
297- wg_memory_sizes,
298- sized_bindings,
299- immutable_buffer_mask,
300- } )
301319 }
302320
303321 fn set_buffers_mutability (
0 commit comments