Skip to content

Commit 37c4a9d

Browse files
committed
Add jni interface to use a binary hnsw index with faiss (#1747)
Signed-off-by: Heemin Kim <heemin@amazon.com>
1 parent e2c0a36 commit 37c4a9d

35 files changed

+1600
-87
lines changed

jni/CMakeLists.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ set(TARGET_LIBS "") # Libs to be installed
2020

2121
set(CMAKE_CXX_STANDARD 11)
2222
set(CMAKE_CXX_STANDARD_REQUIRED True)
23-
2423
option(CONFIG_FAISS "Configure faiss library build when this is on")
2524
option(CONFIG_NMSLIB "Configure nmslib library build when this is on")
2625
option(CONFIG_TEST "Configure tests when this is on")
@@ -112,6 +111,8 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
112111
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
113112
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
114113
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_util.cpp
114+
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_index_service.cpp
115+
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_methods.cpp
115116
)
116117
target_link_libraries(${TARGET_LIB_FAISS} ${TARGET_LINK_FAISS_LIB} ${TARGET_LIB_UTIL} OpenMP::OpenMP_CXX)
117118
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
@@ -151,6 +152,7 @@ if ("${WIN32}" STREQUAL "")
151152
tests/nmslib_wrapper_test.cpp
152153
tests/test_util.cpp
153154
tests/commons_test.cpp
155+
tests/faiss_index_service_test.cpp
154156
)
155157

156158
target_link_libraries(

jni/include/commons.h

+23-1
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,38 @@ namespace knn_jni {
2222
* @param memoryAddress The address of the memory location where data will be stored.
2323
* @param data 2D float array containing data to be stored in native memory.
2424
* @param initialCapacity The initial capacity of the memory location.
25-
* @return memory address where the data is stored.
25+
* @return memory address of std::vector<float> where the data is stored.
2626
*/
2727
jlong storeVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
2828

29+
/**
30+
* This is utility function that can be used to store data in native memory. This function will allocate memory for
31+
* the data(rows*columns) with initialCapacity and return the memory address where the data is stored.
32+
* If you are using this function for first time use memoryAddress = 0 to ensure that a new memory location is created.
33+
* For subsequent calls you can pass the same memoryAddress. If the data cannot be stored in the memory location
34+
* will throw Exception.
35+
*
36+
* @param memoryAddress The address of the memory location where data will be stored.
37+
* @param data 2D byte array containing data to be stored in native memory.
38+
* @param initialCapacity The initial capacity of the memory location.
39+
* @return memory address of std::vector<uint8_t> where the data is stored.
40+
*/
41+
jlong storeByteVectorData(knn_jni::JNIUtilInterface *, JNIEnv *, jlong , jobjectArray, jlong);
42+
2943
/**
3044
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
3145
* address returned by {@link JNICommons#storeVectorData(long, float[][], long, long)}
3246
*
3347
* @param memoryAddress address to be freed.
3448
*/
3549
void freeVectorData(jlong);
50+
51+
/**
52+
* Free up the memory allocated for the data stored in memory address. This function should be used with the memory
53+
* address returned by {@link JNICommons#storeByteVectorData(long, byte[][], long, long)}
54+
*
55+
* @param memoryAddress address to be freed.
56+
*/
57+
void freeByteVectorData(jlong);
3658
}
3759
}

jni/include/faiss_index_service.h

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
//
3+
// The OpenSearch Contributors require contributions made to
4+
// this file be licensed under the Apache-2.0 license or a
5+
// compatible open source license.
6+
//
7+
// Modifications Copyright OpenSearch Contributors. See
8+
// GitHub history for details.
9+
10+
/**
11+
* This file contains classes for index operations which are free of JNI
12+
*/
13+
14+
#ifndef OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
15+
#define OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H
16+
17+
#include <jni.h>
18+
#include "faiss/MetricType.h"
19+
#include "jni_util.h"
20+
#include "faiss_methods.h"
21+
#include <memory>
22+
23+
namespace knn_jni {
24+
namespace faiss_wrapper {
25+
26+
27+
/**
28+
* A class to provide operations on index
29+
* This class should evolve to have only cpp object but not jni object
30+
*/
31+
class IndexService {
32+
public:
33+
IndexService(std::unique_ptr<FaissMethods> faissMethods);
34+
//TODO Remove dependency on JNIUtilInterface and JNIEnv
35+
//TODO Reduce the number of parameters
36+
37+
/**
38+
* Create index
39+
*
40+
* @param jniUtil jni util
41+
* @param env jni environment
42+
* @param metric space type for distance calculation
43+
* @param indexDescription index description to be used by faiss index factory
44+
* @param dim dimension of vectors
45+
* @param numIds number of vectors
46+
* @param threadCount number of thread count to be used while adding data
47+
* @param vectorsAddress memory address which is holding vector data
48+
* @param ids a list of document ids for corresponding vectors
49+
* @param indexPath path to write index
50+
* @param parameters parameters to be applied to faiss index
51+
*/
52+
virtual void createIndex(
53+
knn_jni::JNIUtilInterface * jniUtil,
54+
JNIEnv * env,
55+
faiss::MetricType metric,
56+
std::string indexDescription,
57+
int dim,
58+
int numIds,
59+
int threadCount,
60+
int64_t vectorsAddress,
61+
std::vector<int64_t> ids,
62+
std::string indexPath,
63+
std::unordered_map<std::string, jobject> parameters);
64+
virtual ~IndexService() = default;
65+
protected:
66+
std::unique_ptr<FaissMethods> faissMethods;
67+
};
68+
69+
/**
70+
* A class to provide operations on index
71+
* This class should evolve to have only cpp object but not jni object
72+
*/
73+
class BinaryIndexService : public IndexService {
74+
public:
75+
//TODO Remove dependency on JNIUtilInterface and JNIEnv
76+
//TODO Reduce the number of parameters
77+
BinaryIndexService(std::unique_ptr<FaissMethods> faissMethods);
78+
/**
79+
* Create binary index
80+
*
81+
* @param jniUtil jni util
82+
* @param env jni environment
83+
* @param metric space type for distance calculation
84+
* @param indexDescription index description to be used by faiss index factory
85+
* @param dim dimension of vectors
86+
* @param numIds number of vectors
87+
* @param threadCount number of thread count to be used while adding data
88+
* @param vectorsAddress memory address which is holding vector data
89+
* @param ids a list of document ids for corresponding vectors
90+
* @param indexPath path to write index
91+
* @param parameters parameters to be applied to faiss index
92+
*/
93+
virtual void createIndex(
94+
knn_jni::JNIUtilInterface * jniUtil,
95+
JNIEnv * env,
96+
faiss::MetricType metric,
97+
std::string indexDescription,
98+
int dim,
99+
int numIds,
100+
int threadCount,
101+
int64_t vectorsAddress,
102+
std::vector<int64_t> ids,
103+
std::string indexPath,
104+
std::unordered_map<std::string, jobject> parameters
105+
) override;
106+
virtual ~BinaryIndexService() = default;
107+
};
108+
109+
}
110+
}
111+
112+
113+
#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H

jni/include/faiss_methods.h

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
//
3+
// The OpenSearch Contributors require contributions made to
4+
// this file be licensed under the Apache-2.0 license or a
5+
// compatible open source license.
6+
//
7+
// Modifications Copyright OpenSearch Contributors. See
8+
// GitHub history for details.
9+
10+
#ifndef OPENSEARCH_KNN_FAISS_METHODS_H
11+
#define OPENSEARCH_KNN_FAISS_METHODS_H
12+
13+
#include "faiss/Index.h"
14+
#include "faiss/IndexBinary.h"
15+
#include "faiss/IndexIDMap.h"
16+
#include "faiss/index_io.h"
17+
18+
namespace knn_jni {
19+
namespace faiss_wrapper {
20+
21+
/**
22+
* A class having wrapped faiss methods
23+
*
24+
* This class helps to mock faiss methods during unit test
25+
*/
26+
class FaissMethods {
27+
public:
28+
FaissMethods() = default;
29+
virtual faiss::Index* indexFactory(int d, const char* description, faiss::MetricType metric);
30+
virtual faiss::IndexBinary* indexBinaryFactory(int d, const char* description);
31+
virtual faiss::IndexIDMapTemplate<faiss::Index>* indexIdMap(faiss::Index* index);
32+
virtual faiss::IndexIDMapTemplate<faiss::IndexBinary>* indexBinaryIdMap(faiss::IndexBinary* index);
33+
virtual void writeIndex(const faiss::Index* idx, const char* fname);
34+
virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname);
35+
virtual ~FaissMethods() = default;
36+
};
37+
38+
} //namespace faiss_wrapper
39+
} //namespace knn_jni
40+
41+
42+
#endif //OPENSEARCH_KNN_FAISS_METHODS_H

jni/include/faiss_wrapper.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
#define OPENSEARCH_KNN_FAISS_WRAPPER_H
1414

1515
#include "jni_util.h"
16+
#include "faiss_index_service.h"
1617
#include <jni.h>
1718

1819
namespace knn_jni {
1920
namespace faiss_wrapper {
2021
// Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ.
2122
// The index is serialized to indexPathJ.
2223
void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ,
23-
jstring indexPathJ, jobject parametersJ);
24+
jstring indexPathJ, jobject parametersJ, IndexService* indexService);
2425

2526
// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
2627
// based off of the template index passed in. The index is serialized to indexPathJ.
@@ -33,6 +34,11 @@ namespace knn_jni {
3334
// Return a pointer to the loaded index
3435
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);
3536

37+
// Load a binary index from indexPathJ into memory.
38+
//
39+
// Return a pointer to the loaded index
40+
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);
41+
3642
// Check if a loaded index requires shared state
3743
bool IsSharedIndexStateRequired(jlong indexPointerJ);
3844

@@ -58,6 +64,12 @@ namespace knn_jni {
5864
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ,
5965
jint filterIdsTypeJ, jintArray parentIdsJ);
6066

67+
// Execute a query against the binary index located in memory at indexPointerJ along with Filters
68+
//
69+
// Return an array of KNNQueryResults
70+
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
71+
jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);
72+
6173
// Free the index located in memory at indexPointerJ
6274
void Free(jlong indexPointer);
6375

jni/include/jni_util.h

+7
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ namespace knn_jni {
7171

7272
virtual void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ,
7373
int dim, std::vector<float> *vect ) = 0;
74+
virtual void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ,
75+
int dim, std::vector<uint8_t> *vect ) = 0;
7476

7577
virtual std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) = 0;
7678

@@ -79,6 +81,8 @@ namespace knn_jni {
7981
// ------------------------------ MISC HELPERS ------------------------------
8082
virtual int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) = 0;
8183

84+
virtual int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) = 0;
85+
8286
virtual int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) = 0;
8387

8488
virtual int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) = 0;
@@ -146,6 +150,7 @@ namespace knn_jni {
146150
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim);
147151
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ);
148152
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
153+
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ);
149154
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
150155
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
151156
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
@@ -168,6 +173,7 @@ namespace knn_jni {
168173
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
169174
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
170175
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
176+
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);
171177

172178
private:
173179
std::unordered_map<std::string, jclass> cachedClasses;
@@ -193,6 +199,7 @@ namespace knn_jni {
193199
extern const std::string COSINESIMIL;
194200
extern const std::string INNER_PRODUCT;
195201
extern const std::string NEG_DOT_PRODUCT;
202+
extern const std::string HAMMING_BIT;
196203

197204
extern const std::string NPROBES;
198205
extern const std::string COARSE_QUANTIZER;

jni/include/org_opensearch_knn_jni_FaissService.h

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#ifdef __cplusplus
1919
extern "C" {
2020
#endif
21+
2122
/*
2223
* Class: org_opensearch_knn_jni_FaissService
2324
* Method: createIndex
@@ -26,6 +27,14 @@ extern "C" {
2627
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndex
2728
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);
2829

30+
/*
31+
* Class: org_opensearch_knn_jni_FaissService
32+
* Method: createBinaryIndex
33+
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
34+
*/
35+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndex
36+
(JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject);
37+
2938
/*
3039
* Class: org_opensearch_knn_jni_FaissService
3140
* Method: createIndexFromTemplate
@@ -42,6 +51,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
4251
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
4352
(JNIEnv *, jclass, jstring);
4453

54+
/*
55+
* Class: org_opensearch_knn_jni_FaissService
56+
* Method: loadBinaryIndex
57+
* Signature: (Ljava/lang/String;)J
58+
*/
59+
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
60+
(JNIEnv *, jclass, jstring);
61+
4562
/*
4663
* Class: org_opensearch_knn_jni_FaissService
4764
* Method: isSharedIndexStateRequired
@@ -82,6 +99,14 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
8299
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
83100
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray);
84101

102+
/*
103+
* Class: org_opensearch_knn_jni_FaissService
104+
* Method: queryBIndexWithFilter
105+
* Signature: (J[BI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
106+
*/
107+
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryBinaryIndexWithFilter
108+
(JNIEnv *, jclass, jlong, jbyteArray, jint, jlongArray, jint, jintArray);
109+
85110
/*
86111
* Class: org_opensearch_knn_jni_FaissService
87112
* Method: free

jni/include/org_opensearch_knn_jni_JNICommons.h

+16
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ extern "C" {
2626
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
2727
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
2828

29+
/*
30+
* Class: org_opensearch_knn_jni_JNICommons
31+
* Method: storeVectorData
32+
* Signature: (J[[FJJ)
33+
*/
34+
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeByteVectorData
35+
(JNIEnv *, jclass, jlong, jobjectArray, jlong);
36+
2937
/*
3038
* Class: org_opensearch_knn_jni_JNICommons
3139
* Method: freeVectorData
@@ -34,6 +42,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_JNICommons_storeVectorData
3442
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeVectorData
3543
(JNIEnv *, jclass, jlong);
3644

45+
/*
46+
* Class: org_opensearch_knn_jni_JNICommons
47+
* Method: freeVectorData
48+
* Signature: (J)V
49+
*/
50+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_JNICommons_freeByteVectorData
51+
(JNIEnv *, jclass, jlong);
52+
3753
#ifdef __cplusplus
3854
}
3955
#endif

0 commit comments

Comments
 (0)