Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hashing to curve #3293

Merged
merged 7 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crypto/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ SRCS += zkp_ecdsa.c
SRCS += zkp_bip340.c
SRCS += cardano.c
SRCS += tls_prf.c
SRCS += hash_to_curve.c

OBJS = $(SRCS:.c=.o)
OBJS += secp256k1-zkp.o
Expand Down
113 changes: 87 additions & 26 deletions crypto/bignum.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@
#define BN_MAX_DECIMAL_DIGITS \
79 // floor(log(2**(LIMBS * BITS_PER_LIMB), 10)) + 1

// y = (bignum256) x
// Assumes x is normalized and x < 2**261 == 2**(BITS_PER_LIMB * LIMBS)
// Guarantees y is normalized
void bn_copy_lower(const bignum512 *x, bignum256 *y) {
for (int i = 0; i < BN_LIMBS; i++) {
y->val[i] = x->val[i];
}
}

// out_number = (bignum256) in_number
// Assumes in_number is a raw bigendian 256-bit number
// Guarantees out_number is normalized
Expand All @@ -98,6 +107,30 @@ void bn_read_be(const uint8_t *in_number, bignum256 *out_number) {
out_number->val[BN_LIMBS - 1] = temp;
}

// out_number = (bignum512) in_number
// Assumes in_number is a raw bigendian 512-bit number
// Guarantees out_number is normalized
void bn_read_be_512(const uint8_t *in_number, bignum512 *out_number) {
bignum256 lower = {0}, upper = {0};

bn_read_be(in_number + 32, &lower);
bn_read_be(in_number, &upper);

const int shift_length = BN_BITS_PER_LIMB * BN_LIMBS - 256;
uint32_t shift = upper.val[0] & ((1 << shift_length) - 1);
for (int i = 0; i < shift_length; i++) {
bn_rshift(&upper);
}
lower.val[BN_LIMBS - 1] |= shift << (BN_BITS_PER_LIMB - shift_length);

for (int i = 0; i < BN_LIMBS; i++) {
out_number->val[i] = lower.val[i];
}
for (int i = 0; i < BN_LIMBS; i++) {
out_number->val[i + BN_LIMBS] = upper.val[i];
}
}

// out_number = (256BE) in_number
// Assumes in_number < 2**256
// Guarantess out_number is a raw bigendian 256-bit number
Expand Down Expand Up @@ -525,8 +558,7 @@ void bn_mod(bignum256 *x, const bignum256 *prime) {
// res = k * x
// Assumes k and x are normalized
// Guarantees res is normalized 18 digit little endian number in base 2**29
void bn_multiply_long(const bignum256 *k, const bignum256 *x,
uint32_t res[2 * BN_LIMBS]) {
void bn_multiply_long(const bignum256 *k, const bignum256 *x, bignum512 *res) {
// Uses long multiplication in base 2**29, see
// https://en.wikipedia.org/wiki/Multiplication_algorithm#Long_multiplication

Expand All @@ -545,7 +577,7 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x,
// <= 2**35 + 9 * 2**58 < 2**64
}

res[i] = acc & BN_LIMB_MASK;
res->val[i] = acc & BN_LIMB_MASK;
acc >>= BN_BITS_PER_LIMB;
// acc <= 2**35 - 1 == 2**(64 - BITS_PER_LIMB) - 1
}
Expand All @@ -563,20 +595,20 @@ void bn_multiply_long(const bignum256 *k, const bignum256 *x,
// <= 2**35 + 9 * 2**58 < 2**64
}

res[i] = acc & (BN_BASE - 1);
res->val[i] = acc & (BN_BASE - 1);
acc >>= BN_BITS_PER_LIMB;
// acc < 2**35 == 2**(64 - BITS_PER_LIMB)
}

res[2 * BN_LIMBS - 1] = acc;
res->val[2 * BN_LIMBS - 1] = acc;
}

// Auxiliary function for bn_multiply
// Assumes 0 <= d <= 8 == LIMBS - 1
// Assumes res is normalized and res < 2**(256 + 29*d + 31)
// Guarantess res in normalized and res < 2 * prime * 2**(29*d)
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
void bn_multiply_reduce_step(bignum512 *res, const bignum256 *prime,
uint32_t d) {
// clang-format off
// Computes res = res - (res // 2**(256 + BITS_PER_LIMB * d)) * prime * 2**(BITS_PER_LIMB * d)
Expand All @@ -598,8 +630,9 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// clang-format on

uint32_t coef =
(res[d + BN_LIMBS - 1] >> (256 - (BN_LIMBS - 1) * BN_BITS_PER_LIMB)) +
(res[d + BN_LIMBS] << ((BN_LIMBS * BN_BITS_PER_LIMB) - 256));
(res->val[d + BN_LIMBS - 1] >>
(256 - (BN_LIMBS - 1) * BN_BITS_PER_LIMB)) +
(res->val[d + BN_LIMBS] << ((BN_LIMBS * BN_BITS_PER_LIMB) - 256));

// coef == res // 2**(256 + BITS_PER_LIMB * d)

Expand All @@ -613,7 +646,7 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
uint64_t acc = 1ull << shift;

for (int i = 0; i < BN_LIMBS; i++) {
acc += (((uint64_t)(BN_BASE - 1)) << shift) + res[d + i] -
acc += (((uint64_t)(BN_BASE - 1)) << shift) + res->val[d + i] -
prime->val[i] * (uint64_t)coef;
// acc neither overflow 64 bits nor underflow zero
// Proof:
Expand All @@ -633,7 +666,7 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// == (2**35 - 1) + (2**31 + 1) * (2**29 - 1)
// <= 2**35 + 2**60 + 2**29 < 2**64

res[d + i] = acc & BN_LIMB_MASK;
res->val[d + i] = acc & BN_LIMB_MASK;
acc >>= BN_BITS_PER_LIMB;
// acc <= 2**(64 - BITS_PER_LIMB) - 1 == 2**35 - 1

Expand Down Expand Up @@ -664,16 +697,14 @@ void bn_multiply_reduce_step(uint32_t res[2 * BN_LIMBS], const bignum256 *prime,
// == 1 << shift
// clang-format on

res[d + BN_LIMBS] = 0;
res->val[d + BN_LIMBS] = 0;
}

// Auxiliary function for bn_multiply
// Partly reduces res and stores both in x and res
// Assumes res in normalized and res < 2**519
// Partly reduces x
// Assumes x in normalized and res < 2**519
// Guarantees x is normalized and partly reduced modulo prime
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
void bn_multiply_reduce(bignum256 *x, uint32_t res[2 * BN_LIMBS],
const bignum256 *prime) {
void bn_reduce(bignum512 *x, const bignum256 *prime) {
for (int i = BN_LIMBS - 1; i >= 0; i--) {
// res < 2**(256 + 29*i + 31)
// Proof:
Expand All @@ -684,11 +715,7 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[2 * BN_LIMBS],
// else:
// res < 2 * prime * 2**(29 * (i + 1))
// <= 2**256 * 2**(29*i + 29) < 2**(256 + 29*i + 31)
bn_multiply_reduce_step(res, prime, i);
}

for (int i = 0; i < BN_LIMBS; i++) {
x->val[i] = res[i];
bn_multiply_reduce_step(x, prime, i);
}
}

Expand All @@ -697,12 +724,13 @@ void bn_multiply_reduce(bignum256 *x, uint32_t res[2 * BN_LIMBS],
// Guarantees x is normalized and partly reduced modulo prime
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime) {
uint32_t res[2 * BN_LIMBS] = {0};
bignum512 res = {0};

bn_multiply_long(k, x, res);
bn_multiply_reduce(x, res, prime);
bn_multiply_long(k, x, &res);
bn_reduce(&res, prime);
bn_copy_lower(&res, x);

memzero(res, sizeof(res));
memzero(&res, sizeof(res));
}

// Partly reduces x modulo prime
Expand Down Expand Up @@ -858,7 +886,7 @@ void bn_sqrt(bignum256 *x, const bignum256 *prime) {
// http://en.wikipedia.org/wiki/Quadratic_residue#Prime_or_prime_power_modulus
// If prime % 4 == 3, then sqrt(x) % prime == x**((prime+1)//4) % prime

assert(prime->val[BN_LIMBS - 1] % 4 == 3);
assert(prime->val[0] % 4 == 3);

// e = (prime + 1) // 4
bignum256 e = {0};
Expand Down Expand Up @@ -1591,6 +1619,39 @@ void bn_subtract(const bignum256 *x, const bignum256 *y, bignum256 *res) {
// == 1
}

// Returns 0 if x is zero
// Returns 1 if x is a square modulo prime
// Returns -1 if x is not a square modulo prime
// Assumes x is normalized, x < 2**259
// Assumes prime is normalized, 2**256 - 2**224 <= prime <= 2**256
// Assumes prime is a prime
// The function doesn't have neither constant control flow nor constant memory
// access flow with regard to prime
int bn_legendre(const bignum256 *x, const bignum256 *prime) {
// This is a naive implementation
// A better implementation would be to use the Euclidean algorithm together with the quadratic reciprocity law

// e = (prime - 1) / 2
bignum256 e = {0};
bn_copy(prime, &e);
bn_rshift(&e);

// res = x**e % prime
bignum256 res = {0};
bn_power_mod(x, &e, prime, &res);
bn_mod(&res, prime);

if (bn_is_one(&res)) {
return 1;
}

if (bn_is_zero(&res)) {
return 0;
}

return -1;
}

// q = x // d, r = x % d
// Assumes x is normalized, 1 <= d <= 61304
// Guarantees q is normalized
Expand Down
9 changes: 9 additions & 0 deletions crypto/bignum.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ typedef struct {
uint32_t val[BN_LIMBS];
} bignum256;

// Represents the number sum([val[i] * 2**(29*i) for i in range(18))
typedef struct {
uint32_t val[2 * BN_LIMBS];
} bignum512;

static inline uint32_t read_be(const uint8_t *data) {
return (((uint32_t)data[0]) << 24) | (((uint32_t)data[1]) << 16) |
(((uint32_t)data[2]) << 8) | (((uint32_t)data[3]));
Expand All @@ -67,7 +72,9 @@ static inline void write_le(uint8_t *data, uint32_t x) {
data[0] = x;
}

void bn_copy_lower(const bignum512 *x, bignum256 *y);
void bn_read_be(const uint8_t *in_number, bignum256 *out_number);
void bn_read_be_512(const uint8_t *in_number, bignum512 *out_number);
void bn_write_be(const bignum256 *in_number, uint8_t *out_number);
void bn_read_le(const uint8_t *in_number, bignum256 *out_number);
void bn_write_le(const bignum256 *in_number, uint8_t *out_number);
Expand All @@ -94,6 +101,7 @@ void bn_mult_half(bignum256 *x, const bignum256 *prime);
void bn_mult_k(bignum256 *x, uint8_t k, const bignum256 *prime);
void bn_mod(bignum256 *x, const bignum256 *prime);
void bn_multiply(const bignum256 *k, bignum256 *x, const bignum256 *prime);
void bn_reduce(bignum512 *x, const bignum256 *prime);
void bn_fast_mod(bignum256 *x, const bignum256 *prime);
void bn_power_mod(const bignum256 *x, const bignum256 *e,
const bignum256 *prime, bignum256 *res);
Expand All @@ -108,6 +116,7 @@ void bn_subi(bignum256 *x, uint32_t y, const bignum256 *prime);
void bn_subtractmod(const bignum256 *x, const bignum256 *y, bignum256 *res,
const bignum256 *prime);
void bn_subtract(const bignum256 *x, const bignum256 *y, bignum256 *res);
int bn_legendre(const bignum256 *x, const bignum256 *prime);
void bn_long_division(bignum256 *x, uint32_t d, bignum256 *q, uint32_t *r);
void bn_divmod58(bignum256 *x, uint32_t *r);
void bn_divmod1000(bignum256 *x, uint32_t *r);
Expand Down
Loading