Skip to content

Commit edf50bb

Browse files
committed
Add a memory-mapped RandomAccessReader using MemorySegment api
I'd prefer this to be in `jvector-twenty` module. But `--enable-preview` flag is only allowed for the Java release version used to compile the code. When building with Java 22, `--enable-preview` is not allowed on `twenty` module because it builds for Java 20.
1 parent 70a6df8 commit edf50bb

File tree

4 files changed

+333
-0
lines changed

4 files changed

+333
-0
lines changed

jvector-native/pom.xml

+8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
</nonFilteredFileExtensions>
3636
</configuration>
3737
</plugin>
38+
<plugin>
39+
<groupId>org.apache.maven.plugins</groupId>
40+
<artifactId>maven-surefire-plugin</artifactId>
41+
<version>3.1.2</version>
42+
<configuration>
43+
<skip>false</skip>
44+
</configuration>
45+
</plugin>
3846
</plugins>
3947
</build>
4048
<profiles>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.disk;
18+
19+
import java.io.IOException;
20+
import java.lang.foreign.Arena;
21+
import java.lang.foreign.MemorySegment;
22+
import java.lang.foreign.ValueLayout;
23+
import java.lang.foreign.ValueLayout.OfFloat;
24+
import java.lang.foreign.ValueLayout.OfInt;
25+
import java.lang.foreign.ValueLayout.OfLong;
26+
import java.nio.ByteBuffer;
27+
import java.nio.ByteOrder;
28+
import java.nio.channels.FileChannel;
29+
import java.nio.channels.FileChannel.MapMode;
30+
import java.nio.file.Path;
31+
import java.nio.file.StandardOpenOption;
32+
33+
/**
34+
* {@link MemorySegment} based implementation of RandomAccessReader.
35+
* MemorySegmentReader doesn't have 2GB file size limitation of {@link SimpleMappedReader}.
36+
*/
37+
public class MemorySegmentReader implements RandomAccessReader {
38+
39+
private static final OfInt intLayout = ValueLayout.JAVA_INT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
40+
private static final OfFloat floatLayout = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
41+
private static final OfLong longLayout = ValueLayout.JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
42+
43+
private final Arena arena;
44+
private final MemorySegment memory;
45+
private long position = 0;
46+
47+
public MemorySegmentReader(Path path) throws IOException {
48+
arena = Arena.ofShared();
49+
try (var ch = FileChannel.open(path, StandardOpenOption.READ)) {
50+
memory = ch.map(MapMode.READ_ONLY, 0L, ch.size(), arena);
51+
} catch (Exception e) {
52+
arena.close();
53+
throw e;
54+
}
55+
}
56+
57+
private MemorySegmentReader(Arena arena, MemorySegment memory) {
58+
this.arena = arena;
59+
this.memory = memory;
60+
}
61+
62+
@Override
63+
public void seek(long offset) {
64+
this.position = offset;
65+
}
66+
67+
@Override
68+
public long getPosition() {
69+
return position;
70+
}
71+
72+
@Override
73+
public void readFully(float[] buffer) {
74+
MemorySegment.copy(memory, floatLayout, position, buffer, 0, buffer.length);
75+
position += buffer.length * 4L;
76+
}
77+
78+
@Override
79+
public void readFully(byte[] b) {
80+
MemorySegment.copy(memory, ValueLayout.JAVA_BYTE, position, b, 0, b.length);
81+
position += b.length;
82+
}
83+
84+
@Override
85+
public void readFully(ByteBuffer buffer) {
86+
var remaining = buffer.remaining();
87+
var slice = memory.asSlice(position, remaining).asByteBuffer();
88+
buffer.put(slice);
89+
position += remaining;
90+
}
91+
92+
@Override
93+
public void readFully(long[] vector) {
94+
MemorySegment.copy(memory, longLayout, position, vector, 0, vector.length);
95+
position += vector.length * 8L;
96+
}
97+
98+
@Override
99+
public int readInt() {
100+
var k = memory.get(intLayout, position);
101+
position += 4;
102+
return k;
103+
}
104+
105+
@Override
106+
public float readFloat() {
107+
var f = memory.get(floatLayout, position);
108+
position += 4;
109+
return f;
110+
}
111+
112+
@Override
113+
public void read(int[] ints, int offset, int count) {
114+
MemorySegment.copy(memory, intLayout, position, ints, offset, count);
115+
position += count * 4L;
116+
}
117+
118+
@Override
119+
public void read(float[] floats, int offset, int count) {
120+
MemorySegment.copy(memory, floatLayout, position, floats, offset, count);
121+
position += count * 4L;
122+
}
123+
124+
/**
125+
* Loads the contents of the mapped segment into physical memory.
126+
* This is a best-effort mechanism.
127+
*/
128+
public void loadMemory() {
129+
memory.load();
130+
}
131+
132+
@Override
133+
public void close() {
134+
arena.close();
135+
}
136+
137+
public MemorySegmentReader duplicate() {
138+
return new MemorySegmentReader(arena, memory);
139+
}
140+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.jbellis.jvector.disk;
17+
18+
import java.io.IOException;
19+
import java.nio.file.Path;
20+
21+
public class MemorySegmentReaderSupplier implements ReaderSupplier {
22+
private final MemorySegmentReader reader;
23+
24+
public MemorySegmentReaderSupplier(Path path) throws IOException {
25+
reader = new MemorySegmentReader(path);
26+
}
27+
28+
@Override
29+
public RandomAccessReader get() {
30+
return reader.duplicate();
31+
}
32+
33+
@Override
34+
public void close() {
35+
reader.close();
36+
}
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.github.jbellis.jvector.disk;
18+
19+
import com.carrotsearch.randomizedtesting.RandomizedTest;
20+
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
21+
import org.junit.After;
22+
import org.junit.Before;
23+
import org.junit.Test;
24+
25+
import java.io.DataOutputStream;
26+
import java.io.FileOutputStream;
27+
import java.io.IOException;
28+
import java.nio.ByteBuffer;
29+
import java.nio.file.Files;
30+
import java.nio.file.Path;
31+
32+
import static org.junit.Assert.assertEquals;
33+
import static org.junit.Assert.fail;
34+
35+
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
36+
public class MemorySegmentReaderTest extends RandomizedTest {
37+
38+
private Path tempFile;
39+
40+
@Before
41+
public void setup() throws IOException {
42+
tempFile = Files.createTempFile(getClass().getSimpleName(), ".data");
43+
44+
try (var out = new DataOutputStream(new FileOutputStream(tempFile.toFile()))) {
45+
out.write(new byte[] {1, 2, 3, 4, 5, 6, 7});
46+
for (int i = 0; i < 5; i++) {
47+
out.writeInt((i + 1) * 19);
48+
}
49+
for (int i = 0; i < 5; i++) {
50+
out.writeLong((i + 1) * 19L);
51+
}
52+
for (int i = 0; i < 5; i++) {
53+
out.writeFloat((i + 1) * 19);
54+
}
55+
}
56+
}
57+
58+
@After
59+
public void tearDown() throws IOException {
60+
Files.deleteIfExists(tempFile);
61+
}
62+
63+
@Test
64+
public void testReader() throws Exception {
65+
try (var r = new MemorySegmentReader(tempFile)) {
66+
verifyReader(r);
67+
68+
// read 2nd time from beginning
69+
verifyReader(r);
70+
}
71+
}
72+
73+
@Test
74+
public void testReaderDuplicate() throws Exception {
75+
try (var r = new MemorySegmentReader(tempFile)) {
76+
for (int i = 0; i < 3; i++) {
77+
var r2 = r.duplicate();
78+
verifyReader(r2);
79+
}
80+
}
81+
}
82+
83+
@Test
84+
public void testReaderClose() throws Exception {
85+
var r = new MemorySegmentReader(tempFile);
86+
var r2 = r.duplicate();
87+
88+
r.close();
89+
90+
try {
91+
r.readInt();
92+
fail("Should have thrown an exception");
93+
} catch (IllegalStateException _) {
94+
}
95+
96+
try {
97+
r2.readInt();
98+
fail("Should have thrown an exception");
99+
} catch (IllegalStateException _) {
100+
}
101+
}
102+
103+
private void verifyReader(MemorySegmentReader r) {
104+
r.seek(0);
105+
var bytes = new byte[7];
106+
r.readFully(bytes);
107+
for (int i = 0; i < bytes.length; i++) {
108+
assertEquals(i + 1, bytes[i]);
109+
}
110+
111+
r.seek(0);
112+
var buff = ByteBuffer.allocate(6);
113+
r.readFully(buff);
114+
for (int i = 0; i < buff.remaining(); i++) {
115+
assertEquals(i + 1, buff.get(i));
116+
}
117+
118+
r.seek(7);
119+
assertEquals(19, r.readInt());
120+
121+
r.seek(7);
122+
var ints = new int[5];
123+
r.read(ints, 0, ints.length);
124+
for (int i = 0; i < ints.length; i++) {
125+
var k = ints[i];
126+
assertEquals((i + 1) * 19, k);
127+
}
128+
129+
r.seek(7 + (4 * 5));
130+
var longs = new long[5];
131+
r.readFully(longs);
132+
for (int i = 0; i < longs.length; i++) {
133+
var l = longs[i];
134+
assertEquals((i + 1) * 19, l);
135+
}
136+
137+
r.seek(7 + (4 * 5) + (8 * 5));
138+
assertEquals(19, r.readFloat(), 0.01);
139+
140+
r.seek(7 + (4 * 5) + (8 * 5));
141+
var floats = new float[5];
142+
r.readFully(floats);
143+
for (int i = 0; i < floats.length; i++) {
144+
var f = floats[i];
145+
assertEquals((i + 1) * 19f, f, 0.01);
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)