Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LibCrypto: Improve GHash / GCM performance #24951

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions AK/SIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,19 @@ struct IndexVectorFor<T> {
};
#endif

template<typename T, size_t element_count>
struct MakeVectorImpl {
using Type __attribute__((vector_size(sizeof(T) * element_count))) = T;
};
Comment on lines +112 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You sure this works on GCC?
I had issues with dependent vector sizes on there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this works on my machine, Ubuntu with GCC. Source: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=88600#c1


}

template<SIMDVector T>
using IndexVectorFor = typename Detail::IndexVectorFor<T>::Type;

template<typename T, size_t element_count>
using MakeVector = typename Detail::MakeVectorImpl<T, element_count>::Type;

static_assert(IsSame<IndexVectorFor<i8x16>, i8x16>);
static_assert(IsSame<IndexVectorFor<u32x4>, u32x4>);
static_assert(IsSame<IndexVectorFor<u64x4>, u64x4>);
Expand Down
70 changes: 70 additions & 0 deletions AK/SIMDExtras.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,40 @@ ALWAYS_INLINE static T elementwise_byte_reverse_impl(T a, IndexSequence<Idx...>)
}
}

template<SIMDVector T, size_t... Idx>
ALWAYS_INLINE static ElementOf<T> reduce_or_impl(T const& a, IndexSequence<Idx...> const&)
{
static_assert(is_power_of_two(vector_length<T>));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically not a requirement, just a limitation of the generic impl,
Clang officially uses an even-odd pattern for their builtin

Copy link
Contributor Author

@MarekKnapek MarekKnapek Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct. But I in practice there is always power of two. And I was not feeling like implementing general case. Might add /*todo*/.

static_assert(vector_length<T> == sizeof...(Idx) * 2);

Comment on lines +259 to +261
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should also be requires clauses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

using E = ElementOf<T>;

constexpr size_t N = sizeof...(Idx);

if constexpr (N == 1) {
return a[0] | a[1];
} else {
return reduce_or_impl(MakeVector<E, N> { (a[Idx])... }, MakeIndexSequence<N / 2>()) | reduce_or_impl(MakeVector<E, N> { (a[N + Idx])... }, MakeIndexSequence<N / 2>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt (a[Idx]|...) or similar work as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaaaa, that would save me much effort. Thank you, will test it out.

}
}

template<SIMDVector T, size_t... Idx>
ALWAYS_INLINE static ElementOf<T> reduce_xor_impl(T const& a, IndexSequence<Idx...> const&)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See both comments above

{
static_assert(is_power_of_two(vector_length<T>));
static_assert(vector_length<T> == sizeof...(Idx) * 2);

using E = ElementOf<T>;

constexpr size_t N = sizeof...(Idx);

if constexpr (N == 1) {
return a[0] ^ a[1];
} else {
return reduce_xor_impl(MakeVector<E, N> { (a[Idx])... }, MakeIndexSequence<N / 2>()) ^ reduce_xor_impl(MakeVector<E, N> { (a[N + Idx])... }, MakeIndexSequence<N / 2>());
}
}

}

// FIXME: Shuffles only work with integral types for now
Expand Down Expand Up @@ -286,4 +320,40 @@ ALWAYS_INLINE static T elementwise_byte_reverse(T a)
return Detail::elementwise_byte_reverse_impl(a, MakeIndexSequence<vector_length<T>>());
}

template<SIMDVector T>
ALWAYS_INLINE static ElementOf<T> reduce_or(T const& a)
{
static_assert(is_power_of_two(vector_length<T>));
static_assert(IsUnsigned<ElementOf<T>>);

#if defined __has_builtin
# if __has_builtin(__builtin_reduce_or)
if (true) {
return __builtin_reduce_or(a);
} else
# endif
Comment on lines +330 to +334
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird pattern, shouldnt if constexpr(...) work as well and be nicer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in which case is __has_builtin not defined, we dont really care for MSVC

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to fix this problem: https://github.com/SerenityOS/serenity/actions/runs/10498331182/job/29083055403#step:10:2402 it seems that compiler doesn't like the if constexpr pattern in that specific case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sure, then go through the pre-processor, but the #if defined __has_builtin should be redundant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cross that, making it
__has_builtin(__builtin_reduce_or) && DependentTrue<T> should make it work
(weird rules around when a false constexpr branch is checked)

#endif
{
return Detail::reduce_or_impl(a, MakeIndexSequence<vector_length<T> / 2>());
}
}

template<SIMDVector T>
ALWAYS_INLINE static ElementOf<T> reduce_xor(T const& a)
{
static_assert(is_power_of_two(vector_length<T>));
static_assert(IsUnsigned<ElementOf<T>>);

#if defined __has_builtin
# if __has_builtin(__builtin_reduce_xor)
if (true) {
Comment on lines +347 to +349
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

return __builtin_reduce_xor(a);
} else
# endif
#endif
{
return Detail::reduce_xor_impl(a, MakeIndexSequence<vector_length<T> / 2>());
}
}

}
188 changes: 158 additions & 30 deletions Userland/Libraries/LibCrypto/Authentication/GHash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <AK/ByteReader.h>
#include <AK/Debug.h>
#include <AK/SIMD.h>
#include <AK/SIMDExtras.h>
#include <AK/Types.h>
#include <LibCrypto/Authentication/GHash.h>

Expand Down Expand Up @@ -84,39 +86,165 @@ GHash::TagType GHash::process(ReadonlyBytes aad, ReadonlyBytes cipher)
return digest;
}

/// Galois Field multiplication using <x^127 + x^7 + x^2 + x + 1>.
/// Note that x, y, and z are strictly BE.
void galois_multiply(u32 (&_z)[4], u32 const (&_x)[4], u32 const (&_y)[4])
{
// Note: Copied upfront to stack to avoid memory access in the loop.
u32 x[4] { _x[0], _x[1], _x[2], _x[3] };
u32 const y[4] { _y[0], _y[1], _y[2], _y[3] };
u32 z[4] { 0, 0, 0, 0 };

// Unrolled by 32, the access in y[3-(i/32)] can be cached throughout the loop.
#pragma GCC unroll 32
for (ssize_t i = 127, j = 0; i > -1; --i, j++) {
auto r = -((y[j / 32] >> (i % 32)) & 1);
z[0] ^= x[0] & r;
z[1] ^= x[1] & r;
z[2] ^= x[2] & r;
z[3] ^= x[3] & r;
auto a0 = x[0] & 1;
x[0] >>= 1;
auto a1 = x[1] & 1;
x[1] >>= 1;
x[1] |= a0 << 31;
auto a2 = x[2] & 1;
x[2] >>= 1;
x[2] |= a1 << 31;
auto a3 = x[3] & 1;
x[3] >>= 1;
x[3] |= a2 << 31;

x[0] ^= 0xe1000000 & -a3;
}
/** This function computes 128bit x 128bit unsigned integer multiplication inside Galois finite field, producing 128bit result.
* It uses 9 32bit x 32bit to 64bit carry-less multiplications in Karatsuba decomposition.
*/
using namespace AK::SIMD;

static auto const rotate_left = [](u32x4 const& x) -> u32x4 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a rotation right, isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that 0x1234 << 4 == 0x2340 is left shift and 0x1234 >> 4 == 0x0123 is right shift. Here you write the digits in "big endian" order. In case of u32x4 vec{ 1, 2, 3, 4 }, you write the vector elements in "little endian" order. So the rotation is "right" on screen, but "left" regarding the bits...I guess.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably a point of view thing,
maybe rename the helper to express the scope its working on

return u32x4 { x[3], x[0], x[1], x[2] };
};

static auto const mul_32_x_32_64 = [](u32x4 const& a, u32x4 const& b) -> u64x4 {
/** This function computes 32bit x 32bit unsigned integer multiplication, producing 64bit result.
* It does this for 4 32bit integers x 4 32bit integers at a time, producing 4 64bit integers result.
*/
u64x2 r1;
u64x2 r2;

#if defined __has_builtin
# if __has_builtin(__builtin_ia32_pmuludq128)
if (true) {
Comment on lines +107 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Comment on lines +107 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah forgot to mention it, this needs an x86 guard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

r1 = simd_cast<u64x2>(__builtin_ia32_pmuludq128(simd_cast<i32x4>(u32x4 { a[0], 0, a[1], 0 }), simd_cast<i32x4>(u32x4 { b[0], 0, b[1], 0 })));
r2 = simd_cast<u64x2>(__builtin_ia32_pmuludq128(simd_cast<i32x4>(u32x4 { a[2], 0, a[3], 0 }), simd_cast<i32x4>(u32x4 { b[2], 0, b[3], 0 })));
} else
# endif
#endif
{
r1 = u64x2 { static_cast<u64>(a[0]) * static_cast<u64>(b[0]), static_cast<u64>(a[1]) * static_cast<u64>(b[1]) };
r2 = u64x2 { static_cast<u64>(a[2]) * static_cast<u64>(b[2]), static_cast<u64>(a[3]) * static_cast<u64>(b[3]) };
Comment on lines +116 to +117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out return to<u64x4>(a)*to<u64x4>(b) has slightly better codegen (and emits pmuluqd in my tests)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copy & pasta from your suggestion. But OK, will improve.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, on discord I corrected my self, the link posted here has the other version
also s/to/simd_cast/, I can't remember names

}
return u64x4 { r1[0], r1[1], r2[0], r2[1] };
};

static auto const clmul_32_x_32_64 = [](u32 const& a, u32 const& b, u32& lo, u32& hi) -> void {
/** This function computes 32bit x 32bit unsigned integer carry-less multiplication, producing 64bit result.
* It does this by extracting 4 bits from each integer at a time and multiplying those.
* Those 4 bits are packed into 32bit integers with holes, 1 significant bit plus 3 holes, repeated 4 times.
* Repeating previous logic 4 times, we are able to multiply all of the input 32 bits.
* The holes are there to prevent the carry spill to more significant bits. Respectively, allowing the carry
* to spill into holes, the holes are later discarded.
* https://www.bearssl.org/constanttime.html#ghash-for-gcm
*/
constexpr u32x4 mask32 = { 0x11111111, 0x22222222, 0x44444444, 0x88888888 };
constexpr u64x4 mask64 = { 0x1111111111111111ull, 0x2222222222222222ull, 0x4444444444444444ull, 0x8888888888888888ull };

u32x4 ta;
u32x4 tb;
u64x4 tu64;
u64x4 tc;
u64 cc;

ta = a & mask32;
tb = b & mask32;
tb = item_reverse(tb);

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[0] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[1] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[2] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tb = rotate_left(tb);
tu64 = mul_32_x_32_64(ta, tb);
tc[3] = reduce_xor(u64x4 { tu64[0], tu64[1], tu64[2], tu64[3] });

tc &= mask64;
cc = reduce_or(tc);
lo = static_cast<u32>((cc >> (0 * 32)) & 0xfffffffful);
hi = static_cast<u32>((cc >> (1 * 32)) & 0xfffffffful);
};

memcpy(_z, z, sizeof(z));
u32 aa[4];
u32 bb[4];
u32 ta[9];
u32 tb[9];
u32 tc[4];
u32 tu32[4];
u32 td[4];
u32 te[4];
u32 z[8];

aa[3] = _x[0];
aa[2] = _x[1];
aa[1] = _x[2];
aa[0] = _x[3];
bb[3] = _y[0];
bb[2] = _y[1];
bb[1] = _y[2];
bb[0] = _y[3];
ta[0] = aa[0];
ta[1] = aa[1];
ta[2] = ta[0] ^ ta[1];
ta[3] = aa[2];
ta[4] = aa[3];
ta[5] = ta[3] ^ ta[4];
ta[6] = ta[0] ^ ta[3];
ta[7] = ta[1] ^ ta[4];
ta[8] = ta[6] ^ ta[7];
tb[0] = bb[0];
tb[1] = bb[1];
tb[2] = tb[0] ^ tb[1];
tb[3] = bb[2];
tb[4] = bb[3];
tb[5] = tb[3] ^ tb[4];
tb[6] = tb[0] ^ tb[3];
tb[7] = tb[1] ^ tb[4];
tb[8] = tb[6] ^ tb[7];
Comment on lines +166 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part still feels odd, it looks like it might be vectors but then the ta/tb fall out of place,
any neater way this can be described?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also comments, on how this works would be nice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could also do Karatsuba with the cmul, if that isn't whats happening here
Thats what the intel white paper does with 128 bit width pcmulqdq, so two/four rounds of it should get us there

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is Karatsuba inspired by BearSSL.

for (int i = 0; i != 9; ++i) {
clmul_32_x_32_64(ta[i], tb[i], ta[i], tb[i]);
}
tc[0] = ta[0];
tc[1] = ta[0] ^ ta[1] ^ ta[2] ^ tb[0];
tc[2] = ta[1] ^ tb[0] ^ tb[1] ^ tb[2];
tc[3] = tb[1];
td[0] = ta[3];
td[1] = ta[3] ^ ta[4] ^ ta[5] ^ tb[3];
td[2] = ta[4] ^ tb[3] ^ tb[4] ^ tb[5];
td[3] = tb[4];
te[0] = ta[6];
te[1] = ta[6] ^ ta[7] ^ ta[8] ^ tb[6];
te[2] = ta[7] ^ tb[6] ^ tb[7] ^ tb[8];
te[3] = tb[7];
te[0] ^= (tc[0] ^ td[0]);
te[1] ^= (tc[1] ^ td[1]);
te[2] ^= (tc[2] ^ td[2]);
te[3] ^= (tc[3] ^ td[3]);
tc[2] ^= te[0];
tc[3] ^= te[1];
td[0] ^= te[2];
td[1] ^= te[3];
z[0] = tc[0] << 1;
z[1] = (tc[1] << 1) | (tc[0] >> 31);
z[2] = (tc[2] << 1) | (tc[1] >> 31);
z[3] = (tc[3] << 1) | (tc[2] >> 31);
z[4] = (td[0] << 1) | (tc[3] >> 31);
z[5] = (td[1] << 1) | (td[0] >> 31);
z[6] = (td[2] << 1) | (td[1] >> 31);
z[7] = (td[3] << 1) | (td[2] >> 31);
for (int i = 0; i != 4; ++i) {
tu32[0] = z[i] << 31;
tu32[1] = z[i] << 30;
tu32[2] = z[i] << 25;
z[i + 3] ^= (tu32[0] ^ tu32[1] ^ tu32[2]);
tu32[0] = z[i] >> 0;
tu32[1] = z[i] >> 1;
tu32[2] = z[i] >> 2;
tu32[3] = z[i] >> 7;
z[i + 4] ^= (tu32[0] ^ tu32[1] ^ tu32[2] ^ tu32[3]);
}
_z[0] = z[7];
_z[1] = z[6];
_z[2] = z[5];
_z[3] = z[4];
}

}