diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs index 29cc4060993454..1bed53e0006d93 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs @@ -14,15 +14,10 @@ use syn::{ parse_macro_input, punctuated::Punctuated, token::{Comma, Paren, Where}, - Data, DataStruct, DeriveInput, Field, Fields, Generics, Ident, Index, Member, Meta, NestedMeta, - Path, + Attribute, Data, DataStruct, DeriveInput, Field, Fields, Generics, Ident, Index, Member, Meta, + NestedMeta, Path, }; -#[derive(Default)] -struct PropAttributeArgs { - pub ignore: Option, -} - #[derive(Clone)] enum TraitImpl { NotImplemented, @@ -46,11 +41,47 @@ enum DeriveType { static REFLECT_ATTRIBUTE_NAME: &str = "reflect"; static REFLECT_VALUE_ATTRIBUTE_NAME: &str = "reflect_value"; +fn active_fields(punctuated: &Punctuated) -> impl Iterator { + punctuated.iter().enumerate().filter_map(|(idx, field)| { + field + .attrs + .iter() + .find(|attr| attr.path.get_ident().unwrap() == REFLECT_ATTRIBUTE_NAME) + .map(|attr| { + syn::custom_keyword!(ignore); + attr.parse_args::>() + .expect("Invalid 'property' attribute format.") + .is_none() + }) + .unwrap_or(true) + .then(|| (field, idx)) + }) +} + +fn reflect_attrs(attrs: &[Attribute]) -> (ReflectAttrs, Option) { + for attribute in attrs.iter().filter_map(|attr| attr.parse_meta().ok()) { + if let Meta::List(meta_list) = attribute { + if let Some(ident) = meta_list.path.get_ident() { + if ident == REFLECT_ATTRIBUTE_NAME { + return (ReflectAttrs::from_nested_metas(&meta_list.nested), None); + } else if ident == REFLECT_VALUE_ATTRIBUTE_NAME { + return ( + ReflectAttrs::from_nested_metas(&meta_list.nested), + Some(DeriveType::Value), + ); + } + } + } + } + + Default::default() +} + #[proc_macro_derive(Reflect, attributes(reflect, reflect_value, module))] pub fn derive_reflect(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); let unit_struct_punctuated = Punctuated::new(); - let (fields, mut derive_type) = match &ast.data { + let (fields, derive_type) = match &ast.data { Data::Struct(DataStruct { fields: Fields::Named(fields), .. @@ -66,66 +97,14 @@ pub fn derive_reflect(input: TokenStream) -> TokenStream { _ => (&unit_struct_punctuated, DeriveType::Value), }; - let fields_and_args = fields - .iter() - .enumerate() - .map(|(i, f)| { - ( - f, - f.attrs - .iter() - .find(|a| *a.path.get_ident().as_ref().unwrap() == REFLECT_ATTRIBUTE_NAME) - .map(|a| { - syn::custom_keyword!(ignore); - let mut attribute_args = PropAttributeArgs { ignore: None }; - a.parse_args_with(|input: ParseStream| { - if input.parse::>()?.is_some() { - attribute_args.ignore = Some(true); - return Ok(()); - } - Ok(()) - }) - .expect("Invalid 'property' attribute format."); + let active_fields = active_fields(&fields).collect::>(); - attribute_args - }), - i, - ) - }) - .collect::, usize)>>(); - let active_fields = fields_and_args - .iter() - .filter(|(_field, attrs, _i)| { - attrs.is_none() - || match attrs.as_ref().unwrap().ignore { - Some(ignore) => !ignore, - None => true, - } - }) - .map(|(f, _attr, i)| (*f, *i)) - .collect::>(); + let (reflect_attrs, modified_derive_type) = reflect_attrs(&ast.attrs); + let derive_type = modified_derive_type.unwrap_or(derive_type); let bevy_reflect_path = BevyManifest::default().get_path("bevy_reflect"); let type_name = &ast.ident; - let mut reflect_attrs = ReflectAttrs::default(); - for attribute in ast.attrs.iter().filter_map(|attr| attr.parse_meta().ok()) { - let meta_list = if let Meta::List(meta_list) = attribute { - meta_list - } else { - continue; - }; - - if let Some(ident) = meta_list.path.get_ident() { - if ident == REFLECT_ATTRIBUTE_NAME { - reflect_attrs = ReflectAttrs::from_nested_metas(&meta_list.nested); - } else if ident == REFLECT_VALUE_ATTRIBUTE_NAME { - derive_type = DeriveType::Value; - reflect_attrs = ReflectAttrs::from_nested_metas(&meta_list.nested); - } - } - } - let registration_data = &reflect_attrs.data; let get_type_registration_impl = impl_get_type_registration( type_name,