Skip to content

Commit 96fc7ca

Browse files
committed
macros: move gen_sample_param_permutations macro to separate mod
1 parent 8f3a0b7 commit 96fc7ca

File tree

2 files changed

+207
-198
lines changed

2 files changed

+207
-198
lines changed

crates/spirv-std/macros/src/lib.rs

Lines changed: 3 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,13 @@
7373

7474
mod debug_printf;
7575
mod image;
76+
mod sample_param_permutations;
7677

7778
use crate::debug_printf::{DebugPrintfInput, debug_printf_inner};
7879
use proc_macro::TokenStream;
79-
use proc_macro2::{Delimiter, Group, Ident, Span, TokenTree};
80+
use proc_macro2::{Delimiter, Group, Ident, TokenTree};
8081
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
8182
use spirv_std_types::spirv_attr_version::spirv_attr_with_version;
82-
use syn::{ImplItemFn, visit_mut::VisitMut};
8383

8484
/// A macro for creating SPIR-V `OpTypeImage` types. Always produces a
8585
/// `spirv_std::image::Image<...>` type.
@@ -301,192 +301,6 @@ pub fn debug_printfln(input: TokenStream) -> TokenStream {
301301
debug_printf_inner(input)
302302
}
303303

304-
const SAMPLE_PARAM_COUNT: usize = 4;
305-
const SAMPLE_PARAM_GENERICS: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "G", "S"];
306-
const SAMPLE_PARAM_TYPES: [&str; SAMPLE_PARAM_COUNT] = ["B", "L", "(G,G)", "S"];
307-
const SAMPLE_PARAM_OPERANDS: [&str; SAMPLE_PARAM_COUNT] = ["Bias", "Lod", "Grad", "Sample"];
308-
const SAMPLE_PARAM_NAMES: [&str; SAMPLE_PARAM_COUNT] = ["bias", "lod", "grad", "sample_index"];
309-
const SAMPLE_PARAM_GRAD_INDEX: usize = 2; // Grad requires some special handling because it uses 2 arguments
310-
const SAMPLE_PARAM_EXPLICIT_LOD_MASK: usize = 0b0110; // which params require the use of ExplicitLod rather than ImplicitLod
311-
312-
fn is_grad(i: usize) -> bool {
313-
i == SAMPLE_PARAM_GRAD_INDEX
314-
}
315-
316-
struct SampleImplRewriter(usize, syn::Type);
317-
318-
impl SampleImplRewriter {
319-
fn rewrite(mask: usize, f: &syn::ItemImpl) -> syn::ItemImpl {
320-
let mut new_impl = f.clone();
321-
let mut ty_str = String::from("SampleParams<");
322-
323-
// based on the mask, form a `SampleParams` type string and add the generic parameters to the `impl<>` generics
324-
// example type string: `"SampleParams<SomeTy<B>, NoneTy, NoneTy>"`
325-
for i in 0..SAMPLE_PARAM_COUNT {
326-
if mask & (1 << i) != 0 {
327-
new_impl.generics.params.push(syn::GenericParam::Type(
328-
syn::Ident::new(SAMPLE_PARAM_GENERICS[i], Span::call_site()).into(),
329-
));
330-
ty_str.push_str("SomeTy<");
331-
ty_str.push_str(SAMPLE_PARAM_TYPES[i]);
332-
ty_str.push('>');
333-
} else {
334-
ty_str.push_str("NoneTy");
335-
}
336-
ty_str.push(',');
337-
}
338-
ty_str.push('>');
339-
let ty: syn::Type = syn::parse(ty_str.parse().unwrap()).unwrap();
340-
341-
// use the type to insert it into the generic argument of the trait we're implementing
342-
// e.g., `ImageWithMethods<Dummy>` becomes `ImageWithMethods<SampleParams<SomeTy<B>, NoneTy, NoneTy>>`
343-
if let Some(t) = &mut new_impl.trait_
344-
&& let syn::PathArguments::AngleBracketed(a) =
345-
&mut t.1.segments.last_mut().unwrap().arguments
346-
&& let Some(syn::GenericArgument::Type(t)) = a.args.last_mut()
347-
{
348-
*t = ty.clone();
349-
}
350-
351-
// rewrite the implemented functions
352-
SampleImplRewriter(mask, ty).visit_item_impl_mut(&mut new_impl);
353-
new_impl
354-
}
355-
356-
// generates an operands string for use in the assembly, e.g. "Bias %bias Lod %lod", based on the mask
357-
#[allow(clippy::needless_range_loop)]
358-
fn get_operands(&self) -> String {
359-
let mut op = String::new();
360-
for i in 0..SAMPLE_PARAM_COUNT {
361-
if self.0 & (1 << i) != 0 {
362-
if is_grad(i) {
363-
op.push_str("Grad %grad_x %grad_y ");
364-
} else {
365-
op.push_str(SAMPLE_PARAM_OPERANDS[i]);
366-
op.push_str(" %");
367-
op.push_str(SAMPLE_PARAM_NAMES[i]);
368-
op.push(' ');
369-
}
370-
}
371-
}
372-
op
373-
}
374-
375-
// generates list of assembly loads for the data, e.g. "%bias = OpLoad _ {bias}", etc.
376-
#[allow(clippy::needless_range_loop)]
377-
fn add_loads(&self, t: &mut Vec<TokenTree>) {
378-
for i in 0..SAMPLE_PARAM_COUNT {
379-
if self.0 & (1 << i) != 0 {
380-
if is_grad(i) {
381-
t.push(TokenTree::Literal(proc_macro2::Literal::string(
382-
"%grad_x = OpLoad _ {grad_x}",
383-
)));
384-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
385-
',',
386-
proc_macro2::Spacing::Alone,
387-
)));
388-
t.push(TokenTree::Literal(proc_macro2::Literal::string(
389-
"%grad_y = OpLoad _ {grad_y}",
390-
)));
391-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
392-
',',
393-
proc_macro2::Spacing::Alone,
394-
)));
395-
} else {
396-
let s = format!("%{0} = OpLoad _ {{{0}}}", SAMPLE_PARAM_NAMES[i]);
397-
t.push(TokenTree::Literal(proc_macro2::Literal::string(s.as_str())));
398-
t.push(TokenTree::Punct(proc_macro2::Punct::new(
399-
',',
400-
proc_macro2::Spacing::Alone,
401-
)));
402-
}
403-
}
404-
}
405-
}
406-
407-
// generates list of register specifications, e.g. `bias = in(reg) &params.bias.0, ...` as separate tokens
408-
#[allow(clippy::needless_range_loop)]
409-
fn add_regs(&self, t: &mut Vec<TokenTree>) {
410-
for i in 0..SAMPLE_PARAM_COUNT {
411-
if self.0 & (1 << i) != 0 {
412-
// HACK(eddyb) the extra `{...}` force the pointers to be to
413-
// fresh variables holding value copies, instead of the originals,
414-
// allowing `OpLoad _` inference to pick the appropriate type.
415-
let s = if is_grad(i) {
416-
"grad_x=in(reg) &{params.grad.0.0},grad_y=in(reg) &{params.grad.0.1},"
417-
.to_string()
418-
} else {
419-
format!("{0} = in(reg) &{{params.{0}.0}},", SAMPLE_PARAM_NAMES[i])
420-
};
421-
let ts: proc_macro2::TokenStream = s.parse().unwrap();
422-
t.extend(ts);
423-
}
424-
}
425-
}
426-
}
427-
428-
impl VisitMut for SampleImplRewriter {
429-
fn visit_impl_item_fn_mut(&mut self, item: &mut ImplItemFn) {
430-
// rewrite the last parameter of this method to be of type `SampleParams<...>` we generated earlier
431-
if let Some(syn::FnArg::Typed(p)) = item.sig.inputs.last_mut() {
432-
*p.ty.as_mut() = self.1.clone();
433-
}
434-
syn::visit_mut::visit_impl_item_fn_mut(self, item);
435-
}
436-
437-
fn visit_macro_mut(&mut self, m: &mut syn::Macro) {
438-
if m.path.is_ident("asm") {
439-
// this is where the asm! block is manipulated
440-
let t = m.tokens.clone();
441-
let mut new_t = Vec::new();
442-
let mut altered = false;
443-
444-
for tt in t {
445-
match tt {
446-
TokenTree::Literal(l) => {
447-
if let Ok(l) = syn::parse::<syn::LitStr>(l.to_token_stream().into()) {
448-
// found a string literal
449-
let s = l.value();
450-
if s.contains("$PARAMS") {
451-
altered = true;
452-
// add load instructions before the sampling instruction
453-
self.add_loads(&mut new_t);
454-
// and insert image operands
455-
let s = s.replace("$PARAMS", &self.get_operands());
456-
let lod_type = if self.0 & SAMPLE_PARAM_EXPLICIT_LOD_MASK != 0 {
457-
"ExplicitLod"
458-
} else {
459-
"ImplicitLod"
460-
};
461-
let s = s.replace("$LOD", lod_type);
462-
463-
new_t.push(TokenTree::Literal(proc_macro2::Literal::string(
464-
s.as_str(),
465-
)));
466-
} else {
467-
new_t.push(TokenTree::Literal(l.token()));
468-
}
469-
} else {
470-
new_t.push(TokenTree::Literal(l));
471-
}
472-
}
473-
_ => {
474-
new_t.push(tt);
475-
}
476-
}
477-
}
478-
479-
if altered {
480-
// finally, add register specs
481-
self.add_regs(&mut new_t);
482-
}
483-
484-
// replace all tokens within the asm! block with our new list
485-
m.tokens = new_t.into_iter().collect();
486-
}
487-
}
488-
}
489-
490304
/// Generates permutations of an `ImageWithMethods` implementation containing sampling functions
491305
/// that have asm instruction ending with a placeholder `$PARAMS` operand. The last parameter
492306
/// of each function must be named `params`, its type will be rewritten. Relevant generic
@@ -495,14 +309,5 @@ impl VisitMut for SampleImplRewriter {
495309
#[proc_macro_attribute]
496310
#[doc(hidden)]
497311
pub fn gen_sample_param_permutations(_attr: TokenStream, item: TokenStream) -> TokenStream {
498-
let item_impl = syn::parse_macro_input!(item as syn::ItemImpl);
499-
let mut fns = Vec::new();
500-
501-
for m in 1..(1 << SAMPLE_PARAM_COUNT) {
502-
fns.push(SampleImplRewriter::rewrite(m, &item_impl));
503-
}
504-
505-
// uncomment to output generated tokenstream to stdout
506-
//println!("{}", quote! { #(#fns)* }.to_string());
507-
quote! { #(#fns)* }.into()
312+
sample_param_permutations::gen_sample_param_permutations(item)
508313
}

0 commit comments

Comments
 (0)