Skip to content

Commit 9ffd328

Browse files
committed
Add a memory-mapped RandomAccessReader using MemorySegment api
I'd prefer this to be in `jvector-twenty` module. But `--enable-preview` flag is only allowed for the Java release version used to compile the code. When building with Java 22, `--enable-preview` is not allowed on `twenty` module because it builds for Java 20.
1 parent 70a6df8 commit 9ffd328

File tree

4 files changed

+317
-0
lines changed

4 files changed

+317
-0
lines changed

jvector-native/pom.xml

+8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
</nonFilteredFileExtensions>
3636
</configuration>
3737
</plugin>
38+
<plugin>
39+
<groupId>org.apache.maven.plugins</groupId>
40+
<artifactId>maven-surefire-plugin</artifactId>
41+
<version>3.1.2</version>
42+
<configuration>
43+
<skip>false</skip>
44+
</configuration>
45+
</plugin>
3846
</plugins>
3947
</build>
4048
<profiles>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.disk;
18+
19+
import java.io.IOException;
20+
import java.lang.foreign.Arena;
21+
import java.lang.foreign.MemorySegment;
22+
import java.lang.foreign.ValueLayout;
23+
import java.lang.foreign.ValueLayout.OfFloat;
24+
import java.lang.foreign.ValueLayout.OfInt;
25+
import java.lang.foreign.ValueLayout.OfLong;
26+
import java.nio.ByteBuffer;
27+
import java.nio.ByteOrder;
28+
import java.nio.channels.FileChannel;
29+
import java.nio.channels.FileChannel.MapMode;
30+
import java.nio.file.Path;
31+
import java.nio.file.StandardOpenOption;
32+
33+
/**
34+
* {@link MemorySegment} based implementation of RandomAccessReader.
35+
* MemorySegmentReader doesn't have 2GB file size limitation of {@link SimpleMappedReader}.
36+
*/
37+
public class MemorySegmentReader implements RandomAccessReader {
38+
39+
private static final OfInt intLayout = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
40+
private static final OfFloat floatLayout = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
41+
private static final OfLong longLayout = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
42+
43+
private final Arena arena;
44+
private final MemorySegment memory;
45+
private long position = 0;
46+
47+
public MemorySegmentReader(Path path) throws IOException {
48+
arena = Arena.ofShared();
49+
try (var ch = FileChannel.open(path, StandardOpenOption.READ)) {
50+
memory = ch.map(MapMode.READ_ONLY, 0L, ch.size(), arena);
51+
} catch (Exception e) {
52+
arena.close();
53+
throw e;
54+
}
55+
}
56+
57+
private MemorySegmentReader(Arena arena, MemorySegment memory) {
58+
this.arena = arena;
59+
this.memory = memory;
60+
}
61+
62+
@Override
63+
public void seek(long offset) {
64+
this.position = offset;
65+
}
66+
67+
@Override
68+
public long getPosition() {
69+
return position;
70+
}
71+
72+
@Override
73+
public void readFully(float[] buffer) {
74+
MemorySegment.copy(memory, floatLayout, position, buffer, 0, buffer.length);
75+
position += buffer.length * 4L;
76+
}
77+
78+
@Override
79+
public void readFully(byte[] b) {
80+
MemorySegment.copy(memory, ValueLayout.JAVA_BYTE, position, b, 0, b.length);
81+
position += b.length;
82+
}
83+
84+
@Override
85+
public void readFully(ByteBuffer buffer) {
86+
var remaining = buffer.remaining();
87+
var slice = memory.asSlice(position, remaining).asByteBuffer();
88+
buffer.put(slice);
89+
position += remaining;
90+
}
91+
92+
@Override
93+
public void readFully(long[] vector) {
94+
MemorySegment.copy(memory, longLayout, position, vector, 0, vector.length);
95+
position += vector.length * 8L;
96+
}
97+
98+
@Override
99+
public int readInt() {
100+
var k = memory.get(intLayout, position);
101+
position += 4;
102+
return k;
103+
}
104+
105+
@Override
106+
public float readFloat() {
107+
var f = memory.get(floatLayout, position);
108+
position += 4;
109+
return f;
110+
}
111+
112+
@Override
113+
public void read(int[] ints, int offset, int count) {
114+
MemorySegment.copy(memory, intLayout, position, ints, offset, count);
115+
position += count * 4L;
116+
}
117+
118+
@Override
119+
public void read(float[] floats, int offset, int count) {
120+
MemorySegment.copy(memory, floatLayout, position, floats, offset, count);
121+
position += count * 4L;
122+
}
123+
124+
/**
125+
* Loads the contents of the mapped segment into physical memory.
126+
* This is a best-effort mechanism.
127+
*/
128+
public void loadMemory() {
129+
memory.load();
130+
}
131+
132+
@Override
133+
public void close() {
134+
arena.close();
135+
}
136+
137+
public MemorySegmentReader duplicate() {
138+
return new MemorySegmentReader(arena, memory);
139+
}
140+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.jbellis.jvector.disk;
17+
18+
import java.io.IOException;
19+
import java.nio.file.Path;
20+
21+
public class MemorySegmentReaderSupplier implements ReaderSupplier {
22+
private final MemorySegmentReader reader;
23+
24+
public MemorySegmentReaderSupplier(Path path) throws IOException {
25+
reader = new MemorySegmentReader(path);
26+
}
27+
28+
@Override
29+
public RandomAccessReader get() {
30+
return reader.duplicate();
31+
}
32+
33+
@Override
34+
public void close() {
35+
reader.close();
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package io.github.jbellis.jvector.disk;
2+
3+
import com.carrotsearch.randomizedtesting.RandomizedTest;
4+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
5+
import org.junit.After;
6+
import org.junit.Before;
7+
import org.junit.Test;
8+
9+
import java.io.DataOutputStream;
10+
import java.io.FileOutputStream;
11+
import java.io.IOException;
12+
import java.nio.ByteBuffer;
13+
import java.nio.file.Files;
14+
import java.nio.file.Path;
15+
16+
import static org.junit.Assert.assertEquals;
17+
import static org.junit.Assert.fail;
18+
19+
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
20+
public class MemorySegmentReaderTest extends RandomizedTest {
21+
22+
private Path tempFile;
23+
24+
@Before
25+
public void setup() throws IOException {
26+
tempFile = Files.createTempFile(getClass().getSimpleName(), ".data");
27+
28+
try (var out = new DataOutputStream(new FileOutputStream(tempFile.toFile()))) {
29+
out.write(new byte[] {1, 2, 3, 4, 5, 6, 7});
30+
for (int i = 0; i < 5; i++) {
31+
out.writeInt((i + 1) * 19);
32+
}
33+
for (int i = 0; i < 5; i++) {
34+
out.writeLong((i + 1) * 19L);
35+
}
36+
for (int i = 0; i < 5; i++) {
37+
out.writeFloat((i + 1) * 19);
38+
}
39+
}
40+
}
41+
42+
@After
43+
public void tearDown() throws IOException {
44+
Files.deleteIfExists(tempFile);
45+
}
46+
47+
@Test
48+
public void testReader() throws Exception {
49+
try (var r = new MemorySegmentReader(tempFile)) {
50+
verifyReader(r);
51+
52+
// read 2nd time from beginning
53+
verifyReader(r);
54+
}
55+
}
56+
57+
@Test
58+
public void testReaderDuplicate() throws Exception {
59+
try (var r = new MemorySegmentReader(tempFile)) {
60+
for (int i = 0; i < 3; i++) {
61+
var r2 = r.duplicate();
62+
verifyReader(r2);
63+
}
64+
}
65+
}
66+
67+
@Test
68+
public void testReaderClose() throws Exception {
69+
var r = new MemorySegmentReader(tempFile);
70+
var r2 = r.duplicate();
71+
72+
r.close();
73+
74+
try {
75+
r.readInt();
76+
fail("Should have thrown an exception");
77+
} catch (IllegalStateException _) {
78+
}
79+
80+
try {
81+
r2.readInt();
82+
fail("Should have thrown an exception");
83+
} catch (IllegalStateException _) {
84+
}
85+
}
86+
87+
private void verifyReader(MemorySegmentReader r) {
88+
r.seek(0);
89+
var bytes = new byte[7];
90+
r.readFully(bytes);
91+
for (int i = 0; i < bytes.length; i++) {
92+
assertEquals(i + 1, bytes[i]);
93+
}
94+
95+
r.seek(0);
96+
var buff = ByteBuffer.allocate(6);
97+
r.readFully(buff);
98+
for (int i = 0; i < buff.remaining(); i++) {
99+
assertEquals(i + 1, buff.get(i));
100+
}
101+
102+
r.seek(7);
103+
assertEquals(19, r.readInt());
104+
105+
r.seek(7);
106+
var ints = new int[5];
107+
r.read(ints, 0, ints.length);
108+
for (int i = 0; i < ints.length; i++) {
109+
var k = ints[i];
110+
assertEquals((i + 1) * 19, k);
111+
}
112+
113+
r.seek(7 + (4 * 5));
114+
var longs = new long[5];
115+
r.readFully(longs);
116+
for (int i = 0; i < longs.length; i++) {
117+
var l = longs[i];
118+
assertEquals((i + 1) * 19, l);
119+
}
120+
121+
r.seek(7 + (4 * 5) + (8 * 5));
122+
assertEquals(19, r.readFloat(), 0.01);
123+
124+
r.seek(7 + (4 * 5) + (8 * 5));
125+
var floats = new float[5];
126+
r.readFully(floats);
127+
for (int i = 0; i < floats.length; i++) {
128+
var f = floats[i];
129+
assertEquals((i + 1) * 19f, f, 0.01);
130+
}
131+
}
132+
}

0 commit comments

Comments
 (0)