From 5c13966f0cd08c963fa729f2c26e9943a13ede5d Mon Sep 17 00:00:00 2001 From: Andre Bogus Date: Sun, 27 Aug 2023 02:02:23 +0200 Subject: [PATCH] add aarch64 --- src/lib.rs | 35 +++++++++-- src/simd/aarch64.rs | 139 ++++++++++++++++++++++++++++++++++++++++++++ src/simd/mod.rs | 4 ++ 3 files changed, 172 insertions(+), 6 deletions(-) create mode 100644 src/simd/aarch64.rs diff --git a/src/lib.rs b/src/lib.rs index ef4235c..24f4018 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,6 @@ //! still on small strings. #![deny(missing_docs)] - #![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)] #[cfg(not(feature = "runtime-dispatch-simd"))] @@ -45,7 +44,11 @@ pub use naive::*; mod integer_simd; #[cfg(any( - all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")), + all( + feature = "runtime-dispatch-simd", + any(target_arch = "x86", target_arch = "x86_64") + ), + target_arch = "aarch64", feature = "generic-simd" ))] mod simd; @@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize { #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { - unsafe { return simd::x86_avx2::chunk_count(haystack, needle); } + unsafe { + return simd::x86_avx2::chunk_count(haystack, needle); + } } } @@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize { ))] { if is_x86_feature_detected!("sse2") { - unsafe { return simd::x86_sse2::chunk_count(haystack, needle); } + unsafe { + return simd::x86_sse2::chunk_count(haystack, needle); + } + } + } + #[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))] + { + unsafe { + return simd::aarch64::chunk_count(haystack, needle); } } } @@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize { #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] { if is_x86_feature_detected!("avx2") { - unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); } + unsafe { + return simd::x86_avx2::chunk_num_chars(utf8_chars); + } } } @@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize { ))] { if is_x86_feature_detected!("sse2") { - unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); } + unsafe { + return simd::x86_sse2::chunk_num_chars(utf8_chars); + } + } + } + #[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))] + { + unsafe { + return simd::aarch64::chunk_num_chars(utf8_chars); } } } diff --git a/src/simd/aarch64.rs b/src/simd/aarch64.rs new file mode 100644 index 0000000..56e8b71 --- /dev/null +++ b/src/simd/aarch64.rs @@ -0,0 +1,139 @@ +use core::arch::aarch64::{ + uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8, +}; + +const MASK: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, +]; + +#[target_feature(enable = "neon")] +unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t { + vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned? +} + +#[target_feature(enable = "neon")] +unsafe fn sum(u8s: &uint8x16_t) -> usize { + vaddlvq_u8(*u8s) as usize +} + +#[target_feature(enable = "neon")] +pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { + assert!(haystack.len() >= 16); + + let mut offset = 0; + let mut count = 0; + + let needles = vdupq_n_u8(needle); + + // 4080 + while haystack.len() >= offset + 16 * 255 { + let mut counts = vdupq_n_u8(0); + for _ in 0..255 { + counts = vsubq_u8( + counts, + vceqq_u8(u8x16_from_offset(haystack, offset), needles), + ); + offset += 16; + } + count += sum(&counts); + } + + // 2048 + if haystack.len() >= offset + 16 * 128 { + let mut counts = vdupq_n_u8(0); + for _ in 0..128 { + counts = vsubq_u8( + counts, + vceqq_u8(u8x16_from_offset(haystack, offset), needles), + ); + offset += 16; + } + count += sum(&counts); + } + + // 16 + let mut counts = vdupq_n_u8(0); + for i in 0..(haystack.len() - offset) / 16 { + counts = vsubq_u8( + counts, + vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles), + ); + } + if haystack.len() % 16 != 0 { + counts = vsubq_u8( + counts, + vandq_u8( + vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles), + u8x16_from_offset(&MASK, haystack.len() % 16), + ), + ); + } + count += sum(&counts); + + count +} + +#[target_feature(enable = "neon")] +unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t { + vmvnq_u8(vceqq_u8( + vandq_u8(u8s, vdupq_n_u8(0b1100_0000)), + vdupq_n_u8(0b1000_0000), + )) +} + +#[target_feature(enable = "neon")] +pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { + assert!(utf8_chars.len() >= 16); + + let mut offset = 0; + let mut count = 0; + + // 4080 + while utf8_chars.len() >= offset + 16 * 255 { + let mut counts = vdupq_n_u8(0); + + for _ in 0..255 { + counts = vsubq_u8( + counts, + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)), + ); + offset += 16; + } + count += sum(&counts); + } + + // 2048 + if utf8_chars.len() >= offset + 16 * 128 { + let mut counts = vdupq_n_u8(0); + for _ in 0..128 { + counts = vsubq_u8( + counts, + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)), + ); + offset += 16; + } + count += sum(&counts); + } + + // 16 + let mut counts = vdupq_n_u8(0); + for i in 0..(utf8_chars.len() - offset) / 16 { + counts = vsubq_u8( + counts, + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)), + ); + } + if utf8_chars.len() % 16 != 0 { + counts = vsubq_u8( + counts, + vandq_u8( + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)), + u8x16_from_offset(&MASK, utf8_chars.len() % 16), + ), + ); + } + count += sum(&counts); + + count +} diff --git a/src/simd/mod.rs b/src/simd/mod.rs index d144e18..fa98575 100644 --- a/src/simd/mod.rs +++ b/src/simd/mod.rs @@ -15,3 +15,7 @@ pub mod x86_sse2; // Runtime feature detection is not available with no_std. #[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))] pub mod x86_avx2; + +/// Modern ARM machines are also quite capable thanks to NEON +#[cfg(target_arch = "aarch64")] +pub mod aarch64;