diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25c8e55..678caf5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,7 @@ jobs: arch: - i686 - x86_64 + - aarch64 features: - default - runtime-dispatch-simd diff --git a/src/simd/aarch64.rs b/src/simd/aarch64.rs index 56e8b71..6544355 100644 --- a/src/simd/aarch64.rs +++ b/src/simd/aarch64.rs @@ -1,5 +1,6 @@ use core::arch::aarch64::{ - uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8, + uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4, + vmvnq_u8, vsubq_u8, }; const MASK: [u8; 32] = [ @@ -9,12 +10,29 @@ const MASK: [u8; 32] = [ #[target_feature(enable = "neon")] unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t { + debug_assert!( + offset + 16 <= slice.len(), + "{} + 16 ≥ {}", + offset, + slice.len() + ); vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned? } #[target_feature(enable = "neon")] -unsafe fn sum(u8s: &uint8x16_t) -> usize { - vaddlvq_u8(*u8s) as usize +unsafe fn u8x16_x4_from_offset(slice: &[u8], offset: usize) -> uint8x16x4_t { + debug_assert!( + offset + 64 <= slice.len(), + "{} + 64 ≥ {}", + offset, + slice.len() + ); + vld1q_u8_x4(slice.as_ptr().add(offset) as *const _) +} + +#[target_feature(enable = "neon")] +unsafe fn sum(u8s: uint8x16_t) -> usize { + vaddlvq_u8(u8s) as usize } #[target_feature(enable = "neon")] @@ -26,38 +44,40 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { let needles = vdupq_n_u8(needle); - // 4080 - while haystack.len() >= offset + 16 * 255 { - let mut counts = vdupq_n_u8(0); + // 16320 + while haystack.len() >= offset + 64 * 255 { + let (mut count1, mut count2, mut count3, mut count4) = + (vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0)); for _ in 0..255 { - counts = vsubq_u8( - counts, - vceqq_u8(u8x16_from_offset(haystack, offset), needles), - ); - offset += 16; + let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset); + count1 = vsubq_u8(count1, vceqq_u8(h1, needles)); + count2 = vsubq_u8(count2, vceqq_u8(h2, needles)); + count3 = vsubq_u8(count3, vceqq_u8(h3, needles)); + count4 = vsubq_u8(count4, vceqq_u8(h4, needles)); + offset += 64; } - count += sum(&counts); + count += sum(count1) + sum(count2) + sum(count3) + sum(count4); } - // 2048 - if haystack.len() >= offset + 16 * 128 { - let mut counts = vdupq_n_u8(0); - for _ in 0..128 { - counts = vsubq_u8( - counts, - vceqq_u8(u8x16_from_offset(haystack, offset), needles), - ); - offset += 16; - } - count += sum(&counts); + // 64 + let (mut count1, mut count2, mut count3, mut count4) = + (vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0)); + for _ in 0..(haystack.len() - offset) / 64 { + let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset); + count1 = vsubq_u8(count1, vceqq_u8(h1, needles)); + count2 = vsubq_u8(count2, vceqq_u8(h2, needles)); + count3 = vsubq_u8(count3, vceqq_u8(h3, needles)); + count4 = vsubq_u8(count4, vceqq_u8(h4, needles)); + offset += 64; } + count += sum(count1) + sum(count2) + sum(count3) + sum(count4); - // 16 let mut counts = vdupq_n_u8(0); + // 16 for i in 0..(haystack.len() - offset) / 16 { counts = vsubq_u8( counts, - vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles), + vceqq_u8(u8x16_from_offset(haystack, offset + i * 16), needles), ); } if haystack.len() % 16 != 0 { @@ -69,9 +89,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { ), ); } - count += sum(&counts); - - count + count + sum(counts) } #[target_feature(enable = "neon")] @@ -100,7 +118,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { ); offset += 16; } - count += sum(&counts); + count += sum(counts); } // 2048 @@ -113,7 +131,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { ); offset += 16; } - count += sum(&counts); + count += sum(counts); } // 16 @@ -121,7 +139,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { for i in 0..(utf8_chars.len() - offset) / 16 { counts = vsubq_u8( counts, - is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)), + is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)), ); } if utf8_chars.len() % 16 != 0 { @@ -133,7 +151,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { ), ); } - count += sum(&counts); + count += sum(counts); count }