From c25b8e4e924fbafa1d0e3768482aa524e20e2cd3 Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Fri, 5 Aug 2022 10:14:08 +0200 Subject: [PATCH] Generate traits and impls for all `validstructs` on command parameters --- .github/workflows/ci.yml | 2 +- ash/src/extensions/ext/pipeline_properties.rs | 14 ++--- ash/src/vk/extensions.rs | 14 +++++ generator/src/lib.rs | 52 +++++++++++++++++-- 4 files changed, 71 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 697745623..78770f29f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,7 +64,7 @@ jobs: command: fmt args: -p ash - name: Diff autogen result - run: git diff --quiet || (echo "::error::Generated files are different, please regenerate with cargo run -p generator!"; git diff; false) + run: test -z "$(git status --porcelain)" || (echo "::error::Generated files are different, please regenerate with cargo run -p generator!"; git diff; false) test: name: Test Suite diff --git a/ash/src/extensions/ext/pipeline_properties.rs b/ash/src/extensions/ext/pipeline_properties.rs index 2ceed9be9..211d45f40 100644 --- a/ash/src/extensions/ext/pipeline_properties.rs +++ b/ash/src/extensions/ext/pipeline_properties.rs @@ -21,16 +21,18 @@ impl PipelineProperties { } /// - /// - /// Currently only accepts [`vk::PipelinePropertiesIdentifierEXT`]. #[inline] - pub unsafe fn get_pipeline_properties( + pub unsafe fn get_pipeline_properties( &self, pipeline_info: &vk::PipelineInfoEXT, - pipeline_properties: *mut vk::BaseOutStructure, + pipeline_properties: &mut impl vk::PFN_vkGetPipelinePropertiesEXT_param_p_pipeline_properties_validstructs, ) -> VkResult<()> { - (self.fp.get_pipeline_properties_ext)(self.handle, pipeline_info, pipeline_properties) - .result() + (self.fp.get_pipeline_properties_ext)( + self.handle, + pipeline_info, + pipeline_properties.as_base_out_struct(), + ) + .result() } pub const NAME: &CStr = vk::ExtPipelinePropertiesFn::NAME; diff --git a/ash/src/vk/extensions.rs b/ash/src/vk/extensions.rs index 1fbc1bbc5..7aba88a1a 100644 --- a/ash/src/vk/extensions.rs +++ b/ash/src/vk/extensions.rs @@ -17672,6 +17672,20 @@ impl ExtPipelinePropertiesFn { pub const SPEC_VERSION: u32 = 1u32; } #[allow(non_camel_case_types)] +#[doc = "Implemented for all types that can be passed as argument to `p_pipeline_properties` in [`PFN_vkGetPipelinePropertiesEXT`]"] +pub unsafe trait PFN_vkGetPipelinePropertiesEXT_param_p_pipeline_properties_validstructs { + unsafe fn as_base_in_struct(&self) -> *const BaseInStructure { + <*const _>::cast(self) + } + unsafe fn as_base_out_struct(&mut self) -> *mut BaseOutStructure { + <*mut _>::cast(self) + } +} +unsafe impl PFN_vkGetPipelinePropertiesEXT_param_p_pipeline_properties_validstructs + for PipelinePropertiesIdentifierEXT<'_> +{ +} +#[allow(non_camel_case_types)] pub type PFN_vkGetPipelinePropertiesEXT = unsafe extern "system" fn( device: Device, p_pipeline_info: *const PipelineInfoEXT, diff --git a/generator/src/lib.rs b/generator/src/lib.rs index 5624bd2b4..c647c6c98 100644 --- a/generator/src/lib.rs +++ b/generator/src/lib.rs @@ -911,6 +911,7 @@ fn generate_function_pointers<'a>( parameters: TokenStream, parameters_unused: TokenStream, returns: TokenStream, + parameter_validstructs: Vec<(Ident, Vec)>, } let commands = commands @@ -929,10 +930,13 @@ fn generate_function_pointers<'a>( function_name_c.strip_prefix("vk").unwrap().to_snake_case() ); - let params: Vec<_> = cmd + let params = cmd .params .iter() - .filter(|param| matches!(param.api.as_deref(), None | Some(DESIRED_API))) + .filter(|param| matches!(param.api.as_deref(), None | Some(DESIRED_API))); + + let params_tokens: Vec<_> = params + .clone() .map(|param| { let name = param.param_ident(); let ty = param.type_tokens(true); @@ -940,17 +944,22 @@ fn generate_function_pointers<'a>( }) .collect(); - let params_iter = params + let params_iter = params_tokens .iter() .map(|(param_name, param_ty)| quote!(#param_name: #param_ty)); let parameters = quote!(#(#params_iter,)*); - let params_iter = params.iter().map(|(param_name, param_ty)| { + let params_iter = params_tokens.iter().map(|(param_name, param_ty)| { let unused_name = format_ident!("_{}", param_name); quote!(#unused_name: #param_ty) }); let parameters_unused = quote!(#(#params_iter,)*); + let parameter_validstructs: Vec<_> = params + .filter(|param| !param.validstructs.is_empty()) + .map(|param| (param.param_ident(), param.validstructs.clone())) + .collect(); + let ret = cmd .proto .type_name @@ -972,10 +981,43 @@ fn generate_function_pointers<'a>( let ret_ty_tokens = name_to_tokens(ret); quote!(-> #ret_ty_tokens) }, + parameter_validstructs, } }) .collect::>(); + struct CommandToParamTraits<'a>(&'a Command); + impl<'a> quote::ToTokens for CommandToParamTraits<'a> { + fn to_tokens(&self, tokens: &mut TokenStream) { + for (param_ident, validstructs) in &self.0.parameter_validstructs { + let param_trait_name = + format_ident!("{}_param_{}_validstructs", self.0.type_name, param_ident); + let doc_string = format!( + "Implemented for all types that can be passed as argument to `{}` in [`{}`]", + param_ident, self.0.type_name + ); + quote! { + #[allow(non_camel_case_types)] + #[doc = #doc_string] + pub unsafe trait #param_trait_name { + unsafe fn as_base_in_struct(&self) -> *const BaseInStructure { + <*const _>::cast(self) + } + unsafe fn as_base_out_struct(&mut self) -> *mut BaseOutStructure { + <*mut _>::cast(self) + } + } + } + .to_tokens(tokens); + + for validstruct in validstructs { + let structname = name_to_tokens(validstruct); + quote!(unsafe impl #param_trait_name for #structname<'_> {}).to_tokens(tokens); + } + } + } + } + struct CommandToType<'a>(&'a Command); impl<'a> quote::ToTokens for CommandToType<'a> { fn to_tokens(&self, tokens: &mut TokenStream) { @@ -1034,6 +1076,7 @@ fn generate_function_pointers<'a>( } } + let param_traits = commands.iter().map(CommandToParamTraits); let pfn_typedefs = commands .iter() .filter(|pfn| pfn.type_needs_defining) @@ -1042,6 +1085,7 @@ fn generate_function_pointers<'a>( let loaders = commands.iter().map(CommandToLoader); quote! { + #(#param_traits)* #(#pfn_typedefs)* #[derive(Clone)]