From 1e40f58cb2bdb2bcb64c8bd94301a2c145d3f072 Mon Sep 17 00:00:00 2001 From: Abhinav Saxena Date: Mon, 2 Dec 2024 11:45:33 +0530 Subject: [PATCH 1/4] add checks for ML-KEM keys Signed-off-by: Abhinav Saxena --- tests/CMakeLists.txt | 9 ++++ tests/vectors_kem.c | 119 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 121 insertions(+), 7 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6d08516a8..fae241991 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS}) add_executable(vectors_kem vectors_kem.c) target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS}) +if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS) + # workaround for Windows .dll + if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING) + target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition) + else() + target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE") + endif() +endif() + # Enable Valgrind-based timing side-channel analysis for test_kem and test_sig if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD) message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.") diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index a7a1dc6a7..2b0ac88ff 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -11,15 +11,34 @@ #include #include - +#include #include "system_info.c" +#ifdef OQS_ENABLE_KEM_ML_KEM +/* macros for sanity checks for encaps and decaps key */ +#define ML_KEM_BLOCKSIZE 384 +#define ML_KEM_K_MAX 4 +#define ML_KEM_N 256 +#define ML_KEM_1024_PK_SIZE 1568 +#define ML_KEM_Q 3329 +#define SHA256_OP_LEN 32 +/* since x is 12 bits, max value could be 4095. the below macro uses this to implement a simple time constant mod 3329 */ +#define MOD_Q(x) ((x) - ((x >= ML_KEM_Q) * ML_KEM_Q)) +#endif //OQS_ENABLE_KEM_ML_KEM + struct { const uint8_t *pos; } prng_state = { .pos = 0 }; +/* MLKEM-specific functions */ +static inline bool is_ml_kem(const char *method_name) { + return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) + || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) + || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024)); +} + static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) { size_t i; fprintf(fp, "%s", S); @@ -58,13 +77,75 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) { } } -/* ML_KEM-specific functions */ -static inline bool is_ml_kem(const char *method_name) { - return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) - || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) - || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024)); +#ifdef OQS_ENABLE_KEM_ML_KEM +static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { + /* sanity checks */ + if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name); + return false; + } + /* buffer to hold public key hash */ + uint8_t pkdig[SHA256_OP_LEN] = {0}; + /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 + K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ + uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + /* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */ + OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32); + /* compare it with public key hash stored at 768k+32 offset */ + if (0 != memcmp(pkdig, sk + (ML_KEM_BLOCKSIZE * K * 2) + 32, SHA256_OP_LEN)) { + return false; + } + return true; } +static inline bool sanityCheckPK(const uint8_t *pk, uint32_t pkLen, const char *method_name) { + /* sanity checks */ + if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name); + return false; + } + unsigned int i, j; + /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 + K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ + uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + /* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions + encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */ + uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0}; + /* buffer to hold encoded value */ + uint8_t buffe[ML_KEM_1024_PK_SIZE] = {0}; + uint16_t *buff_dec; + /* perform byte decoding as per Algo 6 of FIPS 203 */ + for (i = 0; i < K; i++) { + buff_dec = &buffd[i * ML_KEM_N]; + const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE]; + for (j = 0; j < ML_KEM_N / 2; j++) { + buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; + buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]); + buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF; + buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]); + } + } + /* perform byte encoding as per Algo 5 of FIPS 203 */ + for (i = 0; i < K; i++) { + uint16_t t0, t1; + buff_dec = &buffd[i * ML_KEM_N]; + uint8_t *buff_enc = &buffe[i * ML_KEM_BLOCKSIZE]; + for (j = 0; j < ML_KEM_N / 2; j++) { + t0 = buff_dec[2 * j]; + t1 = buff_dec[2 * j + 1]; + buff_enc[3 * j + 0] = (t0 >> 0); + buff_enc[3 * j + 1] = (t0 >> 8) | (t1 << 4); + buff_enc[3 * j + 2] = (t1 >> 4); + } + } + /* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */ + if (0 != memcmp(buffe, pk, pkLen - 32)) { + return false; + } + return true; +} +#endif //OQS_ENABLE_KEM_ML_KEM + static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) { (void) personalization_string; prng_state.pos = entropy_input; @@ -134,6 +215,14 @@ static OQS_STATUS kem_kg_vector(const char *method_name, fprintBstr(fh, "ek: ", public_key, kem->length_public_key); fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key); +#ifdef OQS_ENABLE_KEM_ML_KEM + if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) { + ret = OQS_ERROR; + fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) { ret = OQS_SUCCESS; } else { @@ -208,6 +297,14 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name, goto err; } +#ifdef OQS_ENABLE_KEM_ML_KEM + if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) { + ret = OQS_ERROR; + fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk); if (rc != OQS_SUCCESS) { fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name); @@ -273,6 +370,14 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name, goto err; } +#ifdef OQS_ENABLE_KEM_ML_KEM + if (false == sanityCheckSK(encdec_sk, method_name)) { + ret = OQS_ERROR; + fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk); if (rc != OQS_SUCCESS) { fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name); @@ -469,4 +574,4 @@ int main(int argc, char **argv) { } else { return EXIT_SUCCESS; } -} +} \ No newline at end of file From 5005d7f8c1ca81bb94aa7cbc78e5664c10dcb3f8 Mon Sep 17 00:00:00 2001 From: Abhinav Saxena Date: Tue, 3 Dec 2024 11:02:27 +0530 Subject: [PATCH 2/4] fix build issues Signed-off-by: Abhinav Saxena --- tests/vectors_kem.c | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index 2b0ac88ff..437c356c4 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -98,7 +98,7 @@ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { return true; } -static inline bool sanityCheckPK(const uint8_t *pk, uint32_t pkLen, const char *method_name) { +static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) { /* sanity checks */ if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) { fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name); @@ -133,9 +133,9 @@ static inline bool sanityCheckPK(const uint8_t *pk, uint32_t pkLen, const char * for (j = 0; j < ML_KEM_N / 2; j++) { t0 = buff_dec[2 * j]; t1 = buff_dec[2 * j + 1]; - buff_enc[3 * j + 0] = (t0 >> 0); - buff_enc[3 * j + 1] = (t0 >> 8) | (t1 << 4); - buff_enc[3 * j + 2] = (t1 >> 4); + buff_enc[3 * j + 0] = (uint8_t)(t0 >> 0); + buff_enc[3 * j + 1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + buff_enc[3 * j + 2] = (uint8_t)(t1 >> 4); } } /* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */ @@ -574,4 +574,4 @@ int main(int argc, char **argv) { } else { return EXIT_SUCCESS; } -} \ No newline at end of file +} From faf5669f63cbdad21a1397c8eac035d3e33de481 Mon Sep 17 00:00:00 2001 From: Abhinav Saxena Date: Tue, 3 Dec 2024 11:24:35 +0530 Subject: [PATCH 3/4] deadcode removal Signed-off-by: Abhinav Saxena --- tests/vectors_kem.c | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index 437c356c4..a3dfe99cf 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -217,7 +217,6 @@ static OQS_STATUS kem_kg_vector(const char *method_name, #ifdef OQS_ENABLE_KEM_ML_KEM if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) { - ret = OQS_ERROR; fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name); goto err; } @@ -299,7 +298,6 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name, #ifdef OQS_ENABLE_KEM_ML_KEM if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) { - ret = OQS_ERROR; fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name); goto err; } @@ -372,7 +370,6 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name, #ifdef OQS_ENABLE_KEM_ML_KEM if (false == sanityCheckSK(encdec_sk, method_name)) { - ret = OQS_ERROR; fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name); goto err; } From f9bc387f023a473de25933b5bb9d6acfc67cd102 Mon Sep 17 00:00:00 2001 From: Abhinav Saxena Date: Mon, 16 Dec 2024 11:41:52 +0530 Subject: [PATCH 4/4] fix the review comments Signed-off-by: Abhinav Saxena --- tests/vectors_kem.c | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index a3dfe99cf..eef566358 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -78,6 +78,20 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) { } #ifdef OQS_ENABLE_KEM_ML_KEM +/* fetch value of 'K' from MlL-KEM version */ +uint8_t get_ml_kem_k(const char *method) { + if (0 == strcmp(method, OQS_KEM_alg_ml_kem_512)) { + return 2; + } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_768)) { + return 3; + } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_1024)) { + return 4; + } else { + return 0; // Default/error case + } +} + +/* sanity check for private/decaps key */ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { /* sanity checks */ if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) { @@ -88,7 +102,11 @@ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { uint8_t pkdig[SHA256_OP_LEN] = {0}; /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ - uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + uint8_t K = get_ml_kem_k(method_name); + if (0 == K) { + fprintf(stderr, "K value can be fetched only for ML-KEM !\n"); + return false; + } /* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */ OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32); /* compare it with public key hash stored at 768k+32 offset */ @@ -98,6 +116,7 @@ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { return true; } +/* sanity check for public/encaps key */ static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) { /* sanity checks */ if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) { @@ -107,7 +126,11 @@ static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *me unsigned int i, j; /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ - uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + uint8_t K = get_ml_kem_k(method_name); + if (0 == K) { + fprintf(stderr, "K value can be fetched only for ML-KEM !\n"); + return false; + } /* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */ uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0}; @@ -119,8 +142,8 @@ static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *me buff_dec = &buffd[i * ML_KEM_N]; const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE]; for (j = 0; j < ML_KEM_N / 2; j++) { - buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; - buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]); + buff_dec[2 * j + 0] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; + buff_dec[2 * j + 0] = MOD_Q(buff_dec[2 * j]); buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF; buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]); }