Skip to content

Commit c890d06

Browse files
committed
Added the poc code for abstraction layer for index builds and also added the initial interface to iterating over the vector values
Signed-off-by: Navneet Verma <navneev@amazon.com>
1 parent 37c4a9d commit c890d06

7 files changed

+409
-120
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*/
11+
12+
package org.opensearch.knn.index;
13+
14+
import lombok.extern.log4j.Log4j2;
15+
import org.apache.lucene.index.FieldInfo;
16+
import org.apache.lucene.index.SegmentWriteState;
17+
import org.apache.lucene.search.DocIdSetIterator;
18+
import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesConsumer;
19+
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
20+
import org.opensearch.knn.index.codec.util.SerializationMode;
21+
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
22+
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator;
23+
import org.opensearch.knn.jni.JNICommons;
24+
25+
import java.io.IOException;
26+
import java.util.ArrayList;
27+
import java.util.List;
28+
29+
/**
30+
* This is a single layer that will be responsible for creating the native indices. Right now this is just a POC code,
31+
* this needs to be fixed. Its more of a testing to see if everything works correctly.
32+
*/
33+
@Log4j2
34+
public class NativeIndexCreationManager {
35+
36+
public static void startIndexCreation(
37+
final SegmentWriteState segmentWriteState,
38+
final KNNVectorValues<float[]> vectorValues,
39+
final FieldInfo fieldInfo
40+
) throws IOException {
41+
KNNCodecUtil.Pair pair = streamFloatVectors(vectorValues);
42+
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
43+
log.info("Skipping engine index creation as there are no vectors or docs in the segment");
44+
return;
45+
}
46+
createNativeIndex(segmentWriteState, fieldInfo, pair);
47+
}
48+
49+
private static void createNativeIndex(
50+
final SegmentWriteState segmentWriteState,
51+
final FieldInfo fieldInfo,
52+
final KNNCodecUtil.Pair pair
53+
) throws IOException {
54+
KNN80DocValuesConsumer.createNativeIndex(segmentWriteState, fieldInfo, pair);
55+
}
56+
57+
private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[]> kNNVectorValues) throws IOException {
58+
List<float[]> vectorList = new ArrayList<>();
59+
List<Integer> docIdList = new ArrayList<>();
60+
long vectorAddress = 0;
61+
int dimension = 0;
62+
long totalLiveDocs = kNNVectorValues.totalLiveDocs();
63+
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
64+
long vectorsPerTransfer = Integer.MIN_VALUE;
65+
66+
KNNVectorValuesIterator iterator = kNNVectorValues.getVectorValuesIterator();
67+
68+
for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
69+
float[] vector = kNNVectorValues.getVector();
70+
dimension = vector.length;
71+
if (vectorsPerTransfer == Integer.MIN_VALUE) {
72+
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
73+
// This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer
74+
// Doing this will reduce 1 extra trip to JNI layer.
75+
if (vectorsPerTransfer == 0) {
76+
vectorsPerTransfer = totalLiveDocs;
77+
}
78+
}
79+
80+
if (vectorList.size() == vectorsPerTransfer) {
81+
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
82+
// We should probably come up with a better way to reuse the vectorList memory which we have
83+
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
84+
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
85+
vectorList = new ArrayList<>();
86+
}
87+
88+
vectorList.add(vector);
89+
docIdList.add(doc);
90+
}
91+
92+
if (vectorList.isEmpty() == false) {
93+
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
94+
}
95+
// SerializationMode.COLLECTION_OF_FLOATS is not getting used. I just added it to ensure code successfully
96+
// works.
97+
return new KNNCodecUtil.Pair(
98+
docIdList.stream().mapToInt(Integer::intValue).toArray(),
99+
vectorAddress,
100+
dimension,
101+
SerializationMode.COLLECTION_OF_FLOATS
102+
);
103+
}
104+
}

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEnginesKNNVectorsWriter.java

+13-118
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,13 @@
2222
import org.apache.lucene.index.MergeState;
2323
import org.apache.lucene.index.SegmentWriteState;
2424
import org.apache.lucene.index.Sorter;
25-
import org.apache.lucene.search.DocIdSetIterator;
2625
import org.apache.lucene.util.IOUtils;
2726
import org.apache.lucene.util.InfoStream;
28-
import org.opensearch.knn.index.KNNSettings;
29-
import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesConsumer;
30-
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
31-
import org.opensearch.knn.index.codec.util.SerializationMode;
32-
import org.opensearch.knn.jni.JNICommons;
27+
import org.opensearch.knn.index.NativeIndexCreationManager;
28+
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
3329

3430
import java.io.IOException;
3531
import java.util.ArrayList;
36-
import java.util.Arrays;
3732
import java.util.List;
3833

3934
@RequiredArgsConstructor
@@ -69,6 +64,7 @@ public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException
6964
* @param sortMap {@link Sorter.DocMap}
7065
*/
7166
@Override
67+
@SuppressWarnings("unchecked")
7268
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
7369
// simply write data in the flat file
7470
flatVectorsWriter.flush(maxDoc, sortMap);
@@ -79,12 +75,11 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
7975
// on the disk.
8076
// getFloatsFromFloatVectorValues(fields);
8177
for (NativeEnginesKNNVectorsWriter.FieldWriter<?> fieldWriter : fields) {
82-
KNNCodecUtil.Pair pair = getFloatsFromFieldWriter(fieldWriter);
83-
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
84-
log.info("Skipping engine index creation as there are no vectors or docs in the segment");
85-
continue;
86-
}
87-
KNN80DocValuesConsumer.createNativeIndex(segmentWriteState, fieldWriter.fieldInfo, pair);
78+
NativeIndexCreationManager.startIndexCreation(
79+
segmentWriteState,
80+
KNNVectorValuesFactory.getFloatVectorValues(fieldWriter.docsWithField.iterator(), (List<float[]>) fieldWriter.vectors),
81+
fieldWriter.fieldInfo
82+
);
8883
}
8984
}
9085

@@ -94,12 +89,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
9489
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
9590
final FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
9691
// merging the graphs here
97-
final KNNCodecUtil.Pair pair = getFloatsFromFloatVectorValues(floatVectorValues);
98-
if (pair.getVectorAddress() == 0 || pair.docs.length == 0) {
99-
log.info("Skipping engine index creation as there are no vectors or docs to be merged");
100-
return;
101-
}
102-
KNN80DocValuesConsumer.createNativeIndex(segmentWriteState, fieldInfo, pair);
92+
NativeIndexCreationManager.startIndexCreation(
93+
segmentWriteState,
94+
KNNVectorValuesFactory.getFloatVectorValues(floatVectorValues),
95+
fieldInfo
96+
);
10397
}
10498

10599
/**
@@ -140,105 +134,6 @@ public long ramBytesUsed() {
140134
return 0;
141135
}
142136

143-
private KNNCodecUtil.Pair getFloatsFromFloatVectorValues(FloatVectorValues floatVectorValues) throws IOException {
144-
List<float[]> vectorList = new ArrayList<>();
145-
List<Integer> docIdList = new ArrayList<>();
146-
long vectorAddress = 0;
147-
int dimension = 0;
148-
149-
long totalLiveDocs = floatVectorValues.size();
150-
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
151-
long vectorsPerTransfer = Integer.MIN_VALUE;
152-
153-
for (int doc = floatVectorValues.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = floatVectorValues.nextDoc()) {
154-
float[] temp = floatVectorValues.vectorValue();
155-
// This temp object and copy of temp object is required because when we map floats we read to a memory
156-
// location in heap always for floatVectorValues. Ref: OffHeapFloatVectorValues.vectorValue.
157-
float[] vector = Arrays.copyOf(floatVectorValues.vectorValue(), temp.length);
158-
dimension = vector.length;
159-
if (vectorsPerTransfer == Integer.MIN_VALUE) {
160-
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
161-
// This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer
162-
// Doing this will reduce 1 extra trip to JNI layer.
163-
if (vectorsPerTransfer == 0) {
164-
vectorsPerTransfer = totalLiveDocs;
165-
}
166-
}
167-
168-
if (vectorList.size() == vectorsPerTransfer) {
169-
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
170-
// We should probably come up with a better way to reuse the vectorList memory which we have
171-
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
172-
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
173-
vectorList = new ArrayList<>();
174-
}
175-
vectorList.add(vector);
176-
docIdList.add(doc);
177-
}
178-
179-
if (vectorList.isEmpty() == false) {
180-
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
181-
}
182-
// SerializationMode.COLLECTION_OF_FLOATS is not getting used. I just added it to ensure code successfully
183-
// works.
184-
return new KNNCodecUtil.Pair(
185-
docIdList.stream().mapToInt(Integer::intValue).toArray(),
186-
vectorAddress,
187-
dimension,
188-
SerializationMode.COLLECTION_OF_FLOATS
189-
);
190-
}
191-
192-
private KNNCodecUtil.Pair getFloatsFromFieldWriter(NativeEnginesKNNVectorsWriter.FieldWriter<?> fieldWriter) throws IOException {
193-
List<float[]> vectorList = new ArrayList<>();
194-
List<Integer> docIdList = new ArrayList<>();
195-
long vectorAddress = 0;
196-
int dimension = 0;
197-
198-
long totalLiveDocs = fieldWriter.vectors.size();
199-
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
200-
long vectorsPerTransfer = Integer.MIN_VALUE;
201-
202-
DocIdSetIterator disi = fieldWriter.docsWithField.iterator();
203-
204-
for (int i = 0; i < fieldWriter.vectors.size(); i++) {
205-
float[] vector = (float[]) fieldWriter.vectors.get(i);
206-
dimension = vector.length;
207-
if (vectorsPerTransfer == Integer.MIN_VALUE) {
208-
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
209-
// This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer
210-
// Doing this will reduce 1 extra trip to JNI layer.
211-
if (vectorsPerTransfer == 0) {
212-
vectorsPerTransfer = totalLiveDocs;
213-
}
214-
}
215-
216-
if (vectorList.size() == vectorsPerTransfer) {
217-
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
218-
// We should probably come up with a better way to reuse the vectorList memory which we have
219-
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
220-
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
221-
vectorList = new ArrayList<>();
222-
}
223-
224-
vectorList.add(vector);
225-
docIdList.add(disi.nextDoc());
226-
227-
}
228-
229-
if (vectorList.isEmpty() == false) {
230-
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
231-
}
232-
// SerializationMode.COLLECTION_OF_FLOATS is not getting used. I just added it to ensure code successfully
233-
// works.
234-
return new KNNCodecUtil.Pair(
235-
docIdList.stream().mapToInt(Integer::intValue).toArray(),
236-
vectorAddress,
237-
dimension,
238-
SerializationMode.COLLECTION_OF_FLOATS
239-
);
240-
}
241-
242137
private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
243138
private final FieldInfo fieldInfo;
244139
private final List<T> vectors;

src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ public static String buildEngineFileSuffix(String fieldName, String extension) {
145145
return String.format("_%s%s", fieldName, extension);
146146
}
147147

148-
private static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) {
148+
public static long getTotalLiveDocsCount(final BinaryDocValues binaryDocValues) {
149149
long totalLiveDocs;
150150
if (binaryDocValues instanceof KNN80BinaryDocValues) {
151151
totalLiveDocs = ((KNN80BinaryDocValues) binaryDocValues).getTotalLiveDocs();

src/main/java/org/opensearch/knn/index/codec/util/KNNVectorSerializerFactory.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayIn
5656
return getSerializerBySerializationMode(serializationMode);
5757
}
5858

59-
static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
59+
public static SerializationMode serializerModeFromStream(ByteArrayInputStream byteStream) {
6060
int numberOfAvailableBytesInStream = byteStream.available();
6161
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
6262
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);

0 commit comments

Comments
 (0)