diff --git a/src/implementation/aarch64/neon.rs b/src/implementation/aarch64/neon.rs index c8f4065f..efc24b06 100644 --- a/src/implementation/aarch64/neon.rs +++ b/src/implementation/aarch64/neon.rs @@ -86,33 +86,8 @@ impl SimdU8Value { } #[inline] - #[flexpect::e(clippy::too_many_arguments)] - unsafe fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { - Self::from(vqtbl1q_u8( - Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0, - self.0, - )) + unsafe fn lookup_16(self, tbl: Self) -> Self { + Self::from(vqtbl1q_u8(tbl.0, self.0)) } #[inline] @@ -210,3 +185,4 @@ const PREFETCH: bool = false; use crate::implementation::helpers::TempSimdChunkA16 as TempSimdChunk; simd_input_128_bit!(); algorithm_simd!(); +algorithm_simd_default_special_case_fns!(); diff --git a/src/implementation/algorithm.rs b/src/implementation/algorithm.rs index 2da91405..090ebd90 100644 --- a/src/implementation/algorithm.rs +++ b/src/implementation/algorithm.rs @@ -64,7 +64,18 @@ macro_rules! algorithm_simd { $(#[$feat])* #[inline] - unsafe fn check_special_cases(input: SimdU8Value, prev1: SimdU8Value) -> SimdU8Value { + unsafe fn has_error(&self) -> bool { + self.error.any_bit_set() + } + + /// Lookup tables used by `check_special_cases`, returns + /// `(byte_1_high, byte_1_low, byte_2_high)` + /// + /// In optimized builds the returned values compile to constants, which are + /// directly loaded into SIMD registers. + $(#[$feat])* + #[inline] + unsafe fn special_case_tables() -> (SimdU8Value, SimdU8Value, SimdU8Value) { const TOO_SHORT: u8 = 1 << 0; const TOO_LONG: u8 = 1 << 1; const OVERLONG_3: u8 = 1 << 2; @@ -76,7 +87,7 @@ macro_rules! algorithm_simd { const OVERLONG_4: u8 = 1 << 6; const CARRY: u8 = TOO_SHORT | TOO_LONG | TWO_CONTS; - let byte_1_high = prev1.shr4().lookup_16( + let byte_1_high = SimdU8Value::repeat_16( TOO_LONG, TOO_LONG, TOO_LONG, @@ -95,7 +106,7 @@ macro_rules! algorithm_simd { TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4, ); - let byte_1_low = prev1.and(SimdU8Value::splat(0x0F)).lookup_16( + let byte_1_low = SimdU8Value::repeat_16( CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, CARRY | OVERLONG_2, CARRY, @@ -114,7 +125,7 @@ macro_rules! algorithm_simd { CARRY | TOO_LARGE | TOO_LARGE_1000, ); - let byte_2_high = input.shr4().lookup_16( + let byte_2_high = SimdU8Value::repeat_16( TOO_SHORT, TOO_SHORT, TOO_SHORT, @@ -133,35 +144,7 @@ macro_rules! algorithm_simd { TOO_SHORT, ); - byte_1_high.and(byte_1_low).and(byte_2_high) - } - - $(#[$feat])* - #[inline] - unsafe fn must_be_2_3_continuation(prev2: SimdU8Value, prev3: SimdU8Value) -> SimdU8Value { - let is_third_byte = prev2.saturating_sub(SimdU8Value::splat(0xe0 - 0x80)); - let is_fourth_byte = prev3.saturating_sub(SimdU8Value::splat(0xf0 - 0x80)); - is_third_byte.or(is_fourth_byte) - } - - $(#[$feat])* - #[inline] - unsafe fn check_multibyte_lengths( - input: SimdU8Value, - prev: SimdU8Value, - special_cases: SimdU8Value, - ) -> SimdU8Value { - let prev2 = input.prev2(prev); - let prev3 = input.prev3(prev); - let must23 = Self::must_be_2_3_continuation(prev2, prev3); - let must23_80 = must23.and(SimdU8Value::splat(0x80)); - must23_80.xor(special_cases) - } - - $(#[$feat])* - #[inline] - unsafe fn has_error(&self) -> bool { - self.error.any_bit_set() + (byte_1_high, byte_1_low, byte_2_high) } $(#[$feat])* @@ -169,9 +152,7 @@ macro_rules! algorithm_simd { unsafe fn check_bytes(&mut self, input: SimdU8Value) { let prev1 = input.prev1(self.prev); let sc = Self::check_special_cases(input, prev1); - self.error = self - .error - .or(Self::check_multibyte_lengths(input, self.prev, sc)); + self.error = Self::check_multibyte_lengths(input, self.prev, sc, self.error); self.prev = input; } @@ -511,6 +492,55 @@ macro_rules! algorithm_simd { }; } +/// Default implementations of `check_special_cases` and `check_multibyte_lengths` +/// (and its `must_be_2_3_continuation` helper). +/// +/// Every architecture except AVX-512 invokes this macro. AVX-512 provides its own +/// micro-optimized versions. +macro_rules! algorithm_simd_default_special_case_fns { + ($(#[$feat:meta])*) => { + impl Utf8CheckAlgorithm { + $(#[$feat])* + #[inline] + unsafe fn check_special_cases(input: SimdU8Value, prev1: SimdU8Value) -> SimdU8Value { + let (byte_1_high_table, byte_1_low_table, byte_2_high_table) = + Self::special_case_tables(); + + let byte_1_high = prev1.shr4().lookup_16(byte_1_high_table); + let byte_1_low = prev1.and(SimdU8Value::splat(0x0F)).lookup_16(byte_1_low_table); + let byte_2_high = input.shr4().lookup_16(byte_2_high_table); + + byte_1_high.and(byte_1_low).and(byte_2_high) + } + + $(#[$feat])* + #[inline] + unsafe fn must_be_2_3_continuation(prev2: SimdU8Value, prev3: SimdU8Value) -> SimdU8Value { + let is_third_byte = prev2.saturating_sub(SimdU8Value::splat(0xe0 - 0x80)); + let is_fourth_byte = prev3.saturating_sub(SimdU8Value::splat(0xf0 - 0x80)); + is_third_byte.or(is_fourth_byte) + } + + // Slightly different from the original algorithm. The error is or-ed in here to allow + // AVX-512 ternlog optimization. + $(#[$feat])* + #[inline] + unsafe fn check_multibyte_lengths( + input: SimdU8Value, + prev: SimdU8Value, + special_cases: SimdU8Value, + error: SimdU8Value, + ) -> SimdU8Value { + let prev2 = input.prev2(prev); + let prev3 = input.prev3(prev); + let must23 = Self::must_be_2_3_continuation(prev2, prev3); + let must23_80 = must23.and(SimdU8Value::splat(0x80)); + error.or(must23_80.xor(special_cases)) + } + } + }; +} + macro_rules! simd_input_128_bit { ($(#[$feat:meta])*) => { #[repr(C)] diff --git a/src/implementation/armv7/neon.rs b/src/implementation/armv7/neon.rs index 3e5cde94..070811d8 100644 --- a/src/implementation/armv7/neon.rs +++ b/src/implementation/armv7/neon.rs @@ -122,31 +122,8 @@ impl SimdU8Value { #[inline] #[target_feature(enable = "neon")] - #[flexpect::e(clippy::too_many_arguments)] - unsafe fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { - let rep = Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0; - Self(vqtbl1q_u8(rep, self.0)) + unsafe fn lookup_16(self, tbl: Self) -> Self { + Self(vqtbl1q_u8(tbl.0, self.0)) } #[inline] @@ -244,3 +221,4 @@ const PREFETCH: bool = false; use crate::implementation::helpers::TempSimdChunkA16 as TempSimdChunk; simd_input_128_bit!(#[target_feature(enable = "neon")]); algorithm_simd!(#[target_feature(enable = "neon")]); +algorithm_simd_default_special_case_fns!(#[target_feature(enable = "neon")]); diff --git a/src/implementation/portable/simd128.rs b/src/implementation/portable/simd128.rs index 55939ceb..42d02d50 100644 --- a/src/implementation/portable/simd128.rs +++ b/src/implementation/portable/simd128.rs @@ -80,32 +80,11 @@ impl SimdU8Value { } #[inline] - fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { + fn lookup_16(self, tbl: Self) -> Self { // We need to ensure that 'self' only contains the lower 4 bits, unlike the avx instruction // this will otherwise lead to bad results let idx: u8x16 = self.0; - let src: u8x16 = Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0; + let src: u8x16 = tbl.0; let res = src.swizzle_dyn(idx); Self::from(res) } @@ -221,3 +200,4 @@ const PREFETCH: bool = false; use crate::implementation::helpers::TempSimdChunkA16 as TempSimdChunk; simd_input_128_bit!(); algorithm_simd!(); +algorithm_simd_default_special_case_fns!(); diff --git a/src/implementation/portable/simd256.rs b/src/implementation/portable/simd256.rs index ac01403c..3555b51b 100644 --- a/src/implementation/portable/simd256.rs +++ b/src/implementation/portable/simd256.rs @@ -82,33 +82,11 @@ impl SimdU8Value { } #[inline] - fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { + fn lookup_16(self, tbl: Self) -> Self { // We need to ensure that 'self' only contains the lower 4 bits, unlike the avx instruction // this will otherwise lead to bad results let idx: u8x32 = self.0.cast(); - let src: u8x32 = Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0 - .cast(); + let src: u8x32 = tbl.0.cast(); let res = src.swizzle_dyn(idx); Self::from(res.cast()) } @@ -214,3 +192,4 @@ const PREFETCH: bool = false; use crate::implementation::helpers::TempSimdChunkA32 as TempSimdChunk; simd_input_256_bit!(); algorithm_simd!(); +algorithm_simd_default_special_case_fns!(); diff --git a/src/implementation/wasm32/simd128.rs b/src/implementation/wasm32/simd128.rs index 94230d3a..d4b58424 100644 --- a/src/implementation/wasm32/simd128.rs +++ b/src/implementation/wasm32/simd128.rs @@ -91,33 +91,8 @@ impl SimdU8Value { } #[inline] - #[flexpect::e(clippy::too_many_arguments)] - unsafe fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { - Self::from(u8x16_swizzle( - Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0, - self.0, - )) + unsafe fn lookup_16(self, tbl: Self) -> Self { + Self::from(u8x16_swizzle(tbl.0, self.0)) } #[inline] @@ -255,3 +230,4 @@ const PREFETCH: bool = false; use crate::implementation::helpers::TempSimdChunkA16 as TempSimdChunk; simd_input_128_bit!(#[target_feature(enable = "simd128")]); algorithm_simd!(#[target_feature(enable = "simd128")]); +algorithm_simd_default_special_case_fns!(#[target_feature(enable = "simd128")]); diff --git a/src/implementation/x86/avx2.rs b/src/implementation/x86/avx2.rs index 872ef1a4..0cb61b06 100644 --- a/src/implementation/x86/avx2.rs +++ b/src/implementation/x86/avx2.rs @@ -102,35 +102,10 @@ impl SimdU8Value { Self::from(_mm256_loadu_si256(ptr.cast::<__m256i>())) } - #[flexpect::e(clippy::too_many_arguments)] #[target_feature(enable = "avx2")] #[inline] - unsafe fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { - Self::from(_mm256_shuffle_epi8( - Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0, - self.0, - )) + unsafe fn lookup_16(self, tbl: Self) -> Self { + Self::from(_mm256_shuffle_epi8(tbl.0, self.0)) } #[flexpect::e(clippy::cast_possible_wrap)] @@ -240,3 +215,4 @@ const PREFETCH: bool = true; use crate::implementation::helpers::TempSimdChunkA32 as TempSimdChunk; simd_input_256_bit!(#[target_feature(enable = "avx2")]); algorithm_simd!(#[target_feature(enable = "avx2")]); +algorithm_simd_default_special_case_fns!(#[target_feature(enable = "avx2")]); diff --git a/src/implementation/x86/avx512.rs b/src/implementation/x86/avx512.rs index f46e1a51..eaf1505d 100644 --- a/src/implementation/x86/avx512.rs +++ b/src/implementation/x86/avx512.rs @@ -2,20 +2,18 @@ #[cfg(target_arch = "x86")] use core::arch::x86::{ - __m512i, _mm512_alignr_epi8, _mm512_and_si512, _mm512_loadu_si512, _mm512_maskz_loadu_epi8, - _mm512_movepi8_mask, _mm512_or_si512, _mm512_permutex2var_epi64, _mm512_set1_epi8, - _mm512_set_epi64, _mm512_set_epi8, _mm512_setzero_si512, _mm512_shuffle_epi8, - _mm512_srli_epi16, _mm512_subs_epu8, _mm512_test_epi8_mask, _mm512_xor_si512, _mm_prefetch, - _MM_HINT_T0, + __m512i, _mm512_alignr_epi8, _mm512_loadu_si512, _mm512_maskz_loadu_epi8, _mm512_movepi8_mask, + _mm512_or_si512, _mm512_permutex2var_epi64, _mm512_permutexvar_epi8, _mm512_set1_epi8, + _mm512_set_epi64, _mm512_set_epi8, _mm512_setzero_si512, _mm512_srli_epi16, _mm512_subs_epu8, + _mm512_ternarylogic_epi32, _mm512_test_epi8_mask, _mm_prefetch, _MM_HINT_T0, }; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::{ - __m512i, _mm512_alignr_epi8, _mm512_and_si512, _mm512_loadu_si512, _mm512_maskz_loadu_epi8, - _mm512_movepi8_mask, _mm512_or_si512, _mm512_permutex2var_epi64, _mm512_set1_epi8, - _mm512_set_epi64, _mm512_set_epi8, _mm512_setzero_si512, _mm512_shuffle_epi8, - _mm512_srli_epi16, _mm512_subs_epu8, _mm512_test_epi8_mask, _mm512_xor_si512, _mm_prefetch, - _MM_HINT_T0, + __m512i, _mm512_alignr_epi8, _mm512_loadu_si512, _mm512_maskz_loadu_epi8, _mm512_movepi8_mask, + _mm512_or_si512, _mm512_permutex2var_epi64, _mm512_permutexvar_epi8, _mm512_set1_epi8, + _mm512_set_epi64, _mm512_set_epi8, _mm512_setzero_si512, _mm512_srli_epi16, _mm512_subs_epu8, + _mm512_ternarylogic_epi32, _mm512_test_epi8_mask, _mm_prefetch, _MM_HINT_T0, }; use crate::implementation::helpers::Utf8CheckAlgorithm; @@ -123,35 +121,13 @@ impl SimdU8Value { Self::from(res) } - #[flexpect::e(clippy::too_many_arguments)] #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] #[inline] - unsafe fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { - Self::from(_mm512_shuffle_epi8( - Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0, - self.0, - )) + unsafe fn lookup_16(self, tbl: Self) -> Self { + // VPERMB (avx512vbmi) differs from other dynamic swizzle instructions in that it + // completely ignores the high bits of the index. Only the low 6 bits of each + // byte are used to select a byte from the 64-byte table. + Self::from(_mm512_permutexvar_epi8(self.0, tbl.0)) } #[flexpect::e(clippy::cast_possible_wrap)] @@ -175,27 +151,40 @@ impl SimdU8Value { #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] #[inline] - unsafe fn and(self, b: Self) -> Self { - Self::from(_mm512_and_si512(self.0, b.0)) + unsafe fn saturating_sub(self, b: Self) -> Self { + Self::from(_mm512_subs_epu8(self.0, b.0)) } + // For ternary ops reference see https://www.felixcloutier.com/x86/vpternlogd:vpternlogq + + /// `self & b & c` fused into a single `vpternlogd`. #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] #[inline] - unsafe fn xor(self, b: Self) -> Self { - Self::from(_mm512_xor_si512(self.0, b.0)) + unsafe fn and3(self, b: Self, c: Self) -> Self { + Self::from(_mm512_ternarylogic_epi32(self.0, b.0, c.0, 0x80)) } + /// `(self | b) & c` fused into a single `vpternlogd`. #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] #[inline] - unsafe fn saturating_sub(self, b: Self) -> Self { - Self::from(_mm512_subs_epu8(self.0, b.0)) + unsafe fn or_and(self, b: Self, c: Self) -> Self { + Self::from(_mm512_ternarylogic_epi32(self.0, b.0, c.0, 0xA8)) } - // ugly but shr requires const generics + /// `(self ^ b) | c` fused into a single `vpternlogd`. #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] #[inline] - unsafe fn shr4(self) -> Self { - Self::from(_mm512_srli_epi16(self.0, 4)).and(Self::splat(0xFF >> 4)) + unsafe fn xor_or(self, b: Self, c: Self) -> Self { + Self::from(_mm512_ternarylogic_epi32(c.0, self.0, b.0, 0xF6)) + } + + /// CAVE: Uses 16-bit-lane shifts so the high nibble of every + /// other byte is polluted with the low nibble of the neighbouring byte. + /// This is fine for our use case since we only care about the low nibble. + #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] + #[inline] + unsafe fn shr4_dirty(self) -> Self { + Self::from(_mm512_srli_epi16(self.0, 4)) } // ugly but prev requires const generics @@ -253,6 +242,41 @@ impl From<__m512i> for SimdU8Value { } } +/// AVX-512 specializations of `check_special_cases` and `check_multibyte_lengths`. +impl Utf8CheckAlgorithm { + #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] + #[inline] + unsafe fn check_special_cases(input: SimdU8Value, prev1: SimdU8Value) -> SimdU8Value { + let (byte_1_high_table, byte_1_low_table, byte_2_high_table) = Self::special_case_tables(); + + // `lookup_16` (VPERMB) on tables where the 128-bit lanes are all identical + // ignores the high nibble, so the pollution by `shr4_dirty` does not matter. + let byte_1_high = prev1.shr4_dirty().lookup_16(byte_1_high_table); + // `lookup_16` (VPERMB) on tables where the 128-bit lanes are all identical + // ignores the high nibble, so no masking needed. + let byte_1_low = prev1.lookup_16(byte_1_low_table); + let byte_2_high = input.shr4_dirty().lookup_16(byte_2_high_table); + + byte_1_high.and3(byte_1_low, byte_2_high) + } + + #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] + #[inline] + unsafe fn check_multibyte_lengths( + input: SimdU8Value, + prev: SimdU8Value, + special_cases: SimdU8Value, + error: SimdU8Value, + ) -> SimdU8Value { + let prev2 = input.prev2(prev); + let prev3 = input.prev3(prev); + let is_third_byte = prev2.saturating_sub(SimdU8Value::splat(0xe0 - 0x80)); + let is_fourth_byte = prev3.saturating_sub(SimdU8Value::splat(0xf0 - 0x80)); + let must23_80 = is_third_byte.or_and(is_fourth_byte, SimdU8Value::splat(0x80)); + must23_80.xor_or(special_cases, error) + } +} + #[target_feature(enable = "avx512f,avx512bw,avx512vbmi")] #[inline] unsafe fn simd_prefetch(ptr: *const u8) { diff --git a/src/implementation/x86/sse42.rs b/src/implementation/x86/sse42.rs index 309f64fd..9a3640e6 100644 --- a/src/implementation/x86/sse42.rs +++ b/src/implementation/x86/sse42.rs @@ -99,35 +99,10 @@ impl SimdU8Value { Self::from(_mm_loadu_si128(ptr.cast::<__m128i>())) } - #[flexpect::e(clippy::too_many_arguments)] #[target_feature(enable = "sse4.2")] #[inline] - unsafe fn lookup_16( - self, - v0: u8, - v1: u8, - v2: u8, - v3: u8, - v4: u8, - v5: u8, - v6: u8, - v7: u8, - v8: u8, - v9: u8, - v10: u8, - v11: u8, - v12: u8, - v13: u8, - v14: u8, - v15: u8, - ) -> Self { - Self::from(_mm_shuffle_epi8( - Self::repeat_16( - v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, - ) - .0, - self.0, - )) + unsafe fn lookup_16(self, tbl: Self) -> Self { + Self::from(_mm_shuffle_epi8(tbl.0, self.0)) } #[flexpect::e(clippy::cast_possible_wrap)] @@ -225,3 +200,4 @@ const PREFETCH: bool = false; use crate::implementation::helpers::TempSimdChunkA16 as TempSimdChunk; simd_input_128_bit!(#[target_feature(enable = "sse4.2")]); algorithm_simd!(#[target_feature(enable = "sse4.2")]); +algorithm_simd_default_special_case_fns!(#[target_feature(enable = "sse4.2")]);