From 46fc0d7c11a40e152d4640a721b12754050194f1 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 20 Mar 2025 15:17:19 -0500 Subject: [PATCH] Remove query-time usage of ByteSequence::slice to reduce object allocations --- .../jvector/quantization/PQDecoder.java | 13 ++-- .../jvector/quantization/PQVectors.java | 51 +++++++++++--- .../vector/DefaultVectorUtilSupport.java | 9 ++- .../jbellis/jvector/vector/VectorUtil.java | 8 +++ .../jvector/vector/VectorUtilSupport.java | 22 +++++- .../vector/NativeVectorUtilSupport.java | 8 +++ .../vector/PanamaVectorUtilSupport.java | 17 ++++- .../jbellis/jvector/vector/SimdOps.java | 68 +++++++++---------- 8 files changed, 141 insertions(+), 55 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java index 129eebee7..85244befd 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java @@ -53,8 +53,8 @@ protected CachingDecoder(PQVectors cv, VectorFloat query, VectorSimilarityFun } } - protected float decodedSimilarity(ByteSequence encoded) { - return VectorUtil.assembleAndSum(partialSums, cv.pq.getClusterCount(), encoded); + protected float decodedSimilarity(ByteSequence encoded, int offset, int length) { + return VectorUtil.assembleAndSum(partialSums, cv.pq.getClusterCount(), encoded, offset, length); } } @@ -65,7 +65,7 @@ public DotProductDecoder(PQVectors cv, VectorFloat query) { @Override public float similarityTo(int node2) { - return (1 + decodedSimilarity(cv.get(node2))) / 2; + return (1 + decodedSimilarity(cv.getChunk(node2), cv.getOffsetInChunk(node2), cv.pq.getSubspaceCount())) / 2; } } @@ -76,7 +76,7 @@ public EuclideanDecoder(PQVectors cv, VectorFloat query) { @Override public float similarityTo(int node2) { - return 1 / (1 + decodedSimilarity(cv.get(node2))); + return 1 / (1 + decodedSimilarity(cv.getChunk(node2), cv.getOffsetInChunk(node2), cv.pq.getSubspaceCount())); } } @@ -132,9 +132,10 @@ public float similarityTo(int node2) { protected float decodedCosine(int node2) { - ByteSequence encoded = cv.get(node2); + ByteSequence encoded = cv.getChunk(node2); + int offset = cv.getOffsetInChunk(node2); - return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); + return VectorUtil.pqDecodedCosineSimilarity(encoded, offset, cv.pq.getSubspaceCount(), cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index f6a41cefe..d1d5e73f0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -229,11 +229,12 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, switch (similarityFunction) { case DOT_PRODUCT: return (node2) -> { - var encoded = get(node2); + var encodedChunk = getChunk(node2); + var encodedOffset = getOffsetInChunk(node2); // compute the dot product of the query and the codebook centroids corresponding to the encoded points float dp = 0; for (int m = 0; m < pq.getSubspaceCount(); m++) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); int centroidLength = pq.subvectorSizesAndOffsets[m][0]; int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); @@ -244,12 +245,13 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, case COSINE: float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery); return (node2) -> { - var encoded = get(node2); + var encodedChunk = getChunk(node2); + var encodedOffset = getOffsetInChunk(node2); // compute the dot product of the query and the codebook centroids corresponding to the encoded points float sum = 0; float norm2 = 0; for (int m = 0; m < pq.getSubspaceCount(); m++) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); int centroidLength = pq.subvectorSizesAndOffsets[m][0]; int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; var codebookOffset = centroidIndex * centroidLength; @@ -262,11 +264,12 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, }; case EUCLIDEAN: return (node2) -> { - var encoded = get(node2); + var encodedChunk = getChunk(node2); + var encodedOffset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points float sum = 0; for (int m = 0; m < pq.getSubspaceCount(); m++) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); int centroidLength = pq.subvectorSizesAndOffsets[m][0]; int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); @@ -279,6 +282,11 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, } } + /** + * Returns a {@link ByteSequence} for the given ordinal. + * @param ordinal the vector's ordinal + * @return the {@link ByteSequence} + */ public ByteSequence get(int ordinal) { if (ordinal < 0 || ordinal >= count()) throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count()); @@ -286,10 +294,37 @@ public ByteSequence get(int ordinal) { } static ByteSequence get(ByteSequence[] chunks, int ordinal, int vectorsPerChunk, int subspaceCount) { - int chunkIndex = ordinal / vectorsPerChunk; int vectorIndexInChunk = ordinal % vectorsPerChunk; int start = vectorIndexInChunk * subspaceCount; - return chunks[chunkIndex].slice(start, subspaceCount); + return getChunk(chunks, ordinal, vectorsPerChunk).slice(start, subspaceCount); + } + + /** + * Returns a reference to the {@link ByteSequence} containing for the given ordinal. Only intended for use where + * the caller wants to avoid an allocation for the slice object. After getting the chunk, callers should use the + * {@link #getOffsetInChunk(int)} method to get the offset of the vector within the chunk and then use the pq's + * {@link ProductQuantization#getSubspaceCount()} to get the length of the vector. + * @param ordinal the vector's ordinal + * @return the {@link ByteSequence} chunk containing the vector + */ + ByteSequence getChunk(int ordinal) { + if (ordinal < 0 || ordinal >= count()) + throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count()); + + return getChunk(compressedDataChunks, ordinal, vectorsPerChunk); + } + + int getOffsetInChunk(int ordinal) { + if (ordinal < 0 || ordinal >= count()) + throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + count()); + + int vectorIndexInChunk = ordinal % vectorsPerChunk; + return vectorIndexInChunk * pq.getSubspaceCount(); + } + + static ByteSequence getChunk(ByteSequence[] chunks, int ordinal, int vectorsPerChunk) { + int chunkIndex = ordinal / vectorsPerChunk; + return chunks[chunkIndex]; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index 45f473912..28942d99b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -296,9 +296,14 @@ public void minInPlace(VectorFloat v1, VectorFloat v2) { @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { + return assembleAndSum(data, dataBase, baseOffsets, 0, baseOffsets.length()); + } + + @Override + public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { float sum = 0f; - for (int i = 0; i < baseOffsets.length(); i++) { - sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))); + for (int i = 0; i < baseOffsetsLength; i++) { + sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset))); } return sum; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 0a7b25f9a..ebcddd7da 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -166,6 +166,10 @@ public static float assembleAndSum(VectorFloat data, int dataBase, ByteSequen return impl.assembleAndSum(data, dataBase, dataOffsets); } + public static float assembleAndSum(VectorFloat data, int dataBase, ByteSequence dataOffsets, int dataOffsetsOffset, int dataOffsetsLength) { + return impl.assembleAndSum(data, dataBase, dataOffsets, dataOffsetsOffset, dataOffsetsLength); + } + public static void bulkShuffleQuantizedSimilarity(ByteSequence shuffles, int codebookCount, ByteSequence quantizedPartials, float delta, float minDistance, VectorFloat results, VectorSimilarityFunction vsf) { impl.bulkShuffleQuantizedSimilarity(shuffles, codebookCount, quantizedPartials, delta, minDistance, vsf, results); } @@ -215,6 +219,10 @@ public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clust return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } + public static float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return impl.pqDecodedCosineSimilarity(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } + public static float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 41d200739..c7206fae2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -100,6 +100,19 @@ public interface VectorUtilSupport { */ float assembleAndSum(VectorFloat data, int baseIndex, ByteSequence baseOffsets); + /** + * Calculates the sum of sparse points in a vector. + * + * @param data the vector of all datapoints + * @param baseIndex the start of the data in the offset table + * (scaled by the index of the lookup table) + * @param baseOffsets bytes that represent offsets from the baseIndex + * @param baseOffsetsOffset the offset into the baseOffsets ByteSequence + * @param baseOffsetsLength the length of the baseOffsets ByteSequence to use + * @return the sum of the points + */ + float assembleAndSum(VectorFloat data, int baseIndex, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength); + int hammingDistance(long[] v1, long[] v2); @@ -212,12 +225,17 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float min(VectorFloat v); default float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); + } + + default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { float sum = 0.0f; float aMag = 0.0f; - for (int m = 0; m < encoded.length(); ++m) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + for (int m = 0; m < encodedLength; ++m) { + int centroidIndex = Byte.toUnsignedInt(encoded.get(m + encodedOffset)); var index = m * clusterCount + centroidIndex; sum += partialSums.get(index); aMag += aMagnitude.get(index); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 13959f595..e189f86ef 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -121,6 +121,14 @@ public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence b return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat)data).get(), dataBase, ((MemorySegmentByteSequence)baseOffsets).get(), baseOffsets.length()); } + @Override + public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) + { + assert baseOffsetsOffset == 0; + assert baseOffsetsLength == baseOffsets.length(); + return assembleAndSum(data, dataBase, baseOffsets); + } + @Override public int hammingDistance(long[] v1, long[] v2) { return VectorSimdOps.hammingDistance(v1, v2); diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 3146f360a..4e66ce42b 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -112,7 +112,13 @@ public void minInPlace(VectorFloat v1, VectorFloat v2) { @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { - return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence) baseOffsets)); + return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence) baseOffsets), + 0, baseOffsets.length()); + } + + @Override + public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { + return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ByteSequence) baseOffsets), baseOffsetsOffset, baseOffsetsLength); } @Override @@ -177,9 +183,14 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); + } + + @Override + public float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return SimdOps.pqDecodedCosineSimilarity((ByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + return SimdOps.pqDecodedCosineSimilarity((ByteSequence) encoded, encodedOffset, encodedLength, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); } @Override diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 1239354de..7fa804777 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -632,25 +632,25 @@ static void minInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { } } - static float assembleAndSum(float[] data, int dataBase, ByteSequence baseOffsets) { + static float assembleAndSum(float[] data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { return switch (PREFERRED_BIT_SIZE) { - case 512 -> assembleAndSum512(data, dataBase, baseOffsets); - case 256 -> assembleAndSum256(data, dataBase, baseOffsets); - case 128 -> assembleAndSum128(data, dataBase, baseOffsets); + case 512 -> assembleAndSum512(data, dataBase, baseOffsets, baseOffsetsOffset, baseOffsetsLength); + case 256 -> assembleAndSum256(data, dataBase, baseOffsets, baseOffsetsOffset, baseOffsetsLength); + case 128 -> assembleAndSum128(data, dataBase, baseOffsets, baseOffsetsOffset, baseOffsetsLength); default -> throw new IllegalStateException("Unsupported vector width: " + PREFERRED_BIT_SIZE); }; } - static float assembleAndSum512(float[] data, int dataBase, ByteSequence baseOffsets) { + static float assembleAndSum512(float[] data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { int[] convOffsets = scratchInt512.get(); FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512); int i = 0; - int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length()); + int limit = ByteVector.SPECIES_128.loopBound(baseOffsetsLength); var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets.get(), i + baseOffsets.offset()) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets.get(), i + baseOffsets.offset() + baseOffsetsOffset) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() @@ -664,22 +664,22 @@ static float assembleAndSum512(float[] data, int dataBase, ByteSequence float res = sum.reduceLanes(VectorOperators.ADD); //Process tail - for (; i < baseOffsets.length(); i++) - res += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))]; + for (; i < baseOffsetsLength; i++) + res += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset))]; return res; } - static float assembleAndSum256(float[] data, int dataBase, ByteSequence baseOffsets) { + static float assembleAndSum256(float[] data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { int[] convOffsets = scratchInt256.get(); FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); int i = 0; - int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length()); + int limit = ByteVector.SPECIES_64.loopBound(baseOffsetsLength); var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets.get(), i + baseOffsets.offset()) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets.get(), i + baseOffsets.offset() + baseOffsetsOffset) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts() @@ -693,17 +693,17 @@ static float assembleAndSum256(float[] data, int dataBase, ByteSequence float res = sum.reduceLanes(VectorOperators.ADD); // Process tail - for (; i < baseOffsets.length(); i++) - res += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))]; + for (; i < baseOffsetsLength; i++) + res += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset))]; return res; } - static float assembleAndSum128(float[] data, int dataBase, ByteSequence baseOffsets) { + static float assembleAndSum128(float[] data, int dataBase, ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { // benchmarking a 128-bit SIMD implementation showed it performed worse than scalar float sum = 0f; - for (int i = 0; i < baseOffsets.length(); i++) { - sum += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))]; + for (int i = 0; i < baseOffsetsLength; i++) { + sum += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset))]; } return sum; } @@ -791,16 +791,16 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } - public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffset, int encodedLength, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { return switch (PREFERRED_BIT_SIZE) { - case 512 -> pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); - case 256 -> pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); - case 128 -> pqDecodedCosineSimilarity128(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + case 512 -> pqDecodedCosineSimilarity512(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude); + case 256 -> pqDecodedCosineSimilarity256(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude); + case 128 -> pqDecodedCosineSimilarity128(encoded, encodedOffset, encodedLength, clusterCount, partialSums, aMagnitude, bMagnitude); default -> throw new IllegalStateException("Unsupported vector width: " + PREFERRED_BIT_SIZE); }; } - public static float pqDecodedCosineSimilarity512(ByteSequence baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity512(ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); var partialSumsArray = partialSums.get(); @@ -808,13 +808,13 @@ public static float pqDecodedCosineSimilarity512(ByteSequence baseOffset int[] convOffsets = scratchInt512.get(); int i = 0; - int limit = i + ByteVector.SPECIES_128.loopBound(baseOffsets.length()); + int limit = i + ByteVector.SPECIES_128.loopBound(baseOffsetsLength); var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets.get(), i + baseOffsets.offset()) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets.get(), i + baseOffsets.offset() + baseOffsetsOffset) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() @@ -829,8 +829,8 @@ public static float pqDecodedCosineSimilarity512(ByteSequence baseOffset float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length(); i++) { - int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets.get(i)); + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset)); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset]; } @@ -838,7 +838,7 @@ public static float pqDecodedCosineSimilarity512(ByteSequence baseOffset return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } - public static float pqDecodedCosineSimilarity256(ByteSequence baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity256(ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_256); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); var partialSumsArray = partialSums.get(); @@ -846,13 +846,13 @@ public static float pqDecodedCosineSimilarity256(ByteSequence baseOffset int[] convOffsets = scratchInt256.get(); int i = 0; - int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length()); + int limit = ByteVector.SPECIES_64.loopBound(baseOffsetsLength); var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount); for (; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets.get(), i + baseOffsets.offset()) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets.get(), i + baseOffsets.offset() + baseOffsetsOffset) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts() @@ -867,8 +867,8 @@ public static float pqDecodedCosineSimilarity256(ByteSequence baseOffset float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length(); i++) { - int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets.get(i)); + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets.get(i + baseOffsetsOffset)); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset]; } @@ -876,13 +876,13 @@ public static float pqDecodedCosineSimilarity256(ByteSequence baseOffset return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } - public static float pqDecodedCosineSimilarity128(ByteSequence baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity128(ByteSequence baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { // benchmarking showed that a 128-bit SIMD implementation performed worse than scalar float sum = 0.0f; float aMag = 0.0f; - for (int m = 0; m < baseOffsets.length(); ++m) { - int centroidIndex = Byte.toUnsignedInt(baseOffsets.get(m)); + for (int m = 0; m < baseOffsetsLength; ++m) { + int centroidIndex = Byte.toUnsignedInt(baseOffsets.get(m + baseOffsetsOffset)); var index = m * clusterCount + centroidIndex; sum += partialSums.get(index); aMag += aMagnitude.get(index);