Skip to content

Commit d623b1f

Browse files
authored
Merge pull request #402 from datastax/hnsw-3
Add hierarchical structure to the graph index
2 parents 5ab68d8 + 2c24d22 commit d623b1f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+2648
-1461
lines changed

README.md

+13-9
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,19 @@ There are two broad categories of ANN index:
1010

1111
Graph-based indexes tend to be simpler to implement and faster, but more importantly they can be constructed and updated incrementally. This makes them a much better fit for a general-purpose index than partitioning approaches that only work on static datasets that are completely specified up front. That is why all the major commercial vector indexes use graph approaches.
1212

13-
JVector is a graph index in the DiskANN family tree.
13+
JVector is a graph index that merges the DiskANN and HNSW family trees.
14+
JVector borrows the hierarchical structure from HNSW, and uses Vamana (the algorithm behind DiskANN) within each layer.
1415

1516

1617
## JVector Architecture
1718

18-
JVector is a graph-based index that builds on the DiskANN design with composeable extensions.
19+
JVector is a graph-based index that builds on the HNSW and DiskANN designs with composable extensions.
1920

20-
JVector implements a single-layer graph with nonblocking concurrency control, allowing construction to scale linearly with the number of cores:
21+
JVector implements a multi-layer graph with nonblocking concurrency control, allowing construction to scale linearly with the number of cores:
2122
![JVector scales linearly as thread count increases](https://github.com/jbellis/jvector/assets/42158/f0127bfc-6c45-48b9-96ea-95b2120da0d9)
2223

23-
The graph is represented by an on-disk adjacency list per node, with additional data stored inline to support two-pass searches, with the first pass powered by lossily compressed representations of the vectors kept in memory, and the second by a more accurate representation read from disk. The first pass can be performed with
24+
The upper layers of the hierarchy are represnted by an in-memory adjacency list per node. This allows for quick navigation with no IOs.
25+
The bottom layer of the graph is represented by an on-disk adjacency list per node. JVector uses additional data stored inline to support two-pass searches, with the first pass powered by lossily compressed representations of the vectors kept in memory, and the second by a more accurate representation read from disk. The first pass can be performed with
2426
* Product quantization (PQ), optionally with [anisotropic weighting](https://arxiv.org/abs/1908.10396)
2527
* [Binary quantization](https://huggingface.co/blog/embedding-quantization) (BQ)
2628
* Fused ADC, where PQ codebooks are transposed and written inline with the graph adjacency list
@@ -51,15 +53,16 @@ First the code:
5153
int originalDimension = baseVectors.get(0).length();
5254
// wrap the raw vectors in a RandomAccessVectorValues
5355
RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension);
54-
56+
5557
// score provider using the raw, in-memory vectors
5658
BuildScoreProvider bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, VectorSimilarityFunction.EUCLIDEAN);
5759
try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp,
5860
ravv.dimension(),
5961
16, // graph degree
6062
100, // construction search depth
6163
1.2f, // allow degree overflow during construction by this factor
62-
1.2f)) // relax neighbor diversity requirement by this factor
64+
1.2f, // relax neighbor diversity requirement by this factor (alpha)
65+
true)) // use a hierarchical index
6366
{
6467
// build the index (in memory)
6568
OnHeapGraphIndex index = builder.build(ravv);
@@ -86,6 +89,7 @@ Commentary:
8689
* For the overflow Builder parameter, the sweet spot is about 1.2 for in-memory construction and 1.5 for on-disk. (The more overflow is allowed, the fewer recomputations of best edges are required, but the more neighbors will be consulted in every search.)
8790
* The alpha parameter controls the tradeoff between edge distance and diversity; usually 1.2 is sufficient for high-dimensional vectors; 2.0 is recommended for 2D or 3D datasets. See [the DiskANN paper](https://suhasjs.github.io/files/diskann_neurips19.pdf) for more details.
8891
* The Bits parameter to GraphSearcher is intended for controlling your resultset based on external predicates and won’t be used in this tutorial.
92+
* Setting the addHierarchy parameter to true, build a multi-layer index. This approach has proven more robust in highly challenging scenarios.
8993

9094

9195
#### Step 2: more control over GraphSearcher
@@ -129,7 +133,7 @@ This is expected given the approximate nature of the index being created and the
129133
The code:
130134
```java
131135
Path indexPath = Files.createTempFile("siftsmall", ".inline");
132-
try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f)) {
136+
try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, true)) {
133137
// build the index (in memory)
134138
OnHeapGraphIndex index = builder.build(ravv);
135139
// write the index to disk with default options
@@ -218,7 +222,7 @@ Then we need to set up an OnDiskGraphIndexWriter with full control over the cons
218222
Path indexPath = Files.createTempFile("siftsmall", ".inline");
219223
Path pqPath = Files.createTempFile("siftsmall", ".pq");
220224
// Builder creation looks mostly the same
221-
try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f);
225+
try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f, true);
222226
// explicit Writer for the first time, this is what's behind OnDiskGraphIndex.write
223227
OnDiskGraphIndexWriter writer = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexPath)
224228
.with(new InlineVectors(ravv.dimension()))
@@ -259,7 +263,7 @@ Commentary:
259263

260264
### Less-obvious points
261265

262-
* Embeddings models product output from a consistent distribution of vectors. This means that you can save and re-use ProductQuantization codebooks, even for a different set of vectors, as long as you had a sufficiently large training set to build it the first time around. ProductQuantization.MAX_PQ_TRAINING_SET_SIZE (128,000 vectors) has proven to be sufficiently large.
266+
* Embeddings models produce output from a consistent distribution of vectors. This means that you can save and re-use ProductQuantization codebooks, even for a different set of vectors, as long as you had a sufficiently large training set to build it the first time around. ProductQuantization.MAX_PQ_TRAINING_SET_SIZE (128,000 vectors) has proven to be sufficiently large.
263267
* JDK ThreadLocal objects cannot be referenced except from the thread that created them. This is a difficult design into which to fit caching of Closeable objects like GraphSearcher. JVector provides the ExplicitThreadLocal class to solve this.
264268
* Fused ADC is only compatible with Product Quantization, not Binary Quantization. This is no great loss since [very few models generate embeddings that are best suited for BQ](https://thenewstack.io/why-vector-size-matters/). That said, BQ continues to be supported with non-Fused indexes.
265269
* JVector heavily utilizes the Panama Vector API(SIMD) for ANN indexing and search. We have seen cases where the memory bandwidth is saturated during indexing and product quantization and can cause the process to slow down. To avoid this, the batch methods for index and PQ builds use a [PhysicalCoreExecutor](https://javadoc.io/doc/io.github.jbellis/jvector/latest/io/github/jbellis/jvector/util/PhysicalCoreExecutor.html) to limit the amount of operations to the physical core count. The default value is 1/2 the processor count seen by Java. This may not be correct in all setups (e.g. no hyperthreading or hybrid architectures) so if you wish to override the default use the `-Djvector.physical_core_count` property, or pass in your own ForkJoinPool instance.

UPGRADING.md

+11
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,22 @@
55
in each vector with high accuracy by first applying a nonlinear transformation that is individually fit to each
66
vector. These nonlinearities are designed to be lightweight and have a negligible impact on distance computation
77
performance.
8+
- Support for hierarchical graph indices. This new type of index blends HNSW and DiskANN in a novel way. An
9+
HNSW-like hierarchy resides in memory for quickly seeding the search. This also reduces the need for caching the
10+
DiskANN graph near the entrypoint. The base layer of the hierarchy is a DiskANN-like index and inherits its
11+
properties. This hierarchical structure can be disabled, ending up with just the base DiskANN layer.
812

913
## API changes
1014
- MemorySegmentReader.Supplier and SimpleMappedReader.Supplier must now be explicitly closed, instead of being
1115
closed by the first Reader created from them.
1216
- OnDiskGraphIndex no longer closes its ReaderSupplier
17+
- The constructor of GraphIndexBuilder takes an additional parameter which allows to enable or disable the use of the
18+
hierarchy.
19+
- GraphSearcher can be configured to run pruned searches using GraphSearcher.usePruning. When this is set to true,
20+
we do early termination of the search. In certain cases, this can accelerate the search at the potential cost of some
21+
accuracy. It is set to false by default.
22+
- The constructors of GraphIndexBuilder allow to specify different maximum out-degrees for the graphs in each layer.
23+
However, this feature does not work with FusedADC in this version.
1324

1425
### API changes in 3.0.6
1526

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/IndexConstructionWithStaticSetBenchmark.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public void tearDown() throws IOException {
7878
@Benchmark
7979
public void buildIndexBenchmark(Blackhole blackhole) throws IOException {
8080
// score provider using the raw, in-memory vectors
81-
try (final var graphIndexBuilder = new GraphIndexBuilder(bsp, ravv.dimension(), M, beamWidth, 1.2f, 1.2f)) {
81+
try (final var graphIndexBuilder = new GraphIndexBuilder(bsp, ravv.dimension(), M, beamWidth, 1.2f, 1.2f, true)) {
8282
final var graphIndex = graphIndexBuilder.build(ravv);
8383
blackhole.consume(graphIndex);
8484
}

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/RandomVectorsBenchmark.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ public void setup() throws IOException {
8181
16, // graph degree
8282
100, // construction search depth
8383
1.2f, // allow degree overflow during construction by this factor
84-
1.2f); // relax neighbor diversity requirement by this factor
84+
1.2f, // relax neighbor diversity requirement by this factor
85+
true); // add the hierarchy
8586
graphIndex = graphIndexBuilder.build(ravv);
8687
}
8788

benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/StaticSetVectorsBenchmark.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ public void setup() throws IOException {
6969
16, // graph degree
7070
100, // construction search depth
7171
1.2f, // allow degree overflow during construction by this factor
72-
1.2f); // relax neighbor diversity requirement by this factor
72+
1.2f, // relax neighbor diversity requirement by this factor
73+
true); // add the hierarchy
7374
graphIndex = graphIndexBuilder.build(ravv);
7475
}
7576

jvector-base/src/main/java/io/github/jbellis/jvector/disk/ByteBufferReader.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@ public int readInt() {
7474
}
7575

7676
@Override
77-
public float readFloat() throws IOException {
77+
public long readLong() {
78+
return bb.getLong();
79+
}
80+
81+
@Override
82+
public float readFloat() {
7883
return bb.getFloat();
7984
}
8085

jvector-base/src/main/java/io/github/jbellis/jvector/disk/RandomAccessReader.java

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ public interface RandomAccessReader extends AutoCloseable {
4040

4141
float readFloat() throws IOException;
4242

43+
long readLong() throws IOException;
44+
4345
void readFully(byte[] bytes) throws IOException;
4446

4547
void readFully(ByteBuffer buffer) throws IOException;

jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleReader.java

+5
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ public int readInt() throws IOException {
4848
return raf.readInt();
4949
}
5050

51+
@Override
52+
public long readLong() throws IOException {
53+
return raf.readLong();
54+
}
55+
5156
@Override
5257
public float readFloat() throws IOException {
5358
return raf.readFloat();

jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborMap.java

+20-11
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,38 @@
2424
import io.github.jbellis.jvector.util.DenseIntMap;
2525
import io.github.jbellis.jvector.util.DocIdSetIterator;
2626
import io.github.jbellis.jvector.util.FixedBitSet;
27+
import io.github.jbellis.jvector.util.IntMap;
2728

2829
import static java.lang.Math.min;
2930

3031
/**
3132
* Encapsulates operations on a graph's neighbors.
3233
*/
3334
public class ConcurrentNeighborMap {
34-
private final DenseIntMap<Neighbors> neighbors;
35+
final IntMap<Neighbors> neighbors;
3536

3637
/** the diversity threshold; 1.0 is equivalent to HNSW; Vamana uses 1.2 or more */
37-
private final float alpha;
38+
final float alpha;
3839

3940
/** used to compute diversity */
40-
private final BuildScoreProvider scoreProvider;
41+
final BuildScoreProvider scoreProvider;
4142

4243
/** the maximum number of neighbors desired per node */
4344
public final int maxDegree;
4445
/** the maximum number of neighbors a node can have temporarily during construction */
4546
public final int maxOverflowDegree;
4647

4748
public ConcurrentNeighborMap(BuildScoreProvider scoreProvider, int maxDegree, int maxOverflowDegree, float alpha) {
49+
this(new DenseIntMap<>(1024), scoreProvider, maxDegree, maxOverflowDegree, alpha);
50+
}
51+
52+
public <T> ConcurrentNeighborMap(IntMap<Neighbors> neighbors, BuildScoreProvider scoreProvider, int maxDegree, int maxOverflowDegree, float alpha) {
53+
assert maxDegree <= maxOverflowDegree : String.format("maxDegree %d exceeds maxOverflowDegree %d", maxDegree, maxOverflowDegree);
54+
this.neighbors = neighbors;
4855
this.alpha = alpha;
4956
this.scoreProvider = scoreProvider;
5057
this.maxDegree = maxDegree;
5158
this.maxOverflowDegree = maxOverflowDegree;
52-
neighbors = new DenseIntMap<>(1024);
5359
}
5460

5561
public void insertEdge(int fromId, int toId, float score, float overflow) {
@@ -103,6 +109,7 @@ public void replaceDeletedNeighbors(int nodeId, BitSet toDelete, NodeArray candi
103109
public Neighbors insertDiverse(int nodeId, NodeArray candidates) {
104110
while (true) {
105111
var old = neighbors.get(nodeId);
112+
assert old != null : nodeId; // graph.addNode should always be called before this method
106113
var next = old.insertDiverse(candidates, this);
107114
if (next == old || neighbors.compareAndPut(nodeId, old, next)) {
108115
return next;
@@ -132,10 +139,6 @@ public void addNode(int nodeId) {
132139
addNode(nodeId, new NodeArray(0));
133140
}
134141

135-
public NodesIterator nodesIterator() {
136-
return neighbors.keysIterator();
137-
}
138-
139142
public Neighbors remove(int node) {
140143
return neighbors.remove(node);
141144
}
@@ -262,7 +265,9 @@ private Neighbors insertDiverse(NodeArray toMerge, ConcurrentNeighborMap map) {
262265
retainDiverse(merged, 0, map);
263266
}
264267
// insertDiverse usually gets called with a LOT of candidates, so trim down the resulting NodeArray
265-
var nextNodes = merged.getArrayLength() <= map.nodeArrayLength() ? merged : merged.copy(map.nodeArrayLength());
268+
var nextNodes = merged.getArrayLength() <= map.nodeArrayLength()
269+
? merged
270+
: merged.copy(map.nodeArrayLength());
266271
return new Neighbors(nodeId, nextNodes);
267272
}
268273

@@ -402,16 +407,20 @@ public NeighborWithShortEdges(Neighbors neighbors, double shortEdges) {
402407
}
403408
}
404409

405-
private static class NeighborIterator extends NodesIterator {
410+
private static class NeighborIterator implements NodesIterator {
406411
private final NodeArray neighbors;
407412
private int i;
408413

409414
private NeighborIterator(NodeArray neighbors) {
410-
super(neighbors.size());
411415
this.neighbors = neighbors;
412416
i = 0;
413417
}
414418

419+
@Override
420+
public int size() {
421+
return neighbors.size();
422+
}
423+
415424
@Override
416425
public boolean hasNext() {
417426
return i < neighbors.size();

0 commit comments

Comments
 (0)