Skip to content

Commit

Permalink
LibCrypto: SIMDify GHash
Browse files Browse the repository at this point in the history
  • Loading branch information
MarekKnapek committed Nov 5, 2024
1 parent 5c77f1c commit 6846efb
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 39 deletions.
8 changes: 8 additions & 0 deletions AK/SIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,19 @@ struct IndexVectorFor<T> {
u64 __attribute__((vector_size(sizeof(T))))>;
};

template<typename T, size_t element_count>
struct MakeVectorImpl {
using Type __attribute__((vector_size(sizeof(T) * element_count))) = T;
};

}

template<SIMDVector T>
using IndexVectorFor = typename Detail::IndexVectorFor<T>::Type;

template<typename T, size_t element_count>
using MakeVector = typename Detail::MakeVectorImpl<T, element_count>::Type;

static_assert(IsSame<IndexVectorFor<i8x16>, i8x16>);
static_assert(IsSame<IndexVectorFor<u32x4>, u32x4>);
static_assert(IsSame<IndexVectorFor<u64x4>, u64x4>);
Expand Down
70 changes: 70 additions & 0 deletions AK/SIMDExtras.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,40 @@ ALWAYS_INLINE static T elementwise_byte_reverse_impl(T a, IndexSequence<Idx...>)
}
}

template<SIMDVector T, size_t... Idx>
ALWAYS_INLINE static ElementOf<T> reduce_or_impl(T const& a, IndexSequence<Idx...> const&)
{
static_assert(is_power_of_two(vector_length<T>));
static_assert(vector_length<T> == sizeof...(Idx) * 2);

using E = ElementOf<T>;

constexpr size_t N = sizeof...(Idx);

if constexpr (N == 1) {
return a[0] | a[1];
} else {
return reduce_or_impl(MakeVector<E, N> { (a[Idx])... }, MakeIndexSequence<N / 2>()) | reduce_or_impl(MakeVector<E, N> { (a[N + Idx])... }, MakeIndexSequence<N / 2>());
}
}

template<SIMDVector T, size_t... Idx>
ALWAYS_INLINE static ElementOf<T> reduce_xor_impl(T const& a, IndexSequence<Idx...> const&)
{
static_assert(is_power_of_two(vector_length<T>));
static_assert(vector_length<T> == sizeof...(Idx) * 2);

using E = ElementOf<T>;

constexpr size_t N = sizeof...(Idx);

if constexpr (N == 1) {
return a[0] ^ a[1];
} else {
return reduce_xor_impl(MakeVector<E, N> { (a[Idx])... }, MakeIndexSequence<N / 2>()) ^ reduce_xor_impl(MakeVector<E, N> { (a[N + Idx])... }, MakeIndexSequence<N / 2>());
}
}

}

// FIXME: Shuffles only work with integral types for now
Expand Down Expand Up @@ -286,4 +320,40 @@ ALWAYS_INLINE static T elementwise_byte_reverse(T a)
return Detail::elementwise_byte_reverse_impl(a, MakeIndexSequence<vector_length<T>>());
}

template<SIMDVector T>
ALWAYS_INLINE static ElementOf<T> reduce_or(T const& a)
{
static_assert(is_power_of_two(vector_length<T>));
static_assert(IsUnsigned<ElementOf<T>>);

#if defined __has_builtin
# if __has_builtin(__builtin_reduce_or)
if (true) {
return __builtin_reduce_or(a);
} else
# endif
#endif
{
return Detail::reduce_or_impl(a, MakeIndexSequence<vector_length<T> / 2>());
}
}

template<SIMDVector T>
ALWAYS_INLINE static ElementOf<T> reduce_xor(T const& a)
{
static_assert(is_power_of_two(vector_length<T>));
static_assert(IsUnsigned<ElementOf<T>>);

#if defined __has_builtin
# if __has_builtin(__builtin_reduce_xor)
if (true) {
return __builtin_reduce_xor(a);
} else
# endif
#endif
{
return Detail::reduce_xor_impl(a, MakeIndexSequence<vector_length<T> / 2>());
}
}

}
93 changes: 54 additions & 39 deletions Userland/Libraries/LibCrypto/Authentication/GHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <AK/ByteReader.h>
#include <AK/Debug.h>
#include <AK/SIMD.h>
#include <AK/SIMDExtras.h>
#include <AK/Types.h>
#include <LibCrypto/Authentication/GHash.h>

Expand Down Expand Up @@ -86,50 +88,63 @@ GHash::TagType GHash::process(ReadonlyBytes aad, ReadonlyBytes cipher)

void galois_multiply(u32 (&_z)[4], u32 const (&_x)[4], u32 const (&_y)[4])
{
static auto const mul_32_x_32_64 = [](u32 const& a, u32 const& b) -> u64 {
return static_cast<u64>(a) * static_cast<u64>(b);
using namespace AK::SIMD;

static auto const rotate_left = [](u32x4 const& x) -> u32x4 {
return u32x4 { x[3], x[0], x[1], x[2] };
};

static auto const mul_32_x_32_64 = [](u32x4 const& a, u32x4 const& b) -> u64x4 {
u64x2 r1;
u64x2 r2;

#if defined __has_builtin
# if __has_builtin(__builtin_ia32_pmuludq128)
if (true) {
r1 = simd_cast<u64x2>(__builtin_ia32_pmuludq128(simd_cast<i32x4>(u32x4 { a[0], 0, a[1], 0 }), simd_cast<i32x4>(u32x4 { b[0], 0, b[1], 0 })));
r2 = simd_cast<u64x2>(__builtin_ia32_pmuludq128(simd_cast<i32x4>(u32x4 { a[2], 0, a[3], 0 }), simd_cast<i32x4>(u32x4 { b[2], 0, b[3], 0 })));
} else
# endif
#endif
{
r1 = u64x2 { static_cast<u64>(a[0]) * static_cast<u64>(b[0]), static_cast<u64>(a[1]) * static_cast<u64>(b[1]) };
r2 = u64x2 { static_cast<u64>(a[2]) * static_cast<u64>(b[2]), static_cast<u64>(a[3]) * static_cast<u64>(b[3]) };
}
return u64x4 { r1[0], r1[1], r2[0], r2[1] };
};

static auto const clmul_32_x_32_64 = [](u32 const& a, u32 const& b, u32& lo, u32& hi) -> void {
u32 ta[4];
u32 tb[4];
u64 tu64[4];
u64 tc[4];
constexpr u32x4 mask32 = { 0x11111111, 0x22222222, 0x44444444, 0x88888888 };
constexpr u64x4 mask64 = { 0x1111111111111111ull, 0x2222222222222222ull, 0x4444444444444444ull, 0x8888888888888888ull };

u32x4 ta;
u32x4 tb;
u64x4 tu64;
u64x4 tc;
u64 cc;

ta[0] = a & static_cast<u32>(0x11111111ul);
ta[1] = a & static_cast<u32>(0x22222222ul);
ta[2] = a & static_cast<u32>(0x44444444ul);
ta[3] = a & static_cast<u32>(0x88888888ul);
tb[0] = b & static_cast<u32>(0x11111111ul);
tb[1] = b & static_cast<u32>(0x22222222ul);
tb[2] = b & static_cast<u32>(0x44444444ul);
tb[3] = b & static_cast<u32>(0x88888888ul);
tu64[0] = mul_32_x_32_64(ta[0], tb[0]);
tu64[1] = mul_32_x_32_64(ta[1], tb[3]);
tu64[2] = mul_32_x_32_64(ta[2], tb[2]);
tu64[3] = mul_32_x_32_64(ta[3], tb[1]);
tc[0] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3];
tu64[0] = mul_32_x_32_64(ta[0], tb[1]);
tu64[1] = mul_32_x_32_64(ta[1], tb[0]);
tu64[2] = mul_32_x_32_64(ta[2], tb[3]);
tu64[3] = mul_32_x_32_64(ta[3], tb[2]);
tc[1] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3];
tu64[0] = mul_32_x_32_64(ta[0], tb[2]);
tu64[1] = mul_32_x_32_64(ta[1], tb[1]);
tu64[2] = mul_32_x_32_64(ta[2], tb[0]);
tu64[3] = mul_32_x_32_64(ta[3], tb[3]);
tc[2] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3];
tu64[0] = mul_32_x_32_64(ta[0], tb[3]);
tu64[1] = mul_32_x_32_64(ta[1], tb[2]);
tu64[2] = mul_32_x_32_64(ta[2], tb[1]);
tu64[3] = mul_32_x_32_64(ta[3], tb[0]);
tc[3] = tu64[0] ^ tu64[1] ^ tu64[2] ^ tu64[3];
tc[0] &= static_cast<u64>(0x1111111111111111ull);
tc[1] &= static_cast<u64>(0x2222222222222222ull);
tc[2] &= static_cast<u64>(0x4444444444444444ull);
tc[3] &= static_cast<u64>(0x8888888888888888ull);
cc = tc[0] | tc[1] | tc[2] | tc[3];
ta = a & mask32;
tb = b & mask32;
tb = item_reverse(tb);

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[0] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[1] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[2] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[3] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tc &= mask64;
cc = reduce_or(tc);
lo = static_cast<u32>((cc >> (0 * 32)) & 0xfffffffful);
hi = static_cast<u32>((cc >> (1 * 32)) & 0xfffffffful);
};
Expand Down

0 comments on commit 6846efb

Please sign in to comment.