Skip to content

Commit bf240e3

Browse files
committed
Initial commit for enabling the float vector values for vector search.
Things not working: 1. Filter query not working 2. Training index creation not tested.
1 parent 623b610 commit bf240e3

16 files changed

+693
-56
lines changed

src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.knn.index;
77

88
import org.apache.lucene.index.DocValues;
9+
import org.apache.lucene.index.DocValuesType;
910
import org.apache.lucene.index.FieldInfo;
1011
import org.apache.lucene.index.LeafReader;
1112
import org.apache.lucene.search.DocIdSetIterator;
@@ -57,8 +58,10 @@ public ScriptDocValues<float[]> getScriptValues() {
5758
default:
5859
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
5960
}
60-
} else {
61+
} else if (fieldInfo.getDocValuesType() == DocValuesType.BINARY) {
6162
values = DocValues.getBinary(reader, fieldName);
63+
} else {
64+
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
6265
}
6366
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
6467
} catch (IOException e) {

src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java

+35-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
1212
import org.opensearch.index.mapper.MapperService;
1313
import org.opensearch.knn.common.KNNConstants;
14+
import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines99KnnVectorsFormat;
1415
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
16+
import org.opensearch.knn.index.util.KNNEngine;
17+
import org.opensearch.knn.indices.ModelCache;
1518

1619
import java.util.Map;
1720
import java.util.Optional;
@@ -25,7 +28,7 @@
2528
@Log4j2
2629
public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFormat {
2730

28-
private final Optional<MapperService> mapperService;
31+
private final Optional<MapperService> optionalMapperService;
2932
private final int defaultMaxConnections;
3033
private final int defaultBeamWidth;
3134
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
@@ -42,12 +45,22 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
4245
);
4346
return defaultFormatSupplier.get();
4447
}
45-
var type = (KNNVectorFieldMapper.KNNVectorFieldType) mapperService.orElseThrow(
46-
() -> new IllegalStateException(
48+
if (optionalMapperService.isEmpty()) {
49+
throw new IllegalStateException(
4750
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
48-
)
49-
).fieldType(field);
50-
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
51+
);
52+
}
53+
final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) optionalMapperService
54+
.get()
55+
.fieldType(field);
56+
57+
final KNNEngine knnEngine = getKNNEngine(mappedFieldType);
58+
if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) {
59+
log.debug("Native Engine present hence using NativeEnginesKNNVectorsFormat. Engine found: {}", knnEngine);
60+
return new NativeEngines99KnnVectorsFormat();
61+
}
62+
63+
final Map<String, Object> params = mappedFieldType.getKnnMethodContext().getMethodComponentContext().getParameters();
5164
int maxConnections = getMaxConnections(params);
5265
int beamWidth = getBeamWidth(params);
5366
log.debug(
@@ -65,7 +78,8 @@ public int getMaxDimensions(String fieldName) {
6578
}
6679

6780
private boolean isKnnVectorFieldType(final String field) {
68-
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
81+
return optionalMapperService.isPresent()
82+
&& optionalMapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
6983
}
7084

7185
private int getMaxConnections(final Map<String, Object> params) {
@@ -81,4 +95,18 @@ private int getBeamWidth(final Map<String, Object> params) {
8195
}
8296
return defaultBeamWidth;
8397
}
98+
99+
private KNNEngine getKNNEngine(final KNNVectorFieldMapper.KNNVectorFieldType mappedFieldType) {
100+
final String modelId = mappedFieldType.getModelId();
101+
if (modelId != null) {
102+
var model = ModelCache.getInstance().get(modelId);
103+
return model.getModelMetadata().getKnnEngine();
104+
}
105+
106+
if (mappedFieldType.getKnnMethodContext() == null) {
107+
return KNNEngine.DEFAULT;
108+
} else {
109+
return mappedFieldType.getKnnMethodContext().getKnnEngine();
110+
}
111+
}
84112
}

src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java

+22-28
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.knn.index.codec.KNN80Codec;
77

88
import com.google.common.collect.ImmutableMap;
9-
import lombok.NonNull;
109
import lombok.extern.log4j.Log4j2;
1110
import org.apache.lucene.store.ChecksumIndexInput;
1211
import org.opensearch.common.StopWatch;
@@ -61,7 +60,7 @@
6160
* This class writes the KNN docvalues to the segments
6261
*/
6362
@Log4j2
64-
class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
63+
public class KNN80DocValuesConsumer extends DocValuesConsumer implements Closeable {
6564

6665
private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class);
6766

@@ -90,22 +89,14 @@ public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) th
9089
}
9190

9291
private boolean isKNNBinaryFieldRequired(FieldInfo field) {
93-
final KNNEngine knnEngine = getKNNEngine(field);
92+
final KNNEngine knnEngine = KNNCodecUtil.getKNNEngine(field);
9493
log.debug(String.format("Read engine [%s] for field [%s]", knnEngine.getName(), field.getName()));
95-
return field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)
94+
// This value will not be set: field.getVectorDimension()
95+
return field.getVectorDimension() <= 0
96+
&& field.attributes().containsKey(KNNVectorFieldMapper.KNN_FIELD)
9697
&& KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine);
9798
}
9899

99-
private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
100-
final String modelId = field.attributes().get(MODEL_ID);
101-
if (modelId != null) {
102-
var model = ModelCache.getInstance().get(modelId);
103-
return model.getModelMetadata().getKnnEngine();
104-
}
105-
final String engineName = field.attributes().getOrDefault(KNNConstants.KNN_ENGINE, KNNEngine.DEFAULT.getName());
106-
return KNNEngine.getEngine(engineName);
107-
}
108-
109100
public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh)
110101
throws IOException {
111102
// Get values to be indexed
@@ -123,7 +114,18 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
123114
}
124115
// Increment counter for number of graph index requests
125116
KNNCounter.GRAPH_INDEX_REQUESTS.increment();
126-
final KNNEngine knnEngine = getKNNEngine(field);
117+
if (isMerge) {
118+
recordMergeStats(pair.docs.length, arraySize);
119+
}
120+
121+
if (isRefresh) {
122+
recordRefreshStats();
123+
}
124+
createNativeIndex(state, field, pair);
125+
}
126+
127+
public static void createNativeIndex(SegmentWriteState state, FieldInfo field, KNNCodecUtil.Pair pair) throws IOException {
128+
final KNNEngine knnEngine = KNNCodecUtil.getKNNEngine(field);
127129
final String engineFileName = buildEngineFileName(
128130
state.segmentInfo.name,
129131
knnEngine.getVersion(),
@@ -147,20 +149,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
147149
indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
148150
}
149151

150-
if (isMerge) {
151-
recordMergeStats(pair.docs.length, arraySize);
152-
}
153-
154-
if (isRefresh) {
155-
recordRefreshStats();
156-
}
157-
158152
// This is a bit of a hack. We have to create an output here and then immediately close it to ensure that
159153
// engineFileName is added to the tracked files by Lucene's TrackingDirectoryWrapper. Otherwise, the file will
160154
// not be marked as added to the directory.
161155
state.directory.createOutput(engineFileName, state.context).close();
162156
indexCreator.createIndex();
163-
writeFooter(indexPath, engineFileName);
157+
writeFooter(state, indexPath, engineFileName);
164158
}
165159

166160
private void recordMergeStats(int length, long arraySize) {
@@ -176,7 +170,7 @@ private void recordRefreshStats() {
176170
KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
177171
}
178172

179-
private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
173+
private static void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
180174
Map<String, Object> parameters = ImmutableMap.of(
181175
KNNConstants.INDEX_THREAD_QTY,
182176
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
@@ -195,7 +189,7 @@ private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KN
195189
});
196190
}
197191

198-
private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
192+
private static void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
199193
throws IOException {
200194
Map<String, Object> parameters = new HashMap<>();
201195
Map<String, String> fieldAttributes = fieldInfo.attributes();
@@ -295,7 +289,7 @@ private interface NativeIndexCreator {
295289
void createIndex() throws IOException;
296290
}
297291

298-
private void writeFooter(String indexPath, String engineFileName) throws IOException {
292+
private static void writeFooter(SegmentWriteState state, String indexPath, String engineFileName) throws IOException {
299293
// Opens the engine file that was created and appends a footer to it. The footer consists of
300294
// 1. A Footer magic number (int - 4 bytes)
301295
// 2. A checksum algorithm id (int - 4 bytes)
@@ -325,7 +319,7 @@ private void writeFooter(String indexPath, String engineFileName) throws IOExcep
325319
os.close();
326320
}
327321

328-
private boolean isChecksumValid(long value) {
322+
private static boolean isChecksumValid(long value) {
329323
// Check pulled from
330324
// https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647
331325
return (value & CRC32_CHECKSUM_SANITY) != 0;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*/
11+
12+
package org.opensearch.knn.index.codec.KNN990Codec;
13+
14+
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
15+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
16+
import org.apache.lucene.codecs.KnnVectorsFormat;
17+
import org.apache.lucene.codecs.KnnVectorsReader;
18+
import org.apache.lucene.codecs.KnnVectorsWriter;
19+
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
20+
import org.apache.lucene.index.SegmentReadState;
21+
import org.apache.lucene.index.SegmentWriteState;
22+
23+
import java.io.IOException;
24+
25+
public class NativeEngines99KnnVectorsFormat extends KnnVectorsFormat {
26+
27+
/** The format for storing, reading, merging vectors on disk */
28+
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
29+
30+
/**
31+
* Sole constructor
32+
*
33+
*/
34+
public NativeEngines99KnnVectorsFormat() {
35+
super("NativeEngines99KnnVectorsFormat");
36+
}
37+
38+
/**
39+
* Returns a {@link KnnVectorsWriter} to write the vectors to the index.
40+
*
41+
* @param state
42+
*/
43+
@Override
44+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
45+
return new NativeEnginesKNNVectorsWriter(state, flatVectorsFormat.fieldsWriter(state));
46+
}
47+
48+
/**
49+
* Returns a {@link KnnVectorsReader} to read the vectors from the index.
50+
*
51+
* @param state
52+
*/
53+
@Override
54+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
55+
return new NativeEnginesKNNVectorsReader(state, flatVectorsFormat.fieldsReader(state));
56+
}
57+
58+
@Override
59+
public String toString() {
60+
return "NativeEngines99KnnVectorsFormat(name=NativeEngines99KnnVectorsFormat, flatVectorsFormat=" + flatVectorsFormat + ")";
61+
}
62+
63+
}

0 commit comments

Comments
 (0)