diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6d08516a8..fae241991 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS}) add_executable(vectors_kem vectors_kem.c) target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS}) +if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS) + # workaround for Windows .dll + if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING) + target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition) + else() + target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE") + endif() +endif() + # Enable Valgrind-based timing side-channel analysis for test_kem and test_sig if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD) message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.") diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index a7a1dc6a7..eef566358 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -11,15 +11,34 @@ #include #include - +#include #include "system_info.c" +#ifdef OQS_ENABLE_KEM_ML_KEM +/* macros for sanity checks for encaps and decaps key */ +#define ML_KEM_BLOCKSIZE 384 +#define ML_KEM_K_MAX 4 +#define ML_KEM_N 256 +#define ML_KEM_1024_PK_SIZE 1568 +#define ML_KEM_Q 3329 +#define SHA256_OP_LEN 32 +/* since x is 12 bits, max value could be 4095. the below macro uses this to implement a simple time constant mod 3329 */ +#define MOD_Q(x) ((x) - ((x >= ML_KEM_Q) * ML_KEM_Q)) +#endif //OQS_ENABLE_KEM_ML_KEM + struct { const uint8_t *pos; } prng_state = { .pos = 0 }; +/* MLKEM-specific functions */ +static inline bool is_ml_kem(const char *method_name) { + return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) + || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) + || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024)); +} + static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) { size_t i; fprintf(fp, "%s", S); @@ -58,13 +77,98 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) { } } -/* ML_KEM-specific functions */ -static inline bool is_ml_kem(const char *method_name) { - return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) - || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) - || (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024)); +#ifdef OQS_ENABLE_KEM_ML_KEM +/* fetch value of 'K' from MlL-KEM version */ +uint8_t get_ml_kem_k(const char *method) { + if (0 == strcmp(method, OQS_KEM_alg_ml_kem_512)) { + return 2; + } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_768)) { + return 3; + } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_1024)) { + return 4; + } else { + return 0; // Default/error case + } +} + +/* sanity check for private/decaps key */ +static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { + /* sanity checks */ + if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name); + return false; + } + /* buffer to hold public key hash */ + uint8_t pkdig[SHA256_OP_LEN] = {0}; + /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 + K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ + uint8_t K = get_ml_kem_k(method_name); + if (0 == K) { + fprintf(stderr, "K value can be fetched only for ML-KEM !\n"); + return false; + } + /* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */ + OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32); + /* compare it with public key hash stored at 768k+32 offset */ + if (0 != memcmp(pkdig, sk + (ML_KEM_BLOCKSIZE * K * 2) + 32, SHA256_OP_LEN)) { + return false; + } + return true; } +/* sanity check for public/encaps key */ +static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) { + /* sanity checks */ + if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name); + return false; + } + unsigned int i, j; + /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 + K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ + uint8_t K = get_ml_kem_k(method_name); + if (0 == K) { + fprintf(stderr, "K value can be fetched only for ML-KEM !\n"); + return false; + } + /* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions + encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */ + uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0}; + /* buffer to hold encoded value */ + uint8_t buffe[ML_KEM_1024_PK_SIZE] = {0}; + uint16_t *buff_dec; + /* perform byte decoding as per Algo 6 of FIPS 203 */ + for (i = 0; i < K; i++) { + buff_dec = &buffd[i * ML_KEM_N]; + const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE]; + for (j = 0; j < ML_KEM_N / 2; j++) { + buff_dec[2 * j + 0] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; + buff_dec[2 * j + 0] = MOD_Q(buff_dec[2 * j]); + buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF; + buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]); + } + } + /* perform byte encoding as per Algo 5 of FIPS 203 */ + for (i = 0; i < K; i++) { + uint16_t t0, t1; + buff_dec = &buffd[i * ML_KEM_N]; + uint8_t *buff_enc = &buffe[i * ML_KEM_BLOCKSIZE]; + for (j = 0; j < ML_KEM_N / 2; j++) { + t0 = buff_dec[2 * j]; + t1 = buff_dec[2 * j + 1]; + buff_enc[3 * j + 0] = (uint8_t)(t0 >> 0); + buff_enc[3 * j + 1] = (uint8_t)((t0 >> 8) | (t1 << 4)); + buff_enc[3 * j + 2] = (uint8_t)(t1 >> 4); + } + } + /* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */ + if (0 != memcmp(buffe, pk, pkLen - 32)) { + return false; + } + return true; +} +#endif //OQS_ENABLE_KEM_ML_KEM + static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) { (void) personalization_string; prng_state.pos = entropy_input; @@ -134,6 +238,13 @@ static OQS_STATUS kem_kg_vector(const char *method_name, fprintBstr(fh, "ek: ", public_key, kem->length_public_key); fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key); +#ifdef OQS_ENABLE_KEM_ML_KEM + if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) { + fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) { ret = OQS_SUCCESS; } else { @@ -208,6 +319,13 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name, goto err; } +#ifdef OQS_ENABLE_KEM_ML_KEM + if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) { + fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk); if (rc != OQS_SUCCESS) { fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name); @@ -273,6 +391,13 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name, goto err; } +#ifdef OQS_ENABLE_KEM_ML_KEM + if (false == sanityCheckSK(encdec_sk, method_name)) { + fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name); + goto err; + } +#endif //OQS_ENABLE_KEM_ML_KEM + rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk); if (rc != OQS_SUCCESS) { fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);