From 6846efb6189aaa102b3dfadbc77d65d872a8ab0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Kn=C3=A1pek?= Date: Wed, 21 Aug 2024 23:53:51 +0200 Subject: [PATCH] LibCrypto: SIMDify GHash --- AK/SIMD.h | 8 ++ AK/SIMDExtras.h | 70 ++++++++++++++ .../LibCrypto/Authentication/GHash.cpp | 93 +++++++++++-------- 3 files changed, 132 insertions(+), 39 deletions(-) diff --git a/AK/SIMD.h b/AK/SIMD.h index 49a8cbfb98cb3b9..ebfa3d38694ae1b 100644 --- a/AK/SIMD.h +++ b/AK/SIMD.h @@ -106,11 +106,19 @@ struct IndexVectorFor { u64 __attribute__((vector_size(sizeof(T))))>; }; +template +struct MakeVectorImpl { + using Type __attribute__((vector_size(sizeof(T) * element_count))) = T; +}; + } template using IndexVectorFor = typename Detail::IndexVectorFor::Type; +template +using MakeVector = typename Detail::MakeVectorImpl::Type; + static_assert(IsSame, i8x16>); static_assert(IsSame, u32x4>); static_assert(IsSame, u64x4>); diff --git a/AK/SIMDExtras.h b/AK/SIMDExtras.h index e03c70c6a2d7710..f48003ecb9e724a 100644 --- a/AK/SIMDExtras.h +++ b/AK/SIMDExtras.h @@ -253,6 +253,40 @@ ALWAYS_INLINE static T elementwise_byte_reverse_impl(T a, IndexSequence) } } +template +ALWAYS_INLINE static ElementOf reduce_or_impl(T const& a, IndexSequence const&) +{ + static_assert(is_power_of_two(vector_length)); + static_assert(vector_length == sizeof...(Idx) * 2); + + using E = ElementOf; + + constexpr size_t N = sizeof...(Idx); + + if constexpr (N == 1) { + return a[0] | a[1]; + } else { + return reduce_or_impl(MakeVector { (a[Idx])... }, MakeIndexSequence()) | reduce_or_impl(MakeVector { (a[N + Idx])... }, MakeIndexSequence()); + } +} + +template +ALWAYS_INLINE static ElementOf reduce_xor_impl(T const& a, IndexSequence const&) +{ + static_assert(is_power_of_two(vector_length)); + static_assert(vector_length == sizeof...(Idx) * 2); + + using E = ElementOf; + + constexpr size_t N = sizeof...(Idx); + + if constexpr (N == 1) { + return a[0] ^ a[1]; + } else { + return reduce_xor_impl(MakeVector { (a[Idx])... }, MakeIndexSequence()) ^ reduce_xor_impl(MakeVector { (a[N + Idx])... }, MakeIndexSequence()); + } +} + } // FIXME: Shuffles only work with integral types for now @@ -286,4 +320,40 @@ ALWAYS_INLINE static T elementwise_byte_reverse(T a) return Detail::elementwise_byte_reverse_impl(a, MakeIndexSequence>()); } +template +ALWAYS_INLINE static ElementOf reduce_or(T const& a) +{ + static_assert(is_power_of_two(vector_length)); + static_assert(IsUnsigned>); + +#if defined __has_builtin +# if __has_builtin(__builtin_reduce_or) + if (true) { + return __builtin_reduce_or(a); + } else +# endif +#endif + { + return Detail::reduce_or_impl(a, MakeIndexSequence / 2>()); + } +} + +template +ALWAYS_INLINE static ElementOf reduce_xor(T const& a) +{ + static_assert(is_power_of_two(vector_length)); + static_assert(IsUnsigned>); + +#if defined __has_builtin +# if __has_builtin(__builtin_reduce_xor) + if (true) { + return __builtin_reduce_xor(a); + } else +# endif +#endif + { + return Detail::reduce_xor_impl(a, MakeIndexSequence / 2>()); + } +} + } diff --git a/Userland/Libraries/LibCrypto/Authentication/GHash.cpp b/Userland/Libraries/LibCrypto/Authentication/GHash.cpp index 3c5ff6ba6981ab1..9f440ff55f2159a 100644 --- a/Userland/Libraries/LibCrypto/Authentication/GHash.cpp +++ b/Userland/Libraries/LibCrypto/Authentication/GHash.cpp @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include @@ -86,50 +88,63 @@ GHash::TagType GHash::process(ReadonlyBytes aad, ReadonlyBytes cipher) void galois_multiply(u32 (&_z)[4], u32 const (&_x)[4], u32 const (&_y)[4]) { - static auto const mul_32_x_32_64 = [](u32 const& a, u32 const& b) -> u64 { - return static_cast(a) * static_cast(b); + using namespace AK::SIMD; + + static auto const rotate_left = [](u32x4 const& x) -> u32x4 { + return u32x4 { x[3], x[0], x[1], x[2] }; + }; + + static auto const mul_32_x_32_64 = [](u32x4 const& a, u32x4 const& b) -> u64x4 { + u64x2 r1; + u64x2 r2; + +#if defined __has_builtin +# if __has_builtin(__builtin_ia32_pmuludq128) + if (true) { + r1 = simd_cast(__builtin_ia32_pmuludq128(simd_cast(u32x4 { a[0], 0, a[1], 0 }), simd_cast(u32x4 { b[0], 0, b[1], 0 }))); + r2 = simd_cast(__builtin_ia32_pmuludq128(simd_cast(u32x4 { a[2], 0, a[3], 0 }), simd_cast(u32x4 { b[2], 0, b[3], 0 }))); + } else +# endif +#endif + { + r1 = u64x2 { static_cast(a[0]) * static_cast(b[0]), static_cast(a[1]) * static_cast(b[1]) }; + r2 = u64x2 { static_cast(a[2]) * static_cast(b[2]), static_cast(a[3]) * static_cast(b[3]) }; + } + return u64x4 { r1[0], r1[1], r2[0], r2[1] }; }; static auto const clmul_32_x_32_64 = [](u32 const& a, u32 const& b, u32& lo, u32& hi) -> void { - u32 ta[4]; - u32 tb[4]; - u64 tu64[4]; - u64 tc[4]; + constexpr u32x4 mask32 = { 0x11111111, 0x22222222, 0x44444444, 0x88888888 }; + constexpr u64x4 mask64 = { 0x1111111111111111ull, 0x2222222222222222ull, 0x4444444444444444ull, 0x8888888888888888ull }; + + u32x4 ta; + u32x4 tb; + u64x4 tu64; + u64x4 tc; u64 cc; - ta[0] = a & static_cast(0x11111111ul); - ta[1] = a & static_cast(0x22222222ul); - ta[2] = a & static_cast(0x44444444ul); - ta[3] = a & static_cast(0x88888888ul); - tb[0] = b & static_cast(0x11111111ul); - tb[1] = b & static_cast(0x22222222ul); - tb[2] = b & static_cast(0x44444444ul); - tb[3] = b & static_cast(0x88888888ul); - tu64[0] = mul_32_x_32_64(ta[0], tb[0]); - tu64[1] = mul_32_x_32_64(ta[1], tb[3]); - tu64[2] = mul_32_x_32_64(ta[2], tb[2]); - tu64[3] = mul_32_x_32_64(ta[3], tb[1]); - tc[0] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3]; - tu64[0] = mul_32_x_32_64(ta[0], tb[1]); - tu64[1] = mul_32_x_32_64(ta[1], tb[0]); - tu64[2] = mul_32_x_32_64(ta[2], tb[3]); - tu64[3] = mul_32_x_32_64(ta[3], tb[2]); - tc[1] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3]; - tu64[0] = mul_32_x_32_64(ta[0], tb[2]); - tu64[1] = mul_32_x_32_64(ta[1], tb[1]); - tu64[2] = mul_32_x_32_64(ta[2], tb[0]); - tu64[3] = mul_32_x_32_64(ta[3], tb[3]); - tc[2] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3]; - tu64[0] = mul_32_x_32_64(ta[0], tb[3]); - tu64[1] = mul_32_x_32_64(ta[1], tb[2]); - tu64[2] = mul_32_x_32_64(ta[2], tb[1]); - tu64[3] = mul_32_x_32_64(ta[3], tb[0]); - tc[3] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3]; - tc[0] &= static_cast(0x1111111111111111ull); - tc[1] &= static_cast(0x2222222222222222ull); - tc[2] &= static_cast(0x4444444444444444ull); - tc[3] &= static_cast(0x8888888888888888ull); - cc = tc[0] | tc[1] | tc[2] | tc[3]; + ta = a & mask32; + tb = b & mask32; + tb = item_reverse(tb); + + tb = rotate_left(tb); + tu64 = mul_32_x_32_64(ta, tb); + tc[0] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] }); + + tb = rotate_left(tb); + tu64 = mul_32_x_32_64(ta, tb); + tc[1] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] }); + + tb = rotate_left(tb); + tu64 = mul_32_x_32_64(ta, tb); + tc[2] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] }); + + tb = rotate_left(tb); + tu64 = mul_32_x_32_64(ta, tb); + tc[3] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] }); + + tc &= mask64; + cc = reduce_or(tc); lo = static_cast((cc >> (0 * 32)) & 0xfffffffful); hi = static_cast((cc >> (1 * 32)) & 0xfffffffful); };