diff --git a/Cargo.toml b/Cargo.toml index 45caf3c4b..bccc290ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bech32" -version = "0.4.1" +version = "0.5.0" authors = ["Clark Moody"] repository = "https://github.com/rust-bitcoin/rust-bech32" description = "Encodes and decodes the Bech32 format" diff --git a/src/lib.rs b/src/lib.rs index 1db25aa2a..1d754fbf8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,8 +63,9 @@ #![deny(unused_mut)] use std::{error, fmt}; -use std::str::FromStr; +use std::ascii::AsciiExt; use std::fmt::{Display, Formatter}; +use std::str::FromStr; /// Integer in the range `0..32` #[derive(PartialEq, Eq, Debug, Copy, Clone, Default, PartialOrd, Ord, Hash)] @@ -225,7 +226,7 @@ impl Bech32 { for b in raw_hrp.bytes() { // Valid subset of ASCII if b < 33 || b > 126 { - return Err(Error::InvalidChar(b)) + return Err(Error::InvalidChar(b as char)) } let mut c = b; // Lowercase @@ -242,34 +243,28 @@ impl Bech32 { } // Check data payload - let mut data_bytes: Vec = Vec::new(); - for b in raw_data.bytes() { - // Alphanumeric only - if !((b >= b'0' && b <= b'9') || (b >= b'A' && b <= b'Z') || (b >= b'a' && b <= b'z')) { - return Err(Error::InvalidChar(b)) - } - // Excludes these characters: [1,b,i,o] - if b == b'1' || b == b'b' || b == b'i' || b == b'o' { - return Err(Error::InvalidChar(b)) + let mut data_bytes = raw_data.chars().map(|c| { + // Only check if c is in the ASCII range, all invalid ASCII characters have the value -1 + // in CHARSET_REV (which covers the whole ASCII range) and will be filtered out later. + if !c.is_ascii() { + return Err(Error::InvalidChar(c)) } - // Lowercase - if b >= b'a' && b <= b'z' { + + if c.is_lowercase() { has_lower = true; + } else if c.is_uppercase() { + has_upper = true; } - // Uppercase - let c = if b >= b'A' && b <= b'Z' { - has_upper = true; - // Convert to lowercase - b + (b'a'-b'A') - } else { - b - }; - - data_bytes.push(u5::try_from_u8(CHARSET_REV[c as usize] as u8).expect( - "range was already checked above" - )); - } + // c should be <128 since it is in the ASCII range, CHARSET_REV.len() == 128 + let num_value = CHARSET_REV[c as usize]; + + if num_value > 31 || num_value < 0 { + return Err(Error::InvalidChar(c)); + } + + Ok(u5::try_from_u8(num_value as u8).expect("range checked above, num_value <= 31")) + }).collect::, Error>>()?; // Ensure no mixed case if has_lower && has_upper { @@ -402,7 +397,7 @@ pub enum Error { /// The data or human-readable part is too long or too short InvalidLength, /// Some part of the string contains an invalid character - InvalidChar(u8), + InvalidChar(char), /// Some part of the data has an invalid value InvalidData(u8), /// The bit conversion failed due to a padding issue @@ -545,9 +540,9 @@ mod tests { fn invalid_strings() { let pairs: Vec<(&str, Error)> = vec!( (" 1nwldj5", - Error::InvalidChar(b' ')), - ("\x7f1axkwrx", - Error::InvalidChar(0x7f)), + Error::InvalidChar(' ')), + ("abc1\u{2192}axkwrx", + Error::InvalidChar('\u{2192}')), ("an84characterslonghumanreadablepartthatcontainsthenumber1andtheexcludedcharactersbio1569pvx", Error::InvalidLength), ("pzry9x0s0muk", @@ -555,11 +550,13 @@ mod tests { ("1pzry9x0s0muk", Error::InvalidLength), ("x1b4n0q5v", - Error::InvalidChar(b'b')), + Error::InvalidChar('b')), + ("ABC1DEFGOH", + Error::InvalidChar('O')), ("li1dgmt3", Error::InvalidLength), ("de1lg7wt\u{ff}", - Error::InvalidChar(0xc3)), // ASCII 0xff -> \uC3BF in UTF-8 + Error::InvalidChar('\u{ff}')), ); for p in pairs { let (s, expected_error) = p; @@ -568,7 +565,7 @@ mod tests { println!("{:?}", dec_result.unwrap()); panic!("Should be invalid: {:?}", s); } - assert_eq!(dec_result.unwrap_err(), expected_error); + assert_eq!(dec_result.unwrap_err(), expected_error, "testing input '{}'", s); } } @@ -655,4 +652,24 @@ mod tests { use ToBase32; assert_eq!([0xffu8].to_base32(), [0x1f, 0x1c].check_base32().unwrap()); } + + #[test] + fn reverse_charset() { + use std::ascii::AsciiExt; + use ::CHARSET_REV; + + fn get_char_value(c: char) -> i8 { + let charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"; + match charset.find(c.to_ascii_lowercase()) { + Some(x) => x as i8, + None => -1, + } + } + + let expected_rev_charset = (0u8..128).map(|i| { + get_char_value(i as char) + }).collect::>(); + + assert_eq!(&(CHARSET_REV[..]), expected_rev_charset.as_slice()); + } }