diff --git a/zeroize/derive/Cargo.toml b/zeroize/derive/Cargo.toml index 303bb37c..baaacf13 100644 --- a/zeroize/derive/Cargo.toml +++ b/zeroize/derive/Cargo.toml @@ -17,7 +17,7 @@ proc-macro = true [dependencies] proc-macro2 = "1" quote = "1" -syn = {version = "2", features = ["full", "extra-traits"]} +syn = {version = "2", features = ["full", "extra-traits", "visit"]} [package.metadata.docs.rs] rustdoc-args = ["--document-private-items"] diff --git a/zeroize/derive/src/lib.rs b/zeroize/derive/src/lib.rs index 215bd785..2f31fc65 100644 --- a/zeroize/derive/src/lib.rs +++ b/zeroize/derive/src/lib.rs @@ -8,8 +8,10 @@ use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use syn::{ parse::{Parse, ParseStream}, + parse_quote, punctuated::Punctuated, token::Comma, + visit::Visit, Attribute, Data, DeriveInput, Expr, ExprLit, Field, Fields, Lit, Meta, Result, Variant, WherePredicate, }; @@ -36,12 +38,19 @@ pub fn derive_zeroize(input: proc_macro::TokenStream) -> proc_macro::TokenStream fn derive_zeroize_impl(input: DeriveInput) -> TokenStream { let attributes = ZeroizeAttrs::parse(&input); + let mut generics = input.generics.clone(); + let extra_bounds = match attributes.bound { Some(bounds) => bounds.0, - None => Default::default(), + None => attributes + .auto_params + .iter() + .map(|type_param| -> WherePredicate { + parse_quote! {#type_param: Zeroize} + }) + .collect(), }; - let mut generics = input.generics.clone(); generics.make_where_clause().predicates.extend(extra_bounds); let ty_name = &input.ident; @@ -117,6 +126,8 @@ struct ZeroizeAttrs { drop: bool, /// Custom bounds as defined by the user bound: Option, + /// Type parameters in use by fields + auto_params: Vec, } /// Parsing helper for custom bounds @@ -128,10 +139,37 @@ impl Parse for Bounds { } } +struct BoundAccumulator<'a> { + generics: &'a syn::Generics, + params: Vec, +} + +impl<'ast> Visit<'ast> for BoundAccumulator<'ast> { + fn visit_path(&mut self, path: &'ast syn::Path) { + if path.segments.len() != 1 { + return; + } + + if let Some(segment) = path.segments.first() { + for param in &self.generics.params { + if let syn::GenericParam::Type(type_param) = param { + if type_param.ident == segment.ident && !self.params.contains(&segment.ident) { + self.params.push(type_param.ident.clone()); + } + } + } + } + } +} + impl ZeroizeAttrs { /// Parse attributes from the incoming AST fn parse(input: &DeriveInput) -> Self { let mut result = Self::default(); + let mut bound_accumulator = BoundAccumulator { + generics: &input.generics, + params: Vec::new(), + }; for attr in &input.attrs { result.parse_attr(attr, None, None); @@ -147,6 +185,9 @@ impl ZeroizeAttrs { for attr in &field.attrs { result.parse_attr(attr, Some(variant), Some(field)); } + if !attr_skip(&field.attrs) { + bound_accumulator.visit_type(&field.ty); + } } } } @@ -155,11 +196,16 @@ impl ZeroizeAttrs { for attr in &field.attrs { result.parse_attr(attr, None, Some(field)); } + if !attr_skip(&field.attrs) { + bound_accumulator.visit_type(&field.ty); + } } } syn::Data::Union(union_) => panic!("Unsupported untagged union {:?}", union_), } + result.auto_params = bound_accumulator.params; + result } diff --git a/zeroize/tests/zeroize_derive.rs b/zeroize/tests/zeroize_derive.rs index 1277cc55..21d0cb52 100644 --- a/zeroize/tests/zeroize_derive.rs +++ b/zeroize/tests/zeroize_derive.rs @@ -351,3 +351,12 @@ fn derive_zeroize_with_marker() { trait Marker {} } + +#[test] +// Issue #878 +fn derive_zeroize_used_param() { + #[derive(Zeroize)] + struct Z { + used: T + } +}