Skip to content

Commit

Permalink
add aarch64
Browse files Browse the repository at this point in the history
  • Loading branch information
llogiq committed Sep 26, 2023
1 parent fbad8d4 commit 5c13966
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 6 deletions.
35 changes: 29 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
//! still on small strings.

#![deny(missing_docs)]

#![cfg_attr(not(feature = "runtime-dispatch-simd"), no_std)]

#[cfg(not(feature = "runtime-dispatch-simd"))]
Expand All @@ -45,7 +44,11 @@ pub use naive::*;
mod integer_simd;

#[cfg(any(
all(feature = "runtime-dispatch-simd", any(target_arch = "x86", target_arch = "x86_64")),
all(
feature = "runtime-dispatch-simd",
any(target_arch = "x86", target_arch = "x86_64")
),
target_arch = "aarch64",
feature = "generic-simd"
))]
mod simd;
Expand All @@ -64,7 +67,9 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { return simd::x86_avx2::chunk_count(haystack, needle); }
unsafe {
return simd::x86_avx2::chunk_count(haystack, needle);
}
}
}

Expand All @@ -80,7 +85,15 @@ pub fn count(haystack: &[u8], needle: u8) -> usize {
))]
{
if is_x86_feature_detected!("sse2") {
unsafe { return simd::x86_sse2::chunk_count(haystack, needle); }
unsafe {
return simd::x86_sse2::chunk_count(haystack, needle);
}
}
}
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
{
unsafe {
return simd::aarch64::chunk_count(haystack, needle);
}
}
}
Expand Down Expand Up @@ -109,7 +122,9 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { return simd::x86_avx2::chunk_num_chars(utf8_chars); }
unsafe {
return simd::x86_avx2::chunk_num_chars(utf8_chars);
}
}
}

Expand All @@ -125,7 +140,15 @@ pub fn num_chars(utf8_chars: &[u8]) -> usize {
))]
{
if is_x86_feature_detected!("sse2") {
unsafe { return simd::x86_sse2::chunk_num_chars(utf8_chars); }
unsafe {
return simd::x86_sse2::chunk_num_chars(utf8_chars);
}
}
}
#[cfg(all(target_arch = "aarch64", not(feature = "generic_simd")))]
{
unsafe {
return simd::aarch64::chunk_num_chars(utf8_chars);
}
}
}
Expand Down
139 changes: 139 additions & 0 deletions src/simd/aarch64.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use core::arch::aarch64::{
uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8,
};

const MASK: [u8; 32] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255,
];

#[target_feature(enable = "neon")]
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t {
vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned?
}

#[target_feature(enable = "neon")]
unsafe fn sum(u8s: &uint8x16_t) -> usize {
vaddlvq_u8(*u8s) as usize
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
assert!(haystack.len() >= 16);

let mut offset = 0;
let mut count = 0;

let needles = vdupq_n_u8(needle);

// 4080
while haystack.len() >= offset + 16 * 255 {
let mut counts = vdupq_n_u8(0);
for _ in 0..255 {
counts = vsubq_u8(
counts,
vceqq_u8(u8x16_from_offset(haystack, offset), needles),
);
offset += 16;
}
count += sum(&counts);
}

// 2048
if haystack.len() >= offset + 16 * 128 {
let mut counts = vdupq_n_u8(0);
for _ in 0..128 {
counts = vsubq_u8(
counts,
vceqq_u8(u8x16_from_offset(haystack, offset), needles),
);
offset += 16;
}
count += sum(&counts);
}

// 16
let mut counts = vdupq_n_u8(0);
for i in 0..(haystack.len() - offset) / 16 {
counts = vsubq_u8(
counts,
vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles),
);
}
if haystack.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
vceqq_u8(u8x16_from_offset(haystack, haystack.len() - 16), needles),
u8x16_from_offset(&MASK, haystack.len() % 16),
),
);
}
count += sum(&counts);

count
}

#[target_feature(enable = "neon")]
unsafe fn is_leading_utf8_byte(u8s: uint8x16_t) -> uint8x16_t {
vmvnq_u8(vceqq_u8(
vandq_u8(u8s, vdupq_n_u8(0b1100_0000)),
vdupq_n_u8(0b1000_0000),
))
}

#[target_feature(enable = "neon")]
pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
assert!(utf8_chars.len() >= 16);

let mut offset = 0;
let mut count = 0;

// 4080
while utf8_chars.len() >= offset + 16 * 255 {
let mut counts = vdupq_n_u8(0);

for _ in 0..255 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
}
count += sum(&counts);
}

// 2048
if utf8_chars.len() >= offset + 16 * 128 {
let mut counts = vdupq_n_u8(0);
for _ in 0..128 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset)),
);
offset += 16;
}
count += sum(&counts);
}

// 16
let mut counts = vdupq_n_u8(0);
for i in 0..(utf8_chars.len() - offset) / 16 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)),
);
}
if utf8_chars.len() % 16 != 0 {
counts = vsubq_u8(
counts,
vandq_u8(
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, utf8_chars.len() - 16)),
u8x16_from_offset(&MASK, utf8_chars.len() % 16),
),
);
}
count += sum(&counts);

count
}
4 changes: 4 additions & 0 deletions src/simd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ pub mod x86_sse2;
// Runtime feature detection is not available with no_std.
#[cfg(all(feature = "runtime-dispatch-simd", target_arch = "x86_64"))]
pub mod x86_avx2;

/// Modern ARM machines are also quite capable thanks to NEON
#[cfg(target_arch = "aarch64")]
pub mod aarch64;

0 comments on commit 5c13966

Please sign in to comment.