Skip to content

Commit 24a31c4

Browse files
Only commit hopefully (#8140)
1 parent 57368b3 commit 24a31c4

File tree

1 file changed

+172
-154
lines changed

1 file changed

+172
-154
lines changed

wgpu-hal/src/metal/device.rs

Lines changed: 172 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -128,176 +128,194 @@ impl super::Device {
128128
primitive_class: MTLPrimitiveTopologyClass,
129129
naga_stage: naga::ShaderStage,
130130
) -> Result<CompiledShader, crate::PipelineError> {
131-
let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source {
132-
naga
133-
} else {
134-
panic!("load_shader required a naga shader");
135-
};
136-
let stage_bit = map_naga_stage(naga_stage);
137-
let (module, module_info) = naga::back::pipeline_constants::process_overrides(
138-
&naga_shader.module,
139-
&naga_shader.info,
140-
Some((naga_stage, stage.entry_point)),
141-
stage.constants,
142-
)
143-
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")))?;
144-
145-
let ep_resources = &layout.per_stage_map[naga_stage];
146-
147-
let bounds_check_policy = if stage.module.bounds_checks.bounds_checks {
148-
naga::proc::BoundsCheckPolicy::Restrict
149-
} else {
150-
naga::proc::BoundsCheckPolicy::Unchecked
151-
};
131+
match stage.module.source {
132+
ShaderModuleSource::Naga(ref naga_shader) => {
133+
let stage_bit = map_naga_stage(naga_stage);
134+
let (module, module_info) = naga::back::pipeline_constants::process_overrides(
135+
&naga_shader.module,
136+
&naga_shader.info,
137+
Some((naga_stage, stage.entry_point)),
138+
stage.constants,
139+
)
140+
.map_err(|e| {
141+
crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}"))
142+
})?;
152143

153-
let options = naga::back::msl::Options {
154-
lang_version: match self.shared.private_caps.msl_version {
155-
MTLLanguageVersion::V1_0 => (1, 0),
156-
MTLLanguageVersion::V1_1 => (1, 1),
157-
MTLLanguageVersion::V1_2 => (1, 2),
158-
MTLLanguageVersion::V2_0 => (2, 0),
159-
MTLLanguageVersion::V2_1 => (2, 1),
160-
MTLLanguageVersion::V2_2 => (2, 2),
161-
MTLLanguageVersion::V2_3 => (2, 3),
162-
MTLLanguageVersion::V2_4 => (2, 4),
163-
MTLLanguageVersion::V3_0 => (3, 0),
164-
MTLLanguageVersion::V3_1 => (3, 1),
165-
},
166-
inline_samplers: Default::default(),
167-
spirv_cross_compatibility: false,
168-
fake_missing_bindings: false,
169-
per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([(
170-
stage.entry_point.to_owned(),
171-
ep_resources.clone(),
172-
)]),
173-
bounds_check_policies: naga::proc::BoundsCheckPolicies {
174-
index: bounds_check_policy,
175-
buffer: bounds_check_policy,
176-
image_load: bounds_check_policy,
177-
// TODO: support bounds checks on binding arrays
178-
binding_array: naga::proc::BoundsCheckPolicy::Unchecked,
179-
},
180-
zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory,
181-
force_loop_bounding: stage.module.bounds_checks.force_loop_bounding,
182-
};
144+
let ep_resources = &layout.per_stage_map[naga_stage];
183145

184-
let pipeline_options = naga::back::msl::PipelineOptions {
185-
entry_point: Some((naga_stage, stage.entry_point.to_owned())),
186-
allow_and_force_point_size: match primitive_class {
187-
MTLPrimitiveTopologyClass::Point => true,
188-
_ => false,
189-
},
190-
vertex_pulling_transform: true,
191-
vertex_buffer_mappings: vertex_buffer_mappings.to_vec(),
192-
};
146+
let bounds_check_policy = if stage.module.bounds_checks.bounds_checks {
147+
naga::proc::BoundsCheckPolicy::Restrict
148+
} else {
149+
naga::proc::BoundsCheckPolicy::Unchecked
150+
};
193151

194-
let (source, info) =
195-
naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options)
196-
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?;
152+
let options = naga::back::msl::Options {
153+
lang_version: match self.shared.private_caps.msl_version {
154+
MTLLanguageVersion::V1_0 => (1, 0),
155+
MTLLanguageVersion::V1_1 => (1, 1),
156+
MTLLanguageVersion::V1_2 => (1, 2),
157+
MTLLanguageVersion::V2_0 => (2, 0),
158+
MTLLanguageVersion::V2_1 => (2, 1),
159+
MTLLanguageVersion::V2_2 => (2, 2),
160+
MTLLanguageVersion::V2_3 => (2, 3),
161+
MTLLanguageVersion::V2_4 => (2, 4),
162+
MTLLanguageVersion::V3_0 => (3, 0),
163+
MTLLanguageVersion::V3_1 => (3, 1),
164+
},
165+
inline_samplers: Default::default(),
166+
spirv_cross_compatibility: false,
167+
fake_missing_bindings: false,
168+
per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([(
169+
stage.entry_point.to_owned(),
170+
ep_resources.clone(),
171+
)]),
172+
bounds_check_policies: naga::proc::BoundsCheckPolicies {
173+
index: bounds_check_policy,
174+
buffer: bounds_check_policy,
175+
image_load: bounds_check_policy,
176+
// TODO: support bounds checks on binding arrays
177+
binding_array: naga::proc::BoundsCheckPolicy::Unchecked,
178+
},
179+
zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory,
180+
force_loop_bounding: stage.module.bounds_checks.force_loop_bounding,
181+
};
197182

198-
log::debug!(
199-
"Naga generated shader for entry point '{}' and stage {:?}\n{}",
200-
stage.entry_point,
201-
naga_stage,
202-
&source
203-
);
183+
let pipeline_options = naga::back::msl::PipelineOptions {
184+
entry_point: Some((naga_stage, stage.entry_point.to_owned())),
185+
allow_and_force_point_size: match primitive_class {
186+
MTLPrimitiveTopologyClass::Point => true,
187+
_ => false,
188+
},
189+
vertex_pulling_transform: true,
190+
vertex_buffer_mappings: vertex_buffer_mappings.to_vec(),
191+
};
204192

205-
let options = metal::CompileOptions::new();
206-
options.set_language_version(self.shared.private_caps.msl_version);
193+
let (source, info) = naga::back::msl::write_string(
194+
&module,
195+
&module_info,
196+
&options,
197+
&pipeline_options,
198+
)
199+
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?;
207200

208-
if self.shared.private_caps.supports_preserve_invariance {
209-
options.set_preserve_invariance(true);
210-
}
201+
log::debug!(
202+
"Naga generated shader for entry point '{}' and stage {:?}\n{}",
203+
stage.entry_point,
204+
naga_stage,
205+
&source
206+
);
211207

212-
let library = self
213-
.shared
214-
.device
215-
.lock()
216-
.new_library_with_source(source.as_ref(), &options)
217-
.map_err(|err| {
218-
log::warn!("Naga generated shader:\n{source}");
219-
crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}"))
220-
})?;
221-
222-
let ep_index = module
223-
.entry_points
224-
.iter()
225-
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
226-
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
227-
let ep = &module.entry_points[ep_index];
228-
let translated_ep_name = info.entry_point_names[0]
229-
.as_ref()
230-
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
231-
232-
let wg_size = MTLSize {
233-
width: ep.workgroup_size[0] as _,
234-
height: ep.workgroup_size[1] as _,
235-
depth: ep.workgroup_size[2] as _,
236-
};
208+
let options = metal::CompileOptions::new();
209+
options.set_language_version(self.shared.private_caps.msl_version);
237210

238-
let function = library
239-
.get_function(translated_ep_name, None)
240-
.map_err(|e| {
241-
log::error!("get_function: {e:?}");
242-
crate::PipelineError::EntryPoint(naga_stage)
243-
})?;
244-
245-
// collect sizes indices, immutable buffers, and work group memory sizes
246-
let ep_info = &module_info.get_entry_point(ep_index);
247-
let mut wg_memory_sizes = Vec::new();
248-
let mut sized_bindings = Vec::new();
249-
let mut immutable_buffer_mask = 0;
250-
for (var_handle, var) in module.global_variables.iter() {
251-
match var.space {
252-
naga::AddressSpace::WorkGroup => {
253-
if !ep_info[var_handle].is_empty() {
254-
let size = module.types[var.ty].inner.size(module.to_ctx());
255-
wg_memory_sizes.push(size);
256-
}
211+
if self.shared.private_caps.supports_preserve_invariance {
212+
options.set_preserve_invariance(true);
257213
}
258-
naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => {
259-
let br = match var.binding {
260-
Some(br) => br,
261-
None => continue,
262-
};
263-
let storage_access_store = match var.space {
264-
naga::AddressSpace::Storage { access } => {
265-
access.contains(naga::StorageAccess::STORE)
214+
215+
let library = self
216+
.shared
217+
.device
218+
.lock()
219+
.new_library_with_source(source.as_ref(), &options)
220+
.map_err(|err| {
221+
log::warn!("Naga generated shader:\n{source}");
222+
crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}"))
223+
})?;
224+
225+
let ep_index = module
226+
.entry_points
227+
.iter()
228+
.position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point)
229+
.ok_or(crate::PipelineError::EntryPoint(naga_stage))?;
230+
let ep = &module.entry_points[ep_index];
231+
let translated_ep_name = info.entry_point_names[0]
232+
.as_ref()
233+
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?;
234+
235+
let wg_size = MTLSize {
236+
width: ep.workgroup_size[0] as _,
237+
height: ep.workgroup_size[1] as _,
238+
depth: ep.workgroup_size[2] as _,
239+
};
240+
241+
let function = library
242+
.get_function(translated_ep_name, None)
243+
.map_err(|e| {
244+
log::error!("get_function: {e:?}");
245+
crate::PipelineError::EntryPoint(naga_stage)
246+
})?;
247+
248+
// collect sizes indices, immutable buffers, and work group memory sizes
249+
let ep_info = &module_info.get_entry_point(ep_index);
250+
let mut wg_memory_sizes = Vec::new();
251+
let mut sized_bindings = Vec::new();
252+
let mut immutable_buffer_mask = 0;
253+
for (var_handle, var) in module.global_variables.iter() {
254+
match var.space {
255+
naga::AddressSpace::WorkGroup => {
256+
if !ep_info[var_handle].is_empty() {
257+
let size = module.types[var.ty].inner.size(module.to_ctx());
258+
wg_memory_sizes.push(size);
259+
}
266260
}
267-
_ => false,
268-
};
261+
naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => {
262+
let br = match var.binding {
263+
Some(br) => br,
264+
None => continue,
265+
};
266+
let storage_access_store = match var.space {
267+
naga::AddressSpace::Storage { access } => {
268+
access.contains(naga::StorageAccess::STORE)
269+
}
270+
_ => false,
271+
};
269272

270-
// check for an immutable buffer
271-
if !ep_info[var_handle].is_empty() && !storage_access_store {
272-
let slot = ep_resources.resources[&br].buffer.unwrap();
273-
immutable_buffer_mask |= 1 << slot;
274-
}
273+
// check for an immutable buffer
274+
if !ep_info[var_handle].is_empty() && !storage_access_store {
275+
let slot = ep_resources.resources[&br].buffer.unwrap();
276+
immutable_buffer_mask |= 1 << slot;
277+
}
275278

276-
let mut dynamic_array_container_ty = var.ty;
277-
if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner
278-
{
279-
dynamic_array_container_ty = members.last().unwrap().ty;
280-
}
281-
if let naga::TypeInner::Array {
282-
size: naga::ArraySize::Dynamic,
283-
..
284-
} = module.types[dynamic_array_container_ty].inner
285-
{
286-
sized_bindings.push(br);
279+
let mut dynamic_array_container_ty = var.ty;
280+
if let naga::TypeInner::Struct { ref members, .. } =
281+
module.types[var.ty].inner
282+
{
283+
dynamic_array_container_ty = members.last().unwrap().ty;
284+
}
285+
if let naga::TypeInner::Array {
286+
size: naga::ArraySize::Dynamic,
287+
..
288+
} = module.types[dynamic_array_container_ty].inner
289+
{
290+
sized_bindings.push(br);
291+
}
292+
}
293+
_ => {}
287294
}
288295
}
289-
_ => {}
296+
297+
Ok(CompiledShader {
298+
library,
299+
function,
300+
wg_size,
301+
wg_memory_sizes,
302+
sized_bindings,
303+
immutable_buffer_mask,
304+
})
290305
}
306+
ShaderModuleSource::Passthrough(ref shader) => Ok(CompiledShader {
307+
library: shader.library.clone(),
308+
function: shader.function.clone(),
309+
wg_size: MTLSize {
310+
width: shader.num_workgroups.0 as u64,
311+
height: shader.num_workgroups.1 as u64,
312+
depth: shader.num_workgroups.2 as u64,
313+
},
314+
wg_memory_sizes: vec![],
315+
sized_bindings: vec![],
316+
immutable_buffer_mask: 0,
317+
}),
291318
}
292-
293-
Ok(CompiledShader {
294-
library,
295-
function,
296-
wg_size,
297-
wg_memory_sizes,
298-
sized_bindings,
299-
immutable_buffer_mask,
300-
})
301319
}
302320

303321
fn set_buffers_mutability(

0 commit comments

Comments
 (0)