Skip to content

Commit

Permalink
der: use Reader<'a> as input for Decode::decode
Browse files Browse the repository at this point in the history
Implements decoding generically in terms of the `Reader` trait, similar
to what #611 did for encoding.

This approach can enable 1-pass on-the-fly PEM decoding for
`DecodeOwned` types (although that will require some additional work
beyond what's in this PR).
  • Loading branch information
tarcieri committed May 4, 2022
1 parent 53e2304 commit 8c3ab7a
Show file tree
Hide file tree
Showing 67 changed files with 785 additions and 781 deletions.
14 changes: 7 additions & 7 deletions der/derive/src/asn1_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ impl Asn1Type {
/// Get a `der::Decoder` object for a particular ASN.1 type
pub fn decoder(self) -> TokenStream {
match self {
Asn1Type::BitString => quote!(decoder.bit_string()?),
Asn1Type::Ia5String => quote!(decoder.ia5_string()?),
Asn1Type::GeneralizedTime => quote!(decoder.generalized_time()?),
Asn1Type::OctetString => quote!(decoder.octet_string()?),
Asn1Type::PrintableString => quote!(decoder.printable_string()?),
Asn1Type::UtcTime => quote!(decoder.utc_time()?),
Asn1Type::Utf8String => quote!(decoder.utf8_string()?),
Asn1Type::BitString => quote!(::der::asn1::BitString::decode(reader)?),
Asn1Type::Ia5String => quote!(::der::asn1::Ia5String::decode(reader)?),
Asn1Type::GeneralizedTime => quote!(::der::asn1::GeneralizedTime::decode(reader)?),
Asn1Type::OctetString => quote!(::der::asn1::OctetString::decode(reader)?),
Asn1Type::PrintableString => quote!(::der::asn1::PrintableString::decode(reader)?),
Asn1Type::UtcTime => quote!(::der::asn1::UtcTime::decode(reader)?),
Asn1Type::Utf8String => quote!(::der::asn1::Utf8String::decode(reader)?),
}
}

Expand Down
17 changes: 9 additions & 8 deletions der/derive/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ impl FieldAttrs {
pub fn parse(attrs: &[Attribute], type_attrs: &TypeAttrs) -> Self {
let mut asn1_type = None;
let mut context_specific = None;

let mut default = None;
let mut extensible = None;
let mut optional = None;
Expand Down Expand Up @@ -203,13 +202,13 @@ impl FieldAttrs {
if self.extensible || self.is_optional() {
quote! {
::der::asn1::ContextSpecific::<#type_params>::decode_explicit(
decoder,
reader,
#tag_number
)?
}
} else {
quote! {
match ::der::asn1::ContextSpecific::<#type_params>::decode(decoder)? {
match ::der::asn1::ContextSpecific::<#type_params>::decode(reader)? {
field if field.tag_number == #tag_number => Some(field),
_ => None
}
Expand All @@ -219,7 +218,7 @@ impl FieldAttrs {
TagMode::Implicit => {
quote! {
::der::asn1::ContextSpecific::<#type_params>::decode_implicit(
decoder,
reader,
#tag_number
)?
}
Expand All @@ -246,13 +245,15 @@ impl FieldAttrs {
}
} else if let Some(default) = &self.default {
let type_params = self.asn1_type.map(|ty| ty.type_path()).unwrap_or_default();
self.asn1_type.map(|ty| ty.decoder()).unwrap_or_else(
|| quote!(decoder.decode::<Option<#type_params>>()?.unwrap_or_else(#default)),
)
self.asn1_type.map(|ty| ty.decoder()).unwrap_or_else(|| {
quote! {
Option::<#type_params>::decode(reader)?.unwrap_or_else(#default),
}
})
} else {
self.asn1_type
.map(|ty| ty.decoder())
.unwrap_or_else(|| quote!(decoder.decode()?))
.unwrap_or_else(|| quote!(reader.decode()?))
}
}

Expand Down
13 changes: 6 additions & 7 deletions der/derive/src/choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
mod variant;

use self::variant::ChoiceVariant;
use crate::TypeAttrs;
use crate::{default_lifetime, TypeAttrs};
use proc_macro2::TokenStream;
use proc_macro_error::abort;
use quote::quote;
Expand Down Expand Up @@ -59,10 +59,9 @@ impl DeriveChoice {
pub fn to_tokens(&self) -> TokenStream {
let ident = &self.ident;

// Explicit lifetime or `'_`
let lifetime = match self.lifetime {
Some(ref lifetime) => quote!(#lifetime),
None => quote!('_),
None => default_lifetime(),
};

// Lifetime parameters
Expand All @@ -88,16 +87,16 @@ impl DeriveChoice {
}

quote! {
impl<#lt_params> ::der::Choice<#lifetime> for #ident<#lt_params> {
impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> {
fn can_decode(tag: ::der::Tag) -> bool {
matches!(tag, #(#can_decode_body)|*)
}
}

impl<#lt_params> ::der::Decode<#lifetime> for #ident<#lt_params> {
fn decode(decoder: &mut ::der::Decoder<#lifetime>) -> ::der::Result<Self> {
impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> {
fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
use der::Reader as _;
match decoder.peek_tag()? {
match reader.peek_tag()? {
#(#decode_body)*
actual => Err(der::ErrorKind::TagUnexpected {
expected: None,
Expand Down
8 changes: 4 additions & 4 deletions der/derive/src/choice/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ mod tests {
variant.to_decode_tokens().to_string(),
quote! {
::der::Tag::Utf8String => Ok(Self::ExampleVariant(
decoder.decode()?
reader.decode()?
)),
}
.to_string()
Expand Down Expand Up @@ -214,7 +214,7 @@ mod tests {
variant.to_decode_tokens().to_string(),
quote! {
::der::Tag::Utf8String => Ok(Self::ExampleVariant(
decoder.utf8_string()?
::der::asn1::Utf8String::decode(reader)?
.try_into()?
)),
}
Expand Down Expand Up @@ -273,7 +273,7 @@ mod tests {
constructed: #constructed,
number: #tag_number,
} => Ok(Self::ExplicitVariant(
match ::der::asn1::ContextSpecific::<>::decode(decoder)? {
match ::der::asn1::ContextSpecific::<>::decode(reader)? {
field if field.tag_number == #tag_number => Some(field),
_ => None
}
Expand Down Expand Up @@ -359,7 +359,7 @@ mod tests {
number: #tag_number,
} => Ok(Self::ImplicitVariant(
::der::asn1::ContextSpecific::<>::decode_implicit(
decoder,
reader,
#tag_number
)?
.ok_or_else(|| {
Expand Down
11 changes: 6 additions & 5 deletions der/derive/src/enumerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
//! enum variants.

use crate::ATTR_NAME;
use crate::{default_lifetime, ATTR_NAME};
use proc_macro2::TokenStream;
use proc_macro_error::abort;
use quote::quote;
Expand Down Expand Up @@ -102,6 +102,7 @@ impl DeriveEnumerated {

/// Lower the derived output into a [`TokenStream`].
pub fn to_tokens(&self) -> TokenStream {
let default_lifetime = default_lifetime();
let ident = &self.ident;
let repr = &self.repr;
let tag = match self.integer {
Expand All @@ -115,12 +116,12 @@ impl DeriveEnumerated {
}

quote! {
impl ::der::DecodeValue<'_> for #ident {
fn decode_value(
decoder: &mut ::der::Decoder<'_>,
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
fn decode_value<R: ::der::Reader<#default_lifetime>>(
reader: &mut R,
header: ::der::Header
) -> ::der::Result<Self> {
<#repr as ::der::DecodeValue>::decode_value(decoder, header)?.try_into()
<#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
}
}

Expand Down
10 changes: 9 additions & 1 deletion der/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,16 @@ use crate::{
value_ord::DeriveValueOrd,
};
use proc_macro::TokenStream;
use proc_macro2::Span;
use proc_macro_error::proc_macro_error;
use syn::{parse_macro_input, DeriveInput};
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Lifetime};

/// Get the default lifetime.
fn default_lifetime() -> proc_macro2::TokenStream {
let lifetime = Lifetime::new("'__der_lifetime", Span::call_site());
quote!(#lifetime)
}

/// Derive the [`Choice`][1] trait on an `enum`.
///
Expand Down
18 changes: 9 additions & 9 deletions der/derive/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

mod field;

use crate::TypeAttrs;
use crate::{default_lifetime, TypeAttrs};
use field::SequenceField;
use proc_macro2::TokenStream;
use proc_macro_error::abort;
Expand Down Expand Up @@ -59,10 +59,9 @@ impl DeriveSequence {
pub fn to_tokens(&self) -> TokenStream {
let ident = &self.ident;

// Explicit lifetime or `'_`
let lifetime = match self.lifetime {
Some(ref lifetime) => quote!(#lifetime),
None => quote!('_),
None => default_lifetime(),
};

// Lifetime parameters
Expand All @@ -84,13 +83,14 @@ impl DeriveSequence {
}

quote! {
impl<#lt_params> ::der::DecodeValue<#lifetime> for #ident<#lt_params> {
fn decode_value(
decoder: &mut ::der::Decoder<#lifetime>,
impl<#lifetime> ::der::DecodeValue<#lifetime> for #ident<#lt_params> {
fn decode_value<R: ::der::Reader<#lifetime>>(
reader: &mut R,
header: ::der::Header,
) -> ::der::Result<Self> {
use ::der::DecodeValue;
::der::asn1::SequenceRef::decode_value(decoder, header)?.decode_body(|decoder| {
use ::der::{Decode as _, DecodeValue as _, Reader as _};

reader.read_nested(header.length, |reader| {
#(#decode_body)*

Ok(Self {
Expand All @@ -100,7 +100,7 @@ impl DeriveSequence {
}
}

impl<#lt_params> ::der::Sequence<#lifetime> for #ident<#lt_params> {
impl<#lifetime> ::der::Sequence<#lifetime> for #ident<#lt_params> {
fn fields<F, T>(&self, f: F) -> ::der::Result<T>
where
F: FnOnce(&[&dyn der::Encode]) -> ::der::Result<T>,
Expand Down
8 changes: 4 additions & 4 deletions der/derive/src/sequence/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ impl LowerFieldDecoder {
/// Handle default value for a type.
fn apply_default(&mut self, default: &Path, field_type: &Type) {
self.decoder = quote! {
decoder.decode::<Option<#field_type>>()?.unwrap_or_else(#default);
}
Option::<#field_type>::decode(reader)?.unwrap_or_else(#default);
};
}
}

Expand Down Expand Up @@ -287,7 +287,7 @@ mod tests {
assert_eq!(
field.to_decode_tokens().to_string(),
quote! {
let example_field = decoder.decode()?;
let example_field = reader.decode()?;
}
.to_string()
);
Expand Down Expand Up @@ -328,7 +328,7 @@ mod tests {
field.to_decode_tokens().to_string(),
quote! {
let implicit_field = ::der::asn1::ContextSpecific::<>::decode_implicit(
decoder,
reader,
::der::TagNumber::N0
)?
.ok_or_else(|| {
Expand Down
19 changes: 3 additions & 16 deletions der/src/arrayvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,9 @@ impl<T, const N: usize> ArrayVec<T, N> {
self.length.checked_sub(1).and_then(|n| self.get(n))
}

/// Try to convert this [`ArrayVec`] into a `[T; N]`.
///
/// Returns `None` if the [`ArrayVec`] does not contain `N` elements.
pub fn try_into_array(self) -> Result<[T; N]> {
if self.length != N {
return Err(ErrorKind::Incomplete {
expected_len: N.try_into()?,
actual_len: self.length.try_into()?,
}
.into());
}

Ok(self.elements.map(|elem| match elem {
Some(e) => e,
None => unreachable!(),
}))
/// Extract the inner array.
pub fn into_array(self) -> [Option<T>; N] {
self.elements
}
}

Expand Down
9 changes: 5 additions & 4 deletions der/src/asn1/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::{
asn1::*, ByteSlice, Choice, Decode, DecodeValue, Decoder, DerOrd, EncodeValue, Error,
ErrorKind, FixedTag, Header, Length, Result, Tag, Tagged, ValueOrd, Writer,
ErrorKind, FixedTag, Header, Length, Reader, Result, Tag, Tagged, ValueOrd, Writer,
};
use core::cmp::Ordering;

Expand Down Expand Up @@ -153,11 +153,12 @@ impl<'a> Choice<'a> for Any<'a> {
}

impl<'a> Decode<'a> for Any<'a> {
fn decode(decoder: &mut Decoder<'a>) -> Result<Any<'a>> {
let header = Header::decode(decoder)?;
fn decode<R: Reader<'a>>(reader: &mut R) -> Result<Any<'a>> {
let header = Header::decode(reader)?;

Ok(Self {
tag: header.tag,
value: ByteSlice::decode_value(decoder, header)?,
value: ByteSlice::decode_value(reader, header)?,
})
}
}
Expand Down
18 changes: 9 additions & 9 deletions der/src/asn1/bit_string.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! ASN.1 `BIT STRING` support.

use crate::{
asn1::Any, ByteSlice, DecodeValue, Decoder, DerOrd, EncodeValue, Error, ErrorKind, FixedTag,
Header, Length, Reader, Result, Tag, ValueOrd, Writer,
asn1::Any, ByteSlice, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, FixedTag, Header,
Length, Reader, Result, Tag, ValueOrd, Writer,
};
use core::{cmp::Ordering, iter::FusedIterator};

Expand Down Expand Up @@ -116,14 +116,14 @@ impl<'a> BitString<'a> {
}

impl<'a> DecodeValue<'a> for BitString<'a> {
fn decode_value(decoder: &mut Decoder<'a>, header: Header) -> Result<Self> {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let header = Header {
tag: header.tag,
length: (header.length - Length::ONE)?,
};

let unused_bits = decoder.read_byte()?;
let inner = ByteSlice::decode_value(decoder, header)?;
let unused_bits = reader.read_byte()?;
let inner = ByteSlice::decode_value(reader, header)?;
Self::new(unused_bits, inner.as_slice())
}
}
Expand Down Expand Up @@ -239,12 +239,12 @@ where
T::Type: From<bool>,
T::Type: core::ops::Shl<usize, Output = T::Type>,
{
fn decode_value(decoder: &mut Decoder<'a>, header: Header) -> Result<Self> {
let position = decoder.position();

let bits = BitString::decode_value(decoder, header)?;
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let position = reader.position();
let bits = BitString::decode_value(reader, header)?;

let mut flags = T::none().bits();

if bits.bit_len() > core::mem::size_of_val(&flags) * 8 {
return Err(Error::new(ErrorKind::Overlength, position));
}
Expand Down
Loading

0 comments on commit 8c3ab7a

Please sign in to comment.