From 8b63d1db821d4b0d772733ed7dc80f1deca98182 Mon Sep 17 00:00:00 2001 From: Heemin Kim Date: Thu, 11 Jul 2024 11:12:12 -0700 Subject: [PATCH] Add jni interface to use a binary hnsw index with faiss (#1778) (#1817) * Add jni interface to use a binary hnsw index with faiss (#1747) * Fix memory leak on test code (#1776) --------- Signed-off-by: Heemin Kim --- jni/CMakeLists.txt | 4 +- jni/include/commons.h | 24 +- jni/include/faiss_index_service.h | 113 ++++++++++ jni/include/faiss_methods.h | 42 ++++ jni/include/faiss_wrapper.h | 14 +- jni/include/jni_util.h | 7 + .../org_opensearch_knn_jni_FaissService.h | 29 ++- .../org_opensearch_knn_jni_JNICommons.h | 16 ++ jni/src/commons.cpp | 22 ++ jni/src/faiss_index_service.cpp | 164 ++++++++++++++ jni/src/faiss_methods.cpp | 40 ++++ jni/src/faiss_wrapper.cpp | 205 ++++++++++++++---- jni/src/jni_util.cpp | 51 +++++ .../org_opensearch_knn_jni_FaissService.cpp | 49 ++++- jni/src/org_opensearch_knn_jni_JNICommons.cpp | 23 ++ jni/tests/faiss_index_service_test.cpp | 134 ++++++++++++ jni/tests/faiss_wrapper_test.cpp | 197 +++++++++++++++-- jni/tests/faiss_wrapper_unit_test.cpp | 17 +- jni/tests/mocks/faiss_index_mock.h | 35 +++ jni/tests/mocks/faiss_index_service_mock.h | 44 ++++ jni/tests/mocks/faiss_methods_mock.h | 28 +++ jni/tests/test_util.cpp | 56 +++++ jni/tests/test_util.h | 14 +- .../org/opensearch/knn/index/IndexUtil.java | 13 +- .../opensearch/knn/index/KNNIndexShard.java | 14 +- .../opensearch/knn/index/query/KNNWeight.java | 8 +- .../knn/index/util/FieldInfoExtractor.java | 37 ++++ .../org/opensearch/knn/jni/FaissService.java | 45 ++++ .../org/opensearch/knn/jni/JNICommons.java | 19 ++ .../org/opensearch/knn/jni/JNIService.java | 65 +++++- .../opensearch/knn/index/IndexUtilTests.java | 7 +- .../knn/index/KNNIndexShardTests.java | 4 +- .../knn/index/query/KNNWeightTests.java | 75 ++++++- .../index/util/FieldInfoExtractorTests.java | 42 ++++ .../opensearch/knn/jni/JNIServiceTests.java | 47 ++++ .../java/org/opensearch/knn/TestUtils.java | 43 ++++ 36 files changed, 1639 insertions(+), 108 deletions(-) create mode 100644 jni/include/faiss_index_service.h create mode 100644 jni/include/faiss_methods.h create mode 100644 jni/src/faiss_index_service.cpp create mode 100644 jni/src/faiss_methods.cpp create mode 100644 jni/tests/faiss_index_service_test.cpp create mode 100644 jni/tests/mocks/faiss_index_mock.h create mode 100644 jni/tests/mocks/faiss_index_service_mock.h create mode 100644 jni/tests/mocks/faiss_methods_mock.h create mode 100644 src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java create mode 100644 src/test/java/org/opensearch/knn/index/util/FieldInfoExtractorTests.java diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index e2003e0f7..2fe26875d 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -20,7 +20,6 @@ set(TARGET_LIBS "") # Libs to be installed set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD_REQUIRED True) - option(CONFIG_FAISS "Configure faiss library build when this is on") option(CONFIG_NMSLIB "Configure nmslib library build when this is on") option(CONFIG_TEST "Configure tests when this is on") @@ -112,6 +111,8 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_util.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_index_service.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_methods.cpp ) target_link_libraries(${TARGET_LIB_FAISS} ${TARGET_LINK_FAISS_LIB} ${TARGET_LIB_UTIL} OpenMP::OpenMP_CXX) target_include_directories(${TARGET_LIB_FAISS} PRIVATE @@ -153,6 +154,7 @@ if ("${WIN32}" STREQUAL "") tests/nmslib_wrapper_unit_test.cpp tests/test_util.cpp tests/commons_test.cpp + tests/faiss_index_service_test.cpp ) target_link_libraries( diff --git a/jni/include/commons.h b/jni/include/commons.h index 67a141c8b..d02439377 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -22,10 +22,24 @@ namespace knn_jni { * @param memoryAddress The address of the memory location where data will be stored. * @param data 2D float array containing data to be stored in native memory. * @param initialCapacity The initial capacity of the memory location. - * @return memory address where the data is stored. + * @return memory address of std::vector where the data is stored. */ jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address of std::vector where the data is stored. + */ + jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong); + /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory * address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)} @@ -34,6 +48,14 @@ namespace knn_jni { */ void freeVectorData(jlong); + /** + * Free up the memory allocated for the data stored in memory address. This function should be used with the memory + * address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)} + * + * @param memoryAddress address to be freed. + */ + void freeByteVectorData(jlong); + /** * Extracts query time efSearch from method parameters **/ diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h new file mode 100644 index 000000000..59f15fda9 --- /dev/null +++ b/jni/include/faiss_index_service.h @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +/** + * This file contains classes for index operations which are free of JNI + */ + +#ifndef OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H +#define OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H + +#include +#include "faiss/MetricType.h" +#include "jni_util.h" +#include "faiss_methods.h" +#include + +namespace knn_jni { +namespace faiss_wrapper { + + +/** + * A class to provide operations on index + * This class should evolve to have only cpp object but not jni object + */ +class IndexService { +public: + IndexService(std::unique_ptr faissMethods); + //TODO Remove dependency on JNIUtilInterface and JNIEnv + //TODO Reduce the number of parameters + + /** + * Create index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param ids a list of document ids for corresponding vectors + * @param indexPath path to write index + * @param parameters parameters to be applied to faiss index + */ + virtual void createIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters); + virtual ~IndexService() = default; +protected: + std::unique_ptr faissMethods; +}; + +/** + * A class to provide operations on index + * This class should evolve to have only cpp object but not jni object + */ +class BinaryIndexService : public IndexService { +public: + //TODO Remove dependency on JNIUtilInterface and JNIEnv + //TODO Reduce the number of parameters + BinaryIndexService(std::unique_ptr faissMethods); + /** + * Create binary index + * + * @param jniUtil jni util + * @param env jni environment + * @param metric space type for distance calculation + * @param indexDescription index description to be used by faiss index factory + * @param dim dimension of vectors + * @param numIds number of vectors + * @param threadCount number of thread count to be used while adding data + * @param vectorsAddress memory address which is holding vector data + * @param ids a list of document ids for corresponding vectors + * @param indexPath path to write index + * @param parameters parameters to be applied to faiss index + */ + virtual void createIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters + ) override; + virtual ~BinaryIndexService() = default; +}; + +} +} + + +#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H diff --git a/jni/include/faiss_methods.h b/jni/include/faiss_methods.h new file mode 100644 index 000000000..38d8d756a --- /dev/null +++ b/jni/include/faiss_methods.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#ifndef OPENSEARCH_KNN_FAISS_METHODS_H +#define OPENSEARCH_KNN_FAISS_METHODS_H + +#include "faiss/Index.h" +#include "faiss/IndexBinary.h" +#include "faiss/IndexIDMap.h" +#include "faiss/index_io.h" + +namespace knn_jni { +namespace faiss_wrapper { + +/** + * A class having wrapped faiss methods + * + * This class helps to mock faiss methods during unit test + */ +class FaissMethods { +public: + FaissMethods() = default; + virtual faiss::Index* indexFactory(int d, const char* description, faiss::MetricType metric); + virtual faiss::IndexBinary* indexBinaryFactory(int d, const char* description); + virtual faiss::IndexIDMapTemplate* indexIdMap(faiss::Index* index); + virtual faiss::IndexIDMapTemplate* indexBinaryIdMap(faiss::IndexBinary* index); + virtual void writeIndex(const faiss::Index* idx, const char* fname); + virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname); + virtual ~FaissMethods() = default; +}; + +} //namespace faiss_wrapper +} //namespace knn_jni + + +#endif //OPENSEARCH_KNN_FAISS_METHODS_H \ No newline at end of file diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 3bfc66325..5ac17cfd1 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -13,6 +13,7 @@ #define OPENSEARCH_KNN_FAISS_WRAPPER_H #include "jni_util.h" +#include "faiss_index_service.h" #include namespace knn_jni { @@ -20,7 +21,7 @@ namespace knn_jni { // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ); + jstring indexPathJ, jobject parametersJ, IndexService* indexService); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. @@ -33,6 +34,11 @@ namespace knn_jni { // Return a pointer to the loaded index jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ); + // Load a binary index from indexPathJ into memory. + // + // Return a pointer to the loaded index + jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ); + // Check if a loaded index requires shared state bool IsSharedIndexStateRequired(jlong indexPointerJ); @@ -68,6 +74,12 @@ namespace knn_jni { jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Execute a query against the binary index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index b3d55f1c1..97a8f063c 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -71,6 +71,8 @@ namespace knn_jni { virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect ) = 0; + virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect ) = 0; virtual std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0; @@ -79,6 +81,8 @@ namespace knn_jni { // ------------------------------ MISC HELPERS ------------------------------ virtual int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) = 0; + virtual int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) = 0; + virtual int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) = 0; virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0; @@ -146,6 +150,7 @@ namespace knn_jni { std::vector Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim); std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ); int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ); + int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ); int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ); int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ); int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ); @@ -168,6 +173,7 @@ namespace knn_jni { void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); + void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); private: std::unordered_map cachedClasses; @@ -193,6 +199,7 @@ namespace knn_jni { extern const std::string COSINESIMIL; extern const std::string INNER_PRODUCT; extern const std::string NEG_DOT_PRODUCT; + extern const std::string HAMMING_BIT; extern const std::string NPROBES; extern const std::string COARSE_QUANTIZER; diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index ef382507a..3d6aef45c 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -18,6 +18,7 @@ #ifdef __cplusplus extern "C" { #endif + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndex @@ -26,6 +27,14 @@ extern "C" { JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndex + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate @@ -42,6 +51,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex (JNIEnv *, jclass, jstring); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: loadBinaryIndex + * Signature: (Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex + (JNIEnv *, jclass, jstring); + /* * Class: org_opensearch_knn_jni_FaissService * Method: isSharedIndexStateRequired @@ -69,7 +86,7 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt /* * Class: org_opensearch_knn_jni_FaissService * Method: queryIndex - * Signature: (J[FI[Ljava/util/MapI)[Lorg/opensearch/knn/index/query/KNNQueryResult; + * Signature: (J[FILjava/util/Map[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jintArray); @@ -77,11 +94,19 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd /* * Class: org_opensearch_knn_jni_FaissService * Method: queryIndexWithFilter - * Signature: (J[FI[JLjava/util/MapI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + * Signature: (J[FILjava/util/Map[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jlongArray, jint, jintArray); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: queryBIndexWithFilter + * Signature: (J[BILjava/util/Map[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + */ +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter + (JNIEnv *, jclass, jlong, jbyteArray, jint, jobject, jlongArray, jint, jintArray); + /* * Class: org_opensearch_knn_jni_FaissService * Method: free diff --git a/jni/include/org_opensearch_knn_jni_JNICommons.h b/jni/include/org_opensearch_knn_jni_JNICommons.h index d0758d7c8..89de76520 100644 --- a/jni/include/org_opensearch_knn_jni_JNICommons.h +++ b/jni/include/org_opensearch_knn_jni_JNICommons.h @@ -26,6 +26,14 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData (JNIEnv *, jclass, jlong, jobjectArray, jlong); +/* + * Class: org_opensearch_knn_jni_JNICommons + * Method: storeVectorData + * Signature: (J[[FJJ) + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData + (JNIEnv *, jclass, jlong, jobjectArray, jlong); + /* * Class: org_opensearch_knn_jni_JNICommons * Method: freeVectorData @@ -34,6 +42,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData (JNIEnv *, jclass, jlong); +/* +* Class: org_opensearch_knn_jni_JNICommons +* Method: freeVectorData +* Signature: (J)V +*/ +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData +(JNIEnv *, jclass, jlong); + #ifdef __cplusplus } #endif diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index c2b2354cc..13f59194e 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -32,6 +32,21 @@ jlong knn_jni::commons::storeVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIE return (jlong) vect; } +jlong knn_jni::commons::storeByteVectorData(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong memoryAddressJ, + jobjectArray dataJ, jlong initialCapacityJ) { + std::vector *vect; + if ((long) memoryAddressJ == 0) { + vect = new std::vector(); + vect->reserve((long)initialCapacityJ); + } else { + vect = reinterpret_cast*>(memoryAddressJ); + } + int dim = jniUtil->GetInnerDimensionOf2dJavaByteArray(env, dataJ); + jniUtil->Convert2dJavaObjectArrayAndStoreToByteVector(env, dataJ, dim, vect); + + return (jlong) vect; +} + void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { if (memoryAddressJ != 0) { auto *vect = reinterpret_cast*>(memoryAddressJ); @@ -39,6 +54,13 @@ void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { } } +void knn_jni::commons::freeByteVectorData(jlong memoryAddressJ) { + if (memoryAddressJ != 0) { + auto *vect = reinterpret_cast*>(memoryAddressJ); + delete vect; + } +} + int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map methodParams, std::string methodParam, int defaultValue) { if (methodParams.empty()) { return defaultValue; diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp new file mode 100644 index 000000000..8c5ba36af --- /dev/null +++ b/jni/src/faiss_index_service.cpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#include "faiss_index_service.h" +#include "faiss_methods.h" +#include "faiss/index_factory.h" +#include "faiss/Index.h" +#include "faiss/IndexBinary.h" +#include "faiss/IndexHNSW.h" +#include "faiss/IndexBinaryHNSW.h" +#include "faiss/IndexIVFFlat.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/IndexIDMap.h" +#include "faiss/index_io.h" +#include +#include +#include +#include +#include + +namespace knn_jni { +namespace faiss_wrapper { + +template +void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, + const std::unordered_map& parametersCpp, INDEX * index) { + std::unordered_map::const_iterator value; + if (auto * indexIvf = dynamic_cast(index)) { + if ((value = parametersCpp.find(knn_jni::NPROBES)) != parametersCpp.end()) { + indexIvf->nprobe = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + + if ((value = parametersCpp.find(knn_jni::COARSE_QUANTIZER)) != parametersCpp.end() + && indexIvf->quantizer != nullptr) { + auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, value->second); + SetExtraParameters(jniUtil, env, subParametersCpp, indexIvf->quantizer); + } + } + + if (auto * indexHnsw = dynamic_cast(index)) { + + if ((value = parametersCpp.find(knn_jni::EF_CONSTRUCTION)) != parametersCpp.end()) { + indexHnsw->hnsw.efConstruction = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + + if ((value = parametersCpp.find(knn_jni::EF_SEARCH)) != parametersCpp.end()) { + indexHnsw->hnsw.efSearch = jniUtil->ConvertJavaObjectToCppInteger(env, value->second); + } + } +} + +IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} + +void IndexService::createIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters + ) { + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddress); + + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + std::unique_ptr indexWriter(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + std::unique_ptr idMap(faissMethods->indexIdMap(indexWriter.get())); + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); + + // Write the index to disk + faissMethods->writeIndex(idMap.get(), indexPath.c_str()); +} + +BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} + +void BinaryIndexService::createIndex( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + std::unordered_map parameters + ) { + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddress); + + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + if(numVectors == 0) { + throw std::runtime_error("Number of vectors cannot be 0"); + } + + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + std::unique_ptr indexWriter(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(threadCount != 0) { + omp_set_num_threads(threadCount); + } + + // Add extra parameters that cant be configured with the index factory + SetExtraParameters(jniUtil, env, parameters, indexWriter.get()); + + // Check that the index does not need to be trained + if(!indexWriter->is_trained) { + throw std::runtime_error("Index is not trained"); + } + + // Add vectors + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); + idMap->add_with_ids(numVectors, inputVectors->data(), ids.data()); + + // Write the index to disk + faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); +} + +} // namespace faiss_wrapper +} // namesapce knn_jni diff --git a/jni/src/faiss_methods.cpp b/jni/src/faiss_methods.cpp new file mode 100644 index 000000000..05c8f459a --- /dev/null +++ b/jni/src/faiss_methods.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#include "faiss_methods.h" +#include "faiss/index_factory.h" + +namespace knn_jni { +namespace faiss_wrapper { + +faiss::Index* FaissMethods::indexFactory(int d, const char* description, faiss::MetricType metric) { + return faiss::index_factory(d, description, metric); +} + +faiss::IndexBinary* FaissMethods::indexBinaryFactory(int d, const char* description) { + return faiss::index_binary_factory(d, description); +} + +faiss::IndexIDMapTemplate* FaissMethods::indexIdMap(faiss::Index* index) { + return new faiss::IndexIDMap(index); +} + +faiss::IndexIDMapTemplate* FaissMethods::indexBinaryIdMap(faiss::IndexBinary* index) { + return new faiss::IndexBinaryIDMap(index); +} + +void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) { + faiss::write_index(idx, fname); +} +void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) { + faiss::write_index_binary(idx, fname); +} + +} // namespace faiss_wrapper +} // namesapce knn_jni diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 692a33aee..c4c6e18eb 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -12,6 +12,7 @@ #include "jni_util.h" #include "faiss_wrapper.h" #include "faiss_util.h" +#include "faiss_index_service.h" #include "faiss/impl/io.h" #include "faiss/index_factory.h" @@ -23,6 +24,8 @@ #include "faiss/impl/IDSelector.h" #include "faiss/IndexIVFPQ.h" #include "commons.h" +#include "faiss/IndexBinaryIVF.h" +#include "faiss/IndexBinaryHNSW.h" #include #include @@ -83,8 +86,7 @@ bool isIndexIVFPQL2(faiss::Index * index); faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index); void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) { - + jstring indexPathJ, jobject parametersJ, IndexService* indexService) { if (idsJ == nullptr) { throw std::runtime_error("IDs cannot be null"); } @@ -109,63 +111,49 @@ void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JN // so that it is easier to access. auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); - // Get space type for this index + // Parameters to pass + // Metric type jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + jniUtil->DeleteLocalRef(env, spaceTypeJ); - // Read vectors from memory address - auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + // Dimension int dim = (int)dimJ; - // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value - int numVectors = (int) (inputVectors->size() / (uint64_t) dim); - if(numVectors == 0) { - throw std::runtime_error("Number of vectors cannot be 0"); - } + // Number of vectors int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); - if (numIds != numVectors) { - throw std::runtime_error("Number of IDs does not match number of vectors"); - } - // Create faiss index + // Index description jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + jniUtil->DeleteLocalRef(env, indexDescriptionJ); - std::unique_ptr indexWriter; - indexWriter.reset(faiss::index_factory(dim, indexDescriptionCpp.c_str(), metric)); - - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + // Thread count + int threadCount = 0; if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { - auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); - omp_set_num_threads(threadCount); - } - - // Add extra parameters that cant be configured with the index factory - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { - jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; - auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); - SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); - jniUtil->DeleteLocalRef(env, subParametersJ); + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); } - jniUtil->DeleteLocalRef(env, parametersJ); - // Check that the index does not need to be trained - if(!indexWriter->is_trained) { - throw std::runtime_error("Index is not trained"); - } + // Vectors address + int64_t vectorsAddress = (int64_t)vectorsAddressJ; - auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); - idMap.add_with_ids(numVectors, inputVectors->data(), idVector.data()); + // Ids + auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); - // Write the index to disk + // Index path std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index(&idMap, indexPathCpp.c_str()); - // Releasing the vectorsAddressJ memory as that is not required once we have created the index. - // This is not the ideal approach, please refer this gh issue for long term solution: - // https://github.com/opensearch-project/k-NN/issues/1600 - delete inputVectors; + + // Extra parameters + // TODO: parse the entire map and remove jni object + std::unordered_map subParametersCpp; + if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersCpp[knn_jni::PARAMETERS]); + } + // end parameters to pass + + // Create index + indexService->createIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numIds, threadCount, vectorsAddress, ids, indexPathCpp, subParametersCpp); } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, @@ -248,6 +236,19 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI return (jlong) indexReader; } +jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + // Skipping IO_FLAG_PQ_SKIP_SDC_TABLE because the index is read only and the sdc table is only used during ingestion + // Skipping IO_PRECOMPUTE_TABLE because it is only needed for IVFPQ-l2 and it leads to high memory consumption if + // done for each segment. Instead, we will set it later on with `setSharedIndexState` + faiss::IndexBinary* indexReader = faiss::read_index_binary(indexPathCpp.c_str(), faiss::IO_FLAG_READ_ONLY | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE); + return (jlong) indexReader; +} + bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) { auto * index = reinterpret_cast(indexPointerJ); return isIndexIVFPQL2(index); @@ -415,6 +416,121 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter return results; } +jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + + if (queryVectorJ == nullptr) { + throw std::runtime_error("Query Vector cannot be null"); + } + + auto *indexReader = reinterpret_cast(indexPointerJ); + + if (indexReader == nullptr) { + throw std::runtime_error("Invalid pointer to index"); + } + + std::unordered_map methodParams; + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } + + // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from + // the query point + std::vector dis(kJ); + std::vector ids(kJ); + int8_t* rawQueryvector = jniUtil->GetByteArrayElements(env, queryVectorJ, nullptr); + /* + Setting the omp_set_num_threads to 1 to make sure that no new OMP threads are getting created. + */ + omp_set_num_threads(1); + // create the filterSearch params if the filterIdsJ is not a null pointer + if(filterIdsJ != nullptr) { + jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); + int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); + std::unique_ptr idSelector; + if(filterIdsTypeJ == BITMAP) { + idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); + } else { + faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); + idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices)); + } + faiss::SearchParameters *searchParameters; + faiss::SearchParametersHNSW hnswParams; + faiss::SearchParametersIVF ivfParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + if(hnswReader) { + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); + hnswParams.sel = idSelector.get(); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } + searchParameters = &hnswParams; + } else { + auto ivfReader = dynamic_cast(indexReader->index); + if(ivfReader) { + ivfParams.sel = idSelector.get(); + searchParameters = &ivfParams; + } + } + try { + indexReader->search(1, reinterpret_cast(rawQueryvector), kJ, dis.data(), ids.data(), searchParameters); + } catch (...) { + jniUtil->ReleaseByteArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + throw; + } + jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT); + } else { + faiss::SearchParameters *searchParameters = nullptr; + faiss::SearchParametersHNSW hnswParams; + std::unique_ptr idGrouper; + std::vector idGrouperBitmap; + auto hnswReader = dynamic_cast(indexReader->index); + // TODO currently, search parameter is not supported in binary index + // To avoid test failure, we skip setting ef search when methodPramsJ is null temporary + if(hnswReader!= nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) { + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } + searchParameters = &hnswParams; + } + try { + indexReader->search(1, reinterpret_cast(rawQueryvector), kJ, dis.data(), ids.data(), searchParameters); + } catch (...) { + jniUtil->ReleaseByteArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + throw; + } + } + jniUtil->ReleaseByteArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + + // If there are not k results, the results will be padded with -1. Find the first -1, and set result size to that + // index + int resultSize = kJ; + auto it = std::find(ids.begin(), ids.end(), -1); + if (it != ids.end()) { + resultSize = it - ids.begin(); + } + + jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); + jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); + + jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); + + jobject result; + for(int i = 0; i < resultSize; ++i) { + result = jniUtil->NewObject(env, resultClass, allArgs, ids[i], dis[i]); + jniUtil->SetObjectArrayElement(env, results, i, result); + } + return results; +} + void knn_jni::faiss_wrapper::Free(jlong indexPointer) { auto *indexWrapper = reinterpret_cast(indexPointer); delete indexWrapper; @@ -510,6 +626,11 @@ faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { return faiss::METRIC_INNER_PRODUCT; } + // Space type is not used for binary index. Use L2 just to avoid an error. + if (spaceType == knn_jni::HAMMING_BIT) { + return faiss::METRIC_L2; + } + throw std::runtime_error("Invalid spaceType"); } diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index a1faa4894..919191596 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -261,6 +261,39 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env env->DeleteLocalRef(array2dJ); } +void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, + int dim, std::vector *vect) { + + if (array2dJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + int numVectors = env->GetArrayLength(array2dJ); + this->HasExceptionInStack(env); + + for (int i = 0; i < numVectors; ++i) { + auto vectorArray = (jbyteArray)env->GetObjectArrayElement(array2dJ, i); + this->HasExceptionInStack(env, "Unable to get object array element"); + + if (dim != env->GetArrayLength(vectorArray)) { + throw std::runtime_error("Dimension of vectors is inconsistent"); + } + + uint8_t* vector = reinterpret_cast(env->GetByteArrayElements(vectorArray, nullptr)); + if (vector == nullptr) { + this->HasExceptionInStack(env); + throw std::runtime_error("Unable to get byte array elements"); + } + + for(int j = 0; j < dim; ++j) { + vect->push_back(vector[j]); + } + env->ReleaseByteArrayElements(vectorArray, reinterpret_cast(vector), JNI_ABORT); + } + this->HasExceptionInStack(env); + env->DeleteLocalRef(array2dJ); +} + std::vector knn_jni::JNIUtil::ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) { if (arrayJ == nullptr) { @@ -302,6 +335,23 @@ int knn_jni::JNIUtil::GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectAr return dim; } +int knn_jni::JNIUtil::GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) { + + if (array2dJ == nullptr) { + throw std::runtime_error("Array cannot be null"); + } + + if (env->GetArrayLength(array2dJ) <= 0) { + return 0; + } + + auto vectorArray = (jbyteArray)env->GetObjectArrayElement(array2dJ, 0); + this->HasExceptionInStack(env); + int dim = env->GetArrayLength(vectorArray); + this->HasExceptionInStack(env); + return dim; +} + int knn_jni::JNIUtil::GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) { if (arrayJ == nullptr) { @@ -490,6 +540,7 @@ const std::string knn_jni::LINF = "linf"; const std::string knn_jni::COSINESIMIL = "cosinesimil"; const std::string knn_jni::INNER_PRODUCT = "innerproduct"; const std::string knn_jni::NEG_DOT_PRODUCT = "negdotprod"; +const std::string knn_jni::HAMMING_BIT = "hammingbit"; const std::string knn_jni::NPROBES = "nprobes"; const std::string knn_jni::COARSE_QUANTIZER = "coarse_quantizer"; diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index d8cf2f9cf..5f9c83ea8 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -44,7 +44,32 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex(JNIE jstring indexPathJ, jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &indexService); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete reinterpret_cast*>(vectorsAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, + jstring indexPathJ, jobject parametersJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ, &binaryIndexService); + + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete reinterpret_cast*>(vectorsAddressJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -75,6 +100,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEn return NULL; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex(JNIEnv * env, jclass cls, jstring indexPathJ) +{ + try { + return knn_jni::faiss_wrapper::LoadBinaryIndex(&jniUtil, env, indexPathJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return NULL; +} + JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired (JNIEnv * env, jclass cls, jlong indexPointerJ) { @@ -132,6 +167,18 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } +JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter + (JNIEnv * env, jclass cls, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + + try { + return knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; + +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * env, jclass cls, jlong indexPointerJ) { try { diff --git a/jni/src/org_opensearch_knn_jni_JNICommons.cpp b/jni/src/org_opensearch_knn_jni_JNICommons.cpp index ccdd11882..0bc2e4633 100644 --- a/jni/src/org_opensearch_knn_jni_JNICommons.cpp +++ b/jni/src/org_opensearch_knn_jni_JNICommons.cpp @@ -49,6 +49,18 @@ jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) return (long)memoryAddressJ; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData(JNIEnv * env, jclass cls, +jlong memoryAddressJ, jobjectArray dataJ, jlong initialCapacityJ) + +{ + try { + return knn_jni::commons::storeByteVectorData(&jniUtil, env, memoryAddressJ, dataJ, initialCapacityJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (long)memoryAddressJ; +} + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNIEnv * env, jclass cls, jlong memoryAddressJ) { @@ -58,3 +70,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData(JNI jniUtil.CatchCppExceptionAndThrowJava(env); } } + + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData(JNIEnv * env, jclass cls, + jlong memoryAddressJ) +{ + try { + return knn_jni::commons::freeByteVectorData(memoryAddressJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp new file mode 100644 index 000000000..f876edced --- /dev/null +++ b/jni/tests/faiss_index_service_test.cpp @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + + +#include "faiss_index_service.h" +#include "mocks/faiss_methods_mock.h" +#include "mocks/faiss_index_mock.h" +#include "test_util.h" +#include +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "commons.h" + +using ::testing::_; +using ::testing::NiceMock; +using ::testing::Return; + +TEST(CreateIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 2; + vectors.reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::MetricType metricType = faiss::METRIC_L2; + std::string indexDescription = "HNSW32,Flat"; + int threadCount = 1; + std::unordered_map parametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Setup faiss method mock + // This object is handled by unique_ptr inside indexService.createIndex() + MockIndex* index = new MockIndex(); + EXPECT_CALL(*index, add(numIds, vectors.data())) + .Times(1); + // This object is handled by unique_ptr inside indexService.createIndex() + faiss::IndexIDMap* indexIdMap = new faiss::IndexIDMap(index); + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + EXPECT_CALL(*mockFaissMethods, indexFactory(dim, ::testing::StrEq(indexDescription.c_str()), metricType)) + .WillOnce(Return(index)); + EXPECT_CALL(*mockFaissMethods, indexIdMap(index)) + .WillOnce(Return(indexIdMap)); + EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + .Times(1); + + // Create the index + knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); + indexService.createIndex( + &mockJNIUtil, + jniEnv, + metricType, + indexDescription, + dim, + numIds, + threadCount, + (int64_t) &vectors, + ids, + indexPath, + parametersMap); +} + +TEST(CreateBinaryIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 128; + vectors.reserve(numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { + vectors.push_back(test_util::RandomInt(0, 255)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::MetricType metricType = faiss::METRIC_L2; + std::string indexDescription = "BHNSW32"; + int threadCount = 1; + std::unordered_map parametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Setup faiss method mock + // This object is handled by unique_ptr inside indexService.createIndex() + MockIndexBinary* index = new MockIndexBinary(); + EXPECT_CALL(*index, add(numIds, vectors.data())) + .Times(1); + // This object is handled by unique_ptr inside indexService.createIndex() + faiss::IndexBinaryIDMap* indexIdMap = new faiss::IndexBinaryIDMap(index); + std::unique_ptr mockFaissMethods(new MockFaissMethods()); + EXPECT_CALL(*mockFaissMethods, indexBinaryFactory(dim, ::testing::StrEq(indexDescription.c_str()))) + .WillOnce(Return(index)); + EXPECT_CALL(*mockFaissMethods, indexBinaryIdMap(index)) + .WillOnce(Return(indexIdMap)); + EXPECT_CALL(*mockFaissMethods, writeIndexBinary(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + .Times(1); + + // Create the index + knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(mockFaissMethods)); + indexService.createIndex( + &mockJNIUtil, + jniEnv, + metricType, + indexDescription, + dim, + numIds, + threadCount, + (int64_t) &vectors, + ids, + indexPath, + parametersMap); +} \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index f1d2ee7f4..c6663a19a 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -19,9 +19,12 @@ #include "test_util.h" #include "faiss/IndexHNSW.h" #include "faiss/IndexIVFPQ.h" +#include "mocks/faiss_index_service_mock.h" +using ::testing::_; using ::testing::NiceMock; using ::testing::Return; +using ::testing::Mock; float randomDataMin = -500.0; float randomDataMax = 500.0; @@ -33,44 +36,81 @@ TEST(FaissCreateIndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - auto *vectors = new std::vector(); + std::vector vectors; int dim = 2; - vectors->reserve(dim * numIds); + vectors.reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); for (int j = 0; j < dim; ++j) { - vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); std::string spaceType = knn_jni::L2; - std::string index_description = "HNSW32,Flat"; + std::string indexDescription = "HNSW32,Flat"; std::unordered_map parametersMap; parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; - parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&index_description; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&indexDescription; + std::unordered_map subParametersMap; + parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; // Set up jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - EXPECT_CALL(mockJNIUtil, - GetJavaObjectArrayLength( - jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors->size())); - // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + NiceMock mockIndexService(std::move(faissMethods)); + EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + .Times(1); + knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) vectors, dim , (jstring)&indexPath, - (jobject)¶metersMap); + (jlong) &vectors, dim , (jstring)&indexPath, + (jobject)¶metersMap, &mockIndexService); +} - // Make sure index can be loaded - std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); +TEST(FaissCreateBinaryIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 128; + vectors.reserve(numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { + vectors.push_back(test_util::RandomInt(0, 255)); + } + } - // Clean up - std::remove(indexPath.c_str()); + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::string spaceType = knn_jni::HAMMING_BIT; + std::string indexDescription = "BHNSW32"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&indexDescription; + std::unordered_map subParametersMap; + parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + NiceMock mockIndexService(std::move(faissMethods)); + EXPECT_CALL(mockIndexService, createIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, (int64_t)&vectors, ids, indexPath, subParametersMap)) + .Times(1); + + // This method calls delete vectors at the end + knn_jni::faiss_wrapper::CreateIndex( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) &vectors, dim , (jstring)&indexPath, + (jobject)¶metersMap, &mockIndexService); } TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { @@ -168,6 +208,58 @@ TEST(FaissLoadIndexTest, BasicAssertions) { std::remove(indexPath.c_str()); } +TEST(FaissLoadBinaryIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + auto vectors = std::vector(numIds); + int dim = 128; + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { + vectors.push_back(test_util::RandomInt(0, 255)); + } + } + + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::string spaceType = knn_jni::HAMMING_BIT; + std::string method = "BHNSW32"; + + // Create the index + std::unique_ptr createdIndex( + test_util::FaissCreateBinaryIndex(dim, method)); + auto createdIndexWithData = + test_util::FaissAddBinaryData(createdIndex.get(), ids, vectors); + + test_util::FaissWriteBinaryIndex(&createdIndexWithData, indexPath); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + std::unique_ptr loadedIndexPointer( + reinterpret_cast(knn_jni::faiss_wrapper::LoadBinaryIndex( + &mockJNIUtil, jniEnv, (jstring)&indexPath))); + + // Compare serialized versions + auto createIndexSerialization = + test_util::FaissGetSerializedBinaryIndex(&createdIndexWithData); + auto loadedIndexSerialization = test_util::FaissGetSerializedBinaryIndex( + reinterpret_cast(loadedIndexPointer.get())); + + ASSERT_NE(0, loadedIndexSerialization.data.size()); + ASSERT_EQ(createIndexSerialization.data.size(), + loadedIndexSerialization.data.size()); + + for (int i = 0; i < loadedIndexSerialization.data.size(); ++i) { + ASSERT_EQ(createIndexSerialization.data[i], + loadedIndexSerialization.data[i]); + } + + // Clean up + std::remove(indexPath.c_str()); +} + TEST(FaissLoadIndexTest, HNSWPQDisableSdcTable) { // Check that when we load an HNSWPQ index, the sdc table is not present. faiss::idx_t numIds = 256; @@ -289,6 +381,61 @@ TEST(FaissQueryIndexTest, BasicAssertions) { } } +TEST(FaissQueryBinaryIndexTest, BasicAssertions) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + auto vectors = std::vector(numIds); + int dim = 128; + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { + vectors.push_back(test_util::RandomInt(0, 255)); + } + } + + // Define query data + int k = 10; + int numQueries = 100; + std::vector> queries; + + for (int i = 0; i < numQueries; i++) { + std::vector query; + query.reserve(dim); + for (int j = 0; j < dim; j++) { + query.push_back(test_util::RandomInt(0, 255)); + } + queries.push_back(query); + } + + // Create the index + std::string method = "BHNSW32"; + std::unique_ptr createdIndex( + test_util::FaissCreateBinaryIndex(dim, method)); + auto createdIndexWithData = + test_util::FaissAddBinaryData(createdIndex.get(), ids, vectors); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + for (auto query : queries) { + std::unique_ptr *>> results( + reinterpret_cast *> *>( + knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&createdIndexWithData), + reinterpret_cast(&query), k, nullptr, nullptr, 0, nullptr))); + + ASSERT_EQ(k, results->size()); + + // Need to free up each result + for (auto it : *results.get()) { + delete it; + } + } +} + //Test for a bug reported in https://github.com/opensearch-project/k-NN/issues/1435 TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { // Define the index data @@ -409,6 +556,10 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(&parentIds))) + .WillRepeatedly(Return(parentIds.size())); for (auto query : queries) { std::unique_ptr *>> results( reinterpret_cast *> *>( @@ -488,13 +639,13 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { // Define the data faiss::idx_t numIds = 200; std::vector ids; - auto *vectors = new std::vector(); + std::vector vectors; int dim = 2; - vectors->reserve(dim * numIds); + vectors.reserve(dim * numIds); for (int64_t i = 0; i < numIds; ++i) { ids.push_back(i); for (int j = 0; j < dim; ++j) { - vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); } } @@ -513,13 +664,15 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { EXPECT_CALL(mockJNIUtil, GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors->size())); + .WillRepeatedly(Return(vectors.size())); // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); knn_jni::faiss_wrapper::CreateIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong)vectors, dim, (jstring)&indexPath, - (jobject)¶metersMap); + (jlong)&vectors, dim, (jstring)&indexPath, + (jobject)¶metersMap, &IndexService); // Make sure index can be loaded std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp index d9fdac23f..d68ec69c6 100644 --- a/jni/tests/faiss_wrapper_unit_test.cpp +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -25,12 +25,12 @@ using ::testing::NiceMock; using idx_t = faiss::idx_t; -struct MockIndex : faiss::IndexHNSW { - explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { +struct FaissMockIndex : faiss::IndexHNSW { + explicit FaissMockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { } }; -struct MockIdMap : faiss::IndexIDMap { +struct FaissMockIdMap : faiss::IndexIDMap { mutable idx_t nCalled{}; mutable const float *xCalled{}; mutable int kCalled{}; @@ -40,7 +40,7 @@ struct MockIdMap : faiss::IndexIDMap { mutable const faiss::SearchParametersHNSW *paramsCalled{}; mutable faiss::RangeSearchResult *resCalled{}; - explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate(index) { + explicit FaissMockIdMap(FaissMockIndex *index) : faiss::IndexIDMapTemplate(index) { } void search( @@ -108,8 +108,8 @@ class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam { @@ -119,8 +119,8 @@ class FaissWrapperParametrizedRangeSearchTestFixture : public testing::TestWithP } protected: - MockIndex index_; - MockIdMap id_map_; + FaissMockIndex index_; + FaissMockIdMap id_map_; }; namespace query_index_test { @@ -369,4 +369,3 @@ namespace range_search_test { ) ); } - diff --git a/jni/tests/mocks/faiss_index_mock.h b/jni/tests/mocks/faiss_index_mock.h new file mode 100644 index 000000000..521cbb2d3 --- /dev/null +++ b/jni/tests/mocks/faiss_index_mock.h @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + + #ifndef OPENSEARCH_KNN_FAISS_INDEX_MOCK_H + #define OPENSEARCH_KNN_FAISS_INDEX_MOCK_H + +#include "faiss/Index.h" +#include "faiss/IndexBinary.h" +#include + +using idx_t = int64_t; + +class MockIndex : public faiss::Index { +public: + MOCK_METHOD(void, add, (idx_t n, const float* x), (override)); + MOCK_METHOD(void, search, (idx_t n, const float* x, idx_t k, float* distances, idx_t* labels, const faiss::SearchParameters* params), (const, override)); + MOCK_METHOD(void, reset, (), (override)); +}; + +class MockIndexBinary : public faiss::IndexBinary { +public: + MOCK_METHOD(void, add, (idx_t n, const uint8_t* x), (override)); + MOCK_METHOD(void, search, (idx_t n, const uint8_t* x, idx_t k, int32_t* distances, idx_t* labels, const faiss::SearchParameters* params), (const, override)); + MOCK_METHOD(void, reset, (), (override)); +}; + +#endif // OPENSEARCH_KNN_FAISS_INDEX_MOCK_H \ No newline at end of file diff --git a/jni/tests/mocks/faiss_index_service_mock.h b/jni/tests/mocks/faiss_index_service_mock.h new file mode 100644 index 000000000..7af08c82e --- /dev/null +++ b/jni/tests/mocks/faiss_index_service_mock.h @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#ifndef OPENSEARCH_KNN_FAISS_INDEX_SERVICE_MOCK_H +#define OPENSEARCH_KNN_FAISS_INDEX_SERVICE_MOCK_H + +#include "faiss_index_service.h" +#include + +using ::knn_jni::faiss_wrapper::FaissMethods; +using ::knn_jni::faiss_wrapper::IndexService; +typedef std::unordered_map StringToJObjectMap; + +class MockIndexService : public IndexService { +public: + MockIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {}; + MOCK_METHOD( + void, + createIndex, + ( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + faiss::MetricType metric, + std::string indexDescription, + int dim, + int numIds, + int threadCount, + int64_t vectorsAddress, + std::vector ids, + std::string indexPath, + StringToJObjectMap parameters + ), + (override)); +}; + +#endif // OPENSEARCH_KNN_FAISS_INDEX_SERVICE_MOCK_H \ No newline at end of file diff --git a/jni/tests/mocks/faiss_methods_mock.h b/jni/tests/mocks/faiss_methods_mock.h new file mode 100644 index 000000000..64a23b895 --- /dev/null +++ b/jni/tests/mocks/faiss_methods_mock.h @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + + #ifndef OPENSEARCH_KNN_FAISS_METHODS_MOCK_H + #define OPENSEARCH_KNN_FAISS_METHODS_MOCK_H + +#include "faiss_methods.h" +#include + +class MockFaissMethods : public knn_jni::faiss_wrapper::FaissMethods { +public: + MOCK_METHOD(faiss::Index*, indexFactory, (int d, const char* description, faiss::MetricType metric), (override)); + MOCK_METHOD(faiss::IndexBinary*, indexBinaryFactory, (int d, const char* description), (override)); + MOCK_METHOD(faiss::IndexIDMapTemplate*, indexIdMap, (faiss::Index* index), (override)); + MOCK_METHOD(faiss::IndexIDMapTemplate*, indexBinaryIdMap, (faiss::IndexBinary* index), (override)); + MOCK_METHOD(void, writeIndex, (const faiss::Index* idx, const char* fname), (override)); + MOCK_METHOD(void, writeIndexBinary, (const faiss::IndexBinary* idx, const char* fname), (override)); +}; + +#endif // OPENSEARCH_KNN_FAISS_METHODS_MOCK_H \ No newline at end of file diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 92532b9e2..2149f8a1a 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -51,6 +51,12 @@ test_util::MockJNIUtil::MockJNIUtil() { (*reinterpret_cast> *>(array2dJ))) for (auto item : v) data->push_back(item); }); + ON_CALL(*this, Convert2dJavaObjectArrayAndStoreToByteVector) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ, int dim, std::vector* data) { + for (const auto &v : + (*reinterpret_cast> *>(array2dJ))) + for (auto item : v) data->push_back(item); + }); // arrayJ is re-interpreted as std::vector * @@ -150,6 +156,15 @@ test_util::MockJNIUtil::MockJNIUtil() { .size(); }); + // array2dJ is re-interpreted as a std::vector> * and then + // the size of the first element is returned + ON_CALL(*this, GetInnerDimensionOf2dJavaByteArray) + .WillByDefault([this](JNIEnv *env, jobjectArray array2dJ) { + return (*reinterpret_cast> *>( + array2dJ))[0] + .size(); + }); + // arrayJ is re-interpreted as a std::vector * and the size is returned ON_CALL(*this, GetJavaFloatArrayLength) .WillByDefault([this](JNIEnv *env, jfloatArray arrayJ) { @@ -249,12 +264,22 @@ faiss::Index *test_util::FaissCreateIndex(int dim, const std::string &method, return faiss::index_factory(dim, method.c_str(), metric); } +faiss::IndexBinary *test_util::FaissCreateBinaryIndex(int dim, const std::string &method) { + return faiss::index_binary_factory(dim, method.c_str()); +} + faiss::VectorIOWriter test_util::FaissGetSerializedIndex(faiss::Index *index) { faiss::VectorIOWriter vectorIoWriter; faiss::write_index(index, &vectorIoWriter); return vectorIoWriter; } +faiss::VectorIOWriter test_util::FaissGetSerializedBinaryIndex(faiss::IndexBinary *index) { + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(index, &vectorIoWriter); + return vectorIoWriter; +} + faiss::Index *test_util::FaissLoadFromSerializedIndex( std::vector *indexSerial) { faiss::VectorIOReader vectorIoReader; @@ -262,6 +287,13 @@ faiss::Index *test_util::FaissLoadFromSerializedIndex( return faiss::read_index(&vectorIoReader, 0); } +faiss::IndexBinary *test_util::FaissLoadFromSerializedBinaryIndex( + std::vector *indexSerial) { + faiss::VectorIOReader vectorIoReader; + vectorIoReader.data = *indexSerial; + return faiss::read_index_binary(&vectorIoReader, 0); +} + faiss::IndexIDMap test_util::FaissAddData(faiss::Index *index, std::vector ids, std::vector dataset) { @@ -270,15 +302,32 @@ faiss::IndexIDMap test_util::FaissAddData(faiss::Index *index, return idMap; } +faiss::IndexBinaryIDMap test_util::FaissAddBinaryData(faiss::IndexBinary *index, + std::vector ids, + std::vector dataset) { + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(index); + idMap.add_with_ids(ids.size(), dataset.data(), ids.data()); + return idMap; +} + void test_util::FaissWriteIndex(faiss::Index *index, const std::string &indexPath) { faiss::write_index(index, indexPath.c_str()); } +void test_util::FaissWriteBinaryIndex(faiss::IndexBinary *index, + const std::string &indexPath) { + faiss::write_index_binary(index, indexPath.c_str()); +} + faiss::Index *test_util::FaissLoadIndex(const std::string &indexPath) { return faiss::read_index(indexPath.c_str(), faiss::IO_FLAG_READ_ONLY); } +faiss::IndexBinary *test_util::FaissLoadBinaryIndex(const std::string &indexPath) { + return faiss::read_index_binary(indexPath.c_str(), faiss::IO_FLAG_READ_ONLY); +} + void test_util::FaissQueryIndex(faiss::Index *index, float *query, int k, float *distances, faiss::idx_t *ids) { index->search(1, query, k, distances, ids); @@ -377,6 +426,13 @@ float test_util::RandomFloat(float min, float max) { return distribution(e1); } +int test_util::RandomInt(int min, int max) { + std::random_device r; + std::default_random_engine e1(r()); + std::uniform_int_distribution distribution(min, max); + return distribution(e1); +} + std::vector test_util::RandomVectors(int dim, int64_t numVectors, float min, float max) { std::vector vectors(dim*numVectors); for (int64_t i = 0; i < dim*numVectors; i++) { diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index 8e73a8ab0..ba773fad3 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -46,6 +46,8 @@ namespace test_util { (JNIEnv * env, jobjectArray array2dJ, int dim)); MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToFloatVector, (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); + MOCK_METHOD(void, Convert2dJavaObjectArrayAndStoreToByteVector, + (JNIEnv * env, jobjectArray array2dJ, int dim, std::vector*vect)); MOCK_METHOD(std::vector, ConvertJavaIntArrayToCppIntVector, (JNIEnv * env, jintArray arrayJ)); MOCK_METHOD2(ConvertJavaMapToCppMap, @@ -64,6 +66,8 @@ namespace test_util { (JNIEnv * env, jfloatArray array, jboolean* isCopy)); MOCK_METHOD(int, GetInnerDimensionOf2dJavaFloatArray, (JNIEnv * env, jobjectArray array2dJ)); + MOCK_METHOD(int, GetInnerDimensionOf2dJavaByteArray, + (JNIEnv * env, jobjectArray array2dJ)); MOCK_METHOD(jint*, GetIntArrayElements, (JNIEnv * env, jintArray array, jboolean* isCopy)); MOCK_METHOD(jlong*, GetLongArrayElements, @@ -109,18 +113,25 @@ namespace test_util { faiss::Index* FaissCreateIndex(int dim, const std::string& method, faiss::MetricType metric); + faiss::IndexBinary* FaissCreateBinaryIndex(int dim, const std::string& method); faiss::VectorIOWriter FaissGetSerializedIndex(faiss::Index* index); + faiss::VectorIOWriter FaissGetSerializedBinaryIndex(faiss::IndexBinary* index); faiss::Index* FaissLoadFromSerializedIndex(std::vector* indexSerial); + faiss::IndexBinary* FaissLoadFromSerializedBinaryIndex(std::vector* indexSerial); faiss::IndexIDMap FaissAddData(faiss::Index* index, std::vector ids, std::vector dataset); - + faiss::IndexBinaryIDMap FaissAddBinaryData(faiss::IndexBinary* index, + std::vector ids, + std::vector dataset); void FaissWriteIndex(faiss::Index* index, const std::string& indexPath); + void FaissWriteBinaryIndex(faiss::IndexBinary* index, const std::string& indexPath); faiss::Index* FaissLoadIndex(const std::string& indexPath); + faiss::IndexBinary* FaissLoadBinaryIndex(const std::string &indexPath); void FaissQueryIndex(faiss::Index* index, float* query, int k, float* distances, faiss::idx_t* ids); @@ -156,6 +167,7 @@ namespace test_util { std::string RandomString(size_t length, const std::string& prefix, const std::string& suffix); float RandomFloat(float min, float max); + int RandomInt(int min, int max); std::vector RandomVectors(int dim, int64_t numVectors, float min, float max); diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 5cd4aaf81..7a5e93e14 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -35,6 +35,7 @@ import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; public class IndexUtil { @@ -258,9 +259,15 @@ public static ValidationException validateKnnField( * @param spaceType Space for this particular segment * @param knnEngine Engine used for the native library indices being loaded in * @param indexName Name of OpenSearch index that the segment files belong to + * @param indexDescription Index description of OpenSearch index with faiss that the segment files belong to * @return load parameters that will be passed to the JNI. */ - public static Map getParametersAtLoading(SpaceType spaceType, KNNEngine knnEngine, String indexName) { + public static Map getParametersAtLoading( + SpaceType spaceType, + KNNEngine knnEngine, + String indexName, + String indexDescription + ) { Map loadParameters = Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())); // For nmslib, we need to add the dynamic ef_search parameter that needs to be passed in when the @@ -268,6 +275,9 @@ public static Map getParametersAtLoading(SpaceType spaceType, KN if (KNNEngine.NMSLIB.equals(knnEngine)) { loadParameters.put(HNSW_ALGO_EF_SEARCH, KNNSettings.getEfSearchParam(indexName)); } + if (KNNEngine.FAISS.equals(knnEngine)) { + loadParameters.put(INDEX_DESCRIPTION_PARAMETER, indexDescription); + } return Collections.unmodifiableMap(loadParameters); } @@ -302,5 +312,4 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod } return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine); } - } diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index efa09662c..e00d36d2e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -23,6 +23,7 @@ import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; +import org.opensearch.knn.index.util.FieldInfoExtractor; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -96,7 +97,8 @@ public void warmup() throws IOException { getParametersAtLoading( engineFileContext.getSpaceType(), KNNEngine.getEngineNameFromPath(engineFileContext.getIndexPath()), - getIndexName() + getIndexName(), + engineFileContext.indexDescription ), getIndexName(), engineFileContext.getModelId() @@ -171,7 +173,6 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); SpaceType spaceType = SpaceType.getSpace(spaceTypeName); String modelId = fieldInfo.attributes().getOrDefault(MODEL_ID, null); - engineFiles.addAll( getEngineFileContexts( reader.getSegmentInfo().files(), @@ -180,7 +181,8 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine fileExtension, shardPath, spaceType, - modelId + modelId, + FieldInfoExtractor.getIndexDescription(fieldInfo) ) ); } @@ -197,7 +199,8 @@ List getEngineFileContexts( String fileExtension, Path shardPath, SpaceType spaceType, - String modelId + String modelId, + String indexDescription ) { String prefix = buildEngineFilePrefix(segmentName); String suffix = buildEngineFileSuffix(fieldName, fileExtension); @@ -205,7 +208,7 @@ List getEngineFileContexts( .filter(fileName -> fileName.startsWith(prefix)) .filter(fileName -> fileName.endsWith(suffix)) .map(fileName -> shardPath.resolve(fileName).toString()) - .map(fileName -> new EngineFileContext(spaceType, modelId, fileName)) + .map(fileName -> new EngineFileContext(spaceType, modelId, fileName, indexDescription)) .collect(Collectors.toList()); } @@ -216,5 +219,6 @@ static class EngineFileContext { private final SpaceType spaceType; private final String modelId; private final String indexPath; + private final String indexDescription; } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 539a08a02..fce8e8e04 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -37,6 +37,7 @@ import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator; import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator; +import org.opensearch.knn.index.util.FieldInfoExtractor; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -256,7 +257,12 @@ private Map doANNSearch(final LeafReaderContext context, final B new NativeMemoryEntryContext.IndexEntryContext( indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), + getParametersAtLoading( + spaceType, + knnEngine, + knnQuery.getIndexName(), + FieldInfoExtractor.getIndexDescription(fieldInfo) + ), knnQuery.getIndexName(), modelId ), diff --git a/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java b/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java new file mode 100644 index 000000000..5ad271969 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/FieldInfoExtractor.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.util; + +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; + +import java.io.IOException; + +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; + +/** + * Class having methods to extract a value from field info + */ +public class FieldInfoExtractor { + public static String getIndexDescription(FieldInfo fieldInfo) throws IOException { + String parameters = fieldInfo.attributes().get(KNNConstants.PARAMETERS); + if (parameters == null) { + return null; + } + + return (String) XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parameters), + MediaTypeRegistry.getDefaultMediaType() + ).map().getOrDefault(INDEX_DESCRIPTION_PARAMETER, null); + } +} diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 2c537cf00..f718ce6d5 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -63,6 +63,20 @@ class FaissService { */ public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + /** + * Create a binary index for the native library The memory occupied by the vectorsAddress will be freed up during the + * function call. So Java layer doesn't need to free up the memory. This is not an ideal behavior because Java layer + * created the memory address and that should only free up the memory. We are tracking the proper fix for this on this + * issue + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param parameters parameters to build index + */ + public static native void createBinaryIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); + /** * Create an index for the native library with a provided template index * @@ -90,6 +104,14 @@ public static native void createIndexFromTemplate( */ public static native long loadIndex(String indexPath); + /** + * Load a binary index into memory + * + * @param indexPath path to index file + * @return pointer to location in memory the index resides in + */ + public static native long loadBinaryIndex(String indexPath); + /** * Determine if index contains shared state. * @@ -126,6 +148,7 @@ public static native void createIndexFromTemplate( * @param indexPointer pointer to index in memory * @param queryVector vector to be used for query * @param k neighbors to be returned + * @param methodParameters method parameter * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of k neighbors */ @@ -143,6 +166,7 @@ public static native KNNQueryResult[] queryIndex( * @param indexPointer pointer to index in memory * @param queryVector vector to be used for query * @param k neighbors to be returned + * @param methodParameters method parameter * @param filterIds list of doc ids to include in the query result * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of k neighbors @@ -157,6 +181,27 @@ public static native KNNQueryResult[] queryIndexWithFilter( int[] parentIds ); + /** + * Query a binary index with filter + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param filterIds list of doc ids to include in the query result + * @param parentIds list of parent doc ids when the knn field is a nested field + * @return KNNQueryResult array of k neighbors + */ + public static native KNNQueryResult[] queryBinaryIndexWithFilter( + long indexPointer, + byte[] queryVector, + int k, + Map methodParameters, + long[] filterIds, + int filterIdsType, + int[] parentIds + ); + /** * Free native memory pointer */ diff --git a/src/main/java/org/opensearch/knn/jni/JNICommons.java b/src/main/java/org/opensearch/knn/jni/JNICommons.java index 90ad70c3d..d0111b115 100644 --- a/src/main/java/org/opensearch/knn/jni/JNICommons.java +++ b/src/main/java/org/opensearch/knn/jni/JNICommons.java @@ -47,6 +47,25 @@ public class JNICommons { */ public static native long storeVectorData(long memoryAddress, float[][] data, long initialCapacity); + /** + * This is utility function that can be used to store data in native memory. This function will allocate memory for + * the data(rows*columns) with initialCapacity and return the memory address where the data is stored. + * If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created. + * For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location + * will throw Exception. + * + *

+ * The function is not threadsafe. If multiple threads are trying to insert on same memory location, then it can + * lead to data corruption. + *

+ * + * @param memoryAddress The address of the memory location where data will be stored. + * @param data 2D byte array containing data to be stored in native memory. + * @param initialCapacity The initial capacity of the memory location. + * @return memory address where the data is stored. + */ + public static native long storeByteVectorData(long memoryAddress, byte[][] data, long initialCapacity); + /** * Free up the memory allocated for the data stored in memory address. This function should be used with the memory * address returned by {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 83afc592f..ed6a169c1 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -13,6 +13,7 @@ import org.apache.commons.lang.ArrayUtils; import org.opensearch.common.Nullable; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -22,6 +23,7 @@ * Service to distribute requests to the proper engine jni service */ public class JNIService { + private static final String FAISS_BINARY_INDEX_PREFIX = "B"; /** * Create an index for the native library. The memory occupied by the vectorsAddress will be freed up during the @@ -51,7 +53,12 @@ public static void createIndex( } if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { + FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); + } else { + FaissService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); + } return; } @@ -102,7 +109,12 @@ public static long loadIndex(String indexPath, Map parameters, K } if (KNNEngine.FAISS == knnEngine) { - return FaissService.loadIndex(indexPath); + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null + && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { + return FaissService.loadBinaryIndex(indexPath); + } else { + return FaissService.loadIndex(indexPath); + } } throw new IllegalArgumentException(String.format("LoadIndex not supported for provided engine : %s", knnEngine.getName())); @@ -162,12 +174,13 @@ public static void setSharedIndexState(long indexAddr, long shareIndexStateAddr, /** * Query an index * - * @param indexPointer pointer to index in memory - * @param queryVector vector to be used for query - * @param k neighbors to be returned - * @param knnEngine engine to query index - * @param filteredIds array of ints on which should be used for search. - * @param filterIdsType how to filter ids: Batch or BitMap + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap * @return KNNQueryResult array of k neighbors */ public static KNNQueryResult[] queryIndex( @@ -205,6 +218,42 @@ public static KNNQueryResult[] queryIndex( throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); } + /** + * Query a binary index + * + * @param indexPointer pointer to index in memory + * @param queryVector vector to be used for query + * @param k neighbors to be returned + * @param methodParameters method parameter + * @param knnEngine engine to query index + * @param filteredIds array of ints on which should be used for search. + * @param filterIdsType how to filter ids: Batch or BitMap + * @return KNNQueryResult array of k neighbors + */ + public static KNNQueryResult[] queryBinaryIndex( + long indexPointer, + byte[] queryVector, + int k, + @Nullable Map methodParameters, + KNNEngine knnEngine, + long[] filteredIds, + int filterIdsType, + int[] parentIds + ) { + if (KNNEngine.FAISS == knnEngine) { + return FaissService.queryBinaryIndexWithFilter( + indexPointer, + queryVector, + k, + methodParameters, + ArrayUtils.isEmpty(filteredIds) ? null : filteredIds, + filterIdsType, + parentIds + ); + } + throw new IllegalArgumentException(String.format("QueryBinaryIndex not supported for provided engine : %s", knnEngine.getName())); + } + /** * Free native memory pointer * diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index 00493b293..e6c3e96ee 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -57,9 +57,10 @@ public void testGetLoadParameters() { SpaceType spaceType1 = SpaceType.COSINESIMIL; KNNEngine knnEngine1 = KNNEngine.FAISS; String indexName = "my-test-index"; + String indexDescription = "HNSW32Flat"; - Map loadParameters = getParametersAtLoading(spaceType1, knnEngine1, indexName); - assertEquals(1, loadParameters.size()); + Map loadParameters = getParametersAtLoading(spaceType1, knnEngine1, indexName, indexDescription); + assertEquals(2, loadParameters.size()); assertEquals(spaceType1.getValue(), loadParameters.get(SPACE_TYPE)); // Test nmslib to ensure both space type and ef search are properly set @@ -84,7 +85,7 @@ public void testGetLoadParameters() { when(clusterService.state()).thenReturn(clusterState); KNNSettings.state().setClusterService(clusterService); - loadParameters = getParametersAtLoading(spaceType2, knnEngine2, indexName); + loadParameters = getParametersAtLoading(spaceType2, knnEngine2, indexName, null); assertEquals(2, loadParameters.size()); assertEquals(spaceType2.getValue(), loadParameters.get(SPACE_TYPE)); assertEquals(efSearchValue, loadParameters.get(HNSW_ALGO_EF_SEARCH)); diff --git a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java index 7b6f96d5a..42a59d26f 100644 --- a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java @@ -125,6 +125,7 @@ public void testGetEngineFileContexts() { String fileExt = ".test"; SpaceType spaceType = SpaceType.L2; String modelId = "test-model"; + String indexDescription = "test-description"; Set includedFileNames = ImmutableSet.of( String.format("%s_111_%s%s", segmentName, fieldName, fileExt), @@ -150,7 +151,8 @@ public void testGetEngineFileContexts() { fileExt, path, spaceType, - modelId + modelId, + indexDescription ); assertEquals(includedFileNames.size(), included.size()); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 75cd6e7a9..e73e86e90 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -54,6 +54,7 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.Function; @@ -71,9 +72,11 @@ import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; public class KNNWeightTests extends KNNTestCase { @@ -147,13 +150,27 @@ public void testQueryResultScoreFaiss() { testQueryScore( SpaceType.L2::scoreTranslation, SEGMENT_FILES_FAISS, - Map.of(SPACE_TYPE, SpaceType.L2.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName()) + Map.of( + SPACE_TYPE, + SpaceType.L2.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) ); // score translation for Faiss and inner product is different from default defined in Space enum testQueryScore( rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore), SEGMENT_FILES_FAISS, - Map.of(SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName()) + Map.of( + SPACE_TYPE, + SpaceType.INNER_PRODUCT.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) ); } @@ -421,7 +438,12 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); - final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName()); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); @@ -480,7 +502,14 @@ public void testANNWithFilterQuery_whenExactSearch_thenSuccess() { final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); - final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); @@ -531,7 +560,14 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); - final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); @@ -592,7 +628,14 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost, filterQueryWeight); - final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); final BinaryDocValues binaryDocValues = mock(BinaryDocValues.class); @@ -818,7 +861,16 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final FieldInfo fieldInfo = mock(FieldInfo.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); - when(fieldInfo.attributes()).thenReturn(Map.of(SPACE_TYPE, SpaceType.L2.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName())); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + SpaceType.L2.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); assertNotNull(knnScorer); @@ -881,7 +933,14 @@ private SegmentReader getMockedSegmentReader() { when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); // Prepare fieldInfo - final Map attributesMap = ImmutableMap.of(KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, SpaceType.L2.name()); + final Map attributesMap = ImmutableMap.of( + KNN_ENGINE, + KNNEngine.FAISS.getName(), + SPACE_TYPE, + SpaceType.L2.name(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ); final FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.attributes()).thenReturn(attributesMap); when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); diff --git a/src/test/java/org/opensearch/knn/index/util/FieldInfoExtractorTests.java b/src/test/java/org/opensearch/knn/index/util/FieldInfoExtractorTests.java new file mode 100644 index 000000000..a0facefbd --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/util/FieldInfoExtractorTests.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.util; + +import junit.framework.TestCase; +import lombok.SneakyThrows; +import org.apache.lucene.index.FieldInfo; +import org.mockito.Mockito; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.common.KNNConstants; + +import java.util.Collections; +import java.util.Map; + +import static org.mockito.Mockito.mock; +import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; + +public class FieldInfoExtractorTests extends TestCase { + @SneakyThrows + public void testGetIndexDescription_whenNoDescription_thenReturnNull() { + FieldInfo fieldInfo = mock(FieldInfo.class); + Mockito.when(fieldInfo.attributes()).thenReturn(Collections.emptyMap(), Map.of(KNNConstants.PARAMETERS, "{}")); + assertNull(FieldInfoExtractor.getIndexDescription(fieldInfo)); + assertNull(FieldInfoExtractor.getIndexDescription(fieldInfo)); + } + + @SneakyThrows + public void testGetIndexDescription_whenDescriptionExist_thenReturnIndexDescription() { + String indexDescription = "HNSW"; + XContentBuilder parameters = XContentFactory.jsonBuilder() + .startObject() + .field(INDEX_DESCRIPTION_PARAMETER, indexDescription) + .endObject(); + FieldInfo fieldInfo = mock(FieldInfo.class); + Mockito.when(fieldInfo.attributes()).thenReturn(Map.of(KNNConstants.PARAMETERS, parameters.toString())); + assertEquals(indexDescription, FieldInfoExtractor.getIndexDescription(fieldInfo)); + } +} diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index e71930d48..e17ee5077 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -62,6 +62,7 @@ public class JNIServiceTests extends KNNTestCase { static TestUtils.TestData testData; static TestUtils.TestData testDataNested; private String faissMethod = "HNSW32,Flat"; + private String faissBinaryMethod = "BHNSW32"; @BeforeClass public static void setUpClass() throws IOException { @@ -657,6 +658,21 @@ public void testCreateIndex_faiss_valid() throws IOException { } } + @SneakyThrows + public void testCreateIndex_binary_faiss_valid() { + Path tmpFile1 = createTempFile(); + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + JNIService.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + tmpFile1.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissBinaryMethod, KNNConstants.SPACE_TYPE, SpaceType.HAMMING_BIT.getValue()), + KNNEngine.FAISS + ); + assertTrue(tmpFile1.toFile().length() > 0); + } + public void testLoadIndex_invalidEngine() { expectThrows(IllegalArgumentException.class, () -> JNIService.loadIndex("test", Collections.emptyMap(), KNNEngine.LUCENE)); } @@ -943,6 +959,37 @@ public void testQueryIndex_faiss_parentIds() throws IOException { } } + @SneakyThrows + public void testQueryBinaryIndex_faiss_valid() { + int k = 10; + List methods = ImmutableList.of(faissBinaryMethod); + for (String method : methods) { + Path tmpFile = createTempFile(); + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + JNIService.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, SpaceType.HAMMING_BIT.getValue()), + KNNEngine.FAISS + ); + assertTrue(tmpFile.toFile().length() > 0); + + long pointer = JNIService.loadIndex( + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + + for (byte[] query : testData.binaryQueries) { + KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, null, KNNEngine.FAISS, null, 0, null); + assertEquals(k, results.length); + } + } + } + private Set toParentIdSet(KNNQueryResult[] results, Map idToParentIdMap) { return Arrays.stream(results).map(result -> idToParentIdMap.get(result.getId())).collect(Collectors.toSet()); } diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index 37e35f062..d41bbc0fd 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -19,6 +19,8 @@ import org.opensearch.knn.index.codec.util.SerializationMode; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.util.Collections; import java.util.Comparator; import java.util.Random; import java.util.Set; @@ -250,11 +252,14 @@ public static PriorityQueue insertWithOverflow(PriorityQueue flattenedVectors = new ArrayList<>(indexData.vectors.length * indexData.vectors[0].length); + for (int i = 0; i < indexData.vectors.length; i++) { + for (int j = 0; j < indexData.vectors[i].length; j++) { + flattenedVectors.add(indexData.vectors[i][j]); + } + } + Collections.sort(flattenedVectors); + Float median = flattenedVectors.get(flattenedVectors.size() / 2); + + // Quantize(indexData.vectors[i][j] >= median ? 1 : 0) and + // packing(8 bits to 1 byte) for index data + indexBinaryData = new byte[indexData.vectors.length][(indexData.vectors[0].length + 7) / 8]; + for (int i = 0; i < indexData.vectors.length; i++) { + for (int j = 0; j < indexData.vectors[i].length; j++) { + int byteIndex = j / 8; + int bitIndex = 7 - (j % 8); + indexBinaryData[i][byteIndex] |= (indexData.vectors[i][j] >= median ? 1 : 0) << bitIndex; + } + } + + // Quantize(queries[i][j] >= median ? 1 : 0) and + // packing(8 bits to 1 byte) for query data + binaryQueries = new byte[queries.length][(queries[0].length + 7) / 8]; + for (int i = 0; i < queries.length; i++) { + for (int j = 0; j < queries[i].length; j++) { + int byteIndex = j / 8; + int bitIndex = 7 - (j % 8); + binaryQueries[i][byteIndex] |= (queries[i][j] >= median ? 1 : 0) << bitIndex; + } + } + } + public long loadDataToMemoryAddress() { return JNICommons.storeVectorData(0, indexData.vectors, (long) indexData.vectors.length * indexData.vectors[0].length); } + public long loadBinaryDataToMemoryAddress() { + return JNICommons.storeByteVectorData(0, indexBinaryData, (long) indexBinaryData.length * indexBinaryData[0].length); + } + @AllArgsConstructor public static class Pair { public int[] docs;