Skip to content

Commit 7a6a538

Browse files
authored
Merge pull request #164 from milderhc/postgres
Add PostgreSQL query provider and mapper
2 parents c102f66 + f8f0cdc commit 7a6a538

File tree

13 files changed

+841
-170
lines changed

13 files changed

+841
-170
lines changed

api-test/integration-tests/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@
7373
<version>8.0.33</version>
7474
<scope>test</scope>
7575
</dependency>
76+
<dependency>
77+
<groupId>org.postgresql</groupId>
78+
<artifactId>postgresql</artifactId>
79+
<version>42.7.2</version> <!-- Use the latest version -->
80+
</dependency>
7681

7782
<dependency>
7883
<groupId>org.testcontainers</groupId>

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/Hotel.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class Hotel {
1818
@VectorStoreRecordVectorAttribute(dimensions = 3)
1919
private final List<Float> descriptionEmbedding;
2020
@VectorStoreRecordDataAttribute
21-
private final double rating;
21+
private double rating;
2222

2323
public Hotel() {
2424
this(null, null, 0, null, null, 0.0);
@@ -56,4 +56,8 @@ public List<Float> getDescriptionEmbedding() {
5656
public double getRating() {
5757
return rating;
5858
}
59+
60+
public void setRating(double rating) {
61+
this.rating = rating;
62+
}
5963
}

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreRecordCollectionTest.java

+145-92
Large diffs are not rendered by default.

api-test/integration-tests/src/test/java/com/microsoft/semantickernel/tests/connectors/memory/jdbc/JDBCVectorStoreTest.java

+63-27
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,71 @@
11
package com.microsoft.semantickernel.tests.connectors.memory.jdbc;
22

3-
import static org.junit.jupiter.api.Assertions.assertEquals;
4-
import static org.junit.jupiter.api.Assertions.assertNotNull;
5-
import static org.junit.jupiter.api.Assertions.assertTrue;
6-
73
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
84
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
9-
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
5+
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreQueryProvider;
6+
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
7+
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
108
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
119
import com.mysql.cj.jdbc.MysqlDataSource;
12-
import java.util.Arrays;
13-
import java.util.List;
14-
import org.junit.jupiter.api.BeforeAll;
15-
import org.junit.jupiter.api.Test;
10+
import org.junit.jupiter.params.ParameterizedTest;
11+
import org.junit.jupiter.params.provider.EnumSource;
12+
import org.postgresql.ds.PGSimpleDataSource;
1613
import org.testcontainers.containers.MySQLContainer;
14+
import org.testcontainers.containers.PostgreSQLContainer;
1715
import org.testcontainers.junit.jupiter.Container;
1816
import org.testcontainers.junit.jupiter.Testcontainers;
17+
import org.testcontainers.utility.DockerImageName;
18+
19+
import javax.annotation.Nonnull;
20+
import javax.sql.DataSource;
21+
import java.util.Arrays;
22+
import java.util.List;
23+
24+
import com.microsoft.semantickernel.tests.connectors.memory.jdbc.JDBCVectorStoreRecordCollectionTest.QueryProvider;
25+
import static org.junit.jupiter.api.Assertions.assertEquals;
26+
import static org.junit.jupiter.api.Assertions.assertNotNull;
27+
import static org.junit.jupiter.api.Assertions.assertTrue;
1928

2029
@Testcontainers
2130
public class JDBCVectorStoreTest {
2231
@Container
23-
private static final MySQLContainer<?> CONTAINER = new MySQLContainer<>("mysql:5.7.34");
24-
private static final String MYSQL_USER = "test";
25-
private static final String MYSQL_PASSWORD = "test";
26-
private static MysqlDataSource dataSource;
27-
28-
@BeforeAll
29-
static void setup() {
30-
dataSource = new MysqlDataSource();
31-
dataSource.setUrl(CONTAINER.getJdbcUrl());
32-
dataSource.setUser(MYSQL_USER);
33-
dataSource.setPassword(MYSQL_PASSWORD);
34-
}
32+
private static final MySQLContainer<?> MYSQL_CONTAINER = new MySQLContainer<>("mysql:5.7.34");
33+
34+
private static final DockerImageName PGVECTOR = DockerImageName.parse("pgvector/pgvector:pg16").asCompatibleSubstituteFor("postgres");
35+
@Container
36+
private static final PostgreSQLContainer<?> POSTGRESQL_CONTAINER = new PostgreSQLContainer<>(PGVECTOR);
37+
38+
private JDBCVectorStore buildVectorStore(QueryProvider provider) {
39+
JDBCVectorStoreQueryProvider queryProvider;
40+
DataSource dataSource;
41+
42+
switch (provider) {
43+
case MySQL:
44+
MysqlDataSource mysqlDataSource = new MysqlDataSource();
45+
mysqlDataSource.setUrl(MYSQL_CONTAINER.getJdbcUrl());
46+
mysqlDataSource.setUser(MYSQL_CONTAINER.getUsername());
47+
mysqlDataSource.setPassword(MYSQL_CONTAINER.getPassword());
48+
dataSource = mysqlDataSource;
49+
queryProvider = MySQLVectorStoreQueryProvider.builder()
50+
.withDataSource(dataSource)
51+
.build();
52+
break;
53+
case PostgreSQL:
54+
PGSimpleDataSource pgSimpleDataSource = new PGSimpleDataSource();
55+
pgSimpleDataSource.setUrl(POSTGRESQL_CONTAINER.getJdbcUrl());
56+
pgSimpleDataSource.setUser(POSTGRESQL_CONTAINER.getUsername());
57+
pgSimpleDataSource.setPassword(POSTGRESQL_CONTAINER.getPassword());
58+
dataSource = pgSimpleDataSource;
59+
queryProvider = PostgreSQLVectorStoreQueryProvider.builder()
60+
.withDataSource(dataSource)
61+
.build();
62+
break;
63+
default:
64+
throw new IllegalArgumentException("Unknown query provider: " + provider);
65+
}
3566

36-
@Test
37-
public void getCollectionNamesAsync() {
38-
MySQLVectorStoreQueryProvider queryProvider = MySQLVectorStoreQueryProvider.builder()
39-
.withDataSource(dataSource)
40-
.build();
4167

42-
JDBCVectorStore vectorStore = JDBCVectorStore.builder()
68+
JDBCVectorStore vectorStore = JDBCVectorStore.builder()
4369
.withDataSource(dataSource)
4470
.withOptions(
4571
JDBCVectorStoreOptions.builder()
@@ -48,6 +74,16 @@ public void getCollectionNamesAsync() {
4874
)
4975
.build();
5076

77+
vectorStore.prepareAsync().block();
78+
return vectorStore;
79+
}
80+
81+
82+
@ParameterizedTest
83+
@EnumSource(QueryProvider.class)
84+
public void getCollectionNamesAsync(QueryProvider provider) {
85+
JDBCVectorStore vectorStore = buildVectorStore(provider);
86+
5187
vectorStore.getCollectionNamesAsync().block();
5288

5389
List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");

samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/JDBC_DataStorage.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
99
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStore;
1010
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreOptions;
11-
import com.microsoft.semantickernel.connectors.data.jdbc.MySQLVectorStoreQueryProvider;
11+
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
1212
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
1313
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
1414
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;

semantickernel-experimental/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@
109109
</exclusions>
110110
</dependency>
111111

112+
<dependency>
113+
<groupId>org.postgresql</groupId>
114+
<artifactId>postgresql</artifactId>
115+
<version>42.7.2</version> <!-- Use the latest version -->
116+
</dependency>
112117

113118
</dependencies>
114119

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreDefaultQueryProvider.java

+65-31
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22
package com.microsoft.semantickernel.connectors.data.jdbc;
33

4+
import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
45
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
56
import com.microsoft.semantickernel.exceptions.SKException;
67
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
@@ -24,14 +25,27 @@
2425
import java.util.List;
2526
import java.util.Map;
2627
import java.util.stream.Collectors;
28+
import java.util.stream.Stream;
2729

2830
public class JDBCVectorStoreDefaultQueryProvider
2931
implements JDBCVectorStoreQueryProvider {
30-
private static final Map<Class<?>, String> supportedKeyTypes;
31-
private static final Map<Class<?>, String> supportedDataTypes;
32-
private static final Map<Class<?>, String> supportedVectorTypes;
3332

34-
static {
33+
private Map<Class<?>, String> supportedKeyTypes;
34+
private Map<Class<?>, String> supportedDataTypes;
35+
private Map<Class<?>, String> supportedVectorTypes;
36+
private final DataSource dataSource;
37+
private final String collectionsTable;
38+
private final String prefixForCollectionTables;
39+
40+
@SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
41+
protected JDBCVectorStoreDefaultQueryProvider(
42+
@Nonnull DataSource dataSource,
43+
@Nonnull String collectionsTable,
44+
@Nonnull String prefixForCollectionTables) {
45+
this.dataSource = dataSource;
46+
this.collectionsTable = collectionsTable;
47+
this.prefixForCollectionTables = prefixForCollectionTables;
48+
3549
supportedKeyTypes = new HashMap<>();
3650
supportedKeyTypes.put(String.class, "VARCHAR(255)");
3751

@@ -54,19 +68,6 @@ public class JDBCVectorStoreDefaultQueryProvider
5468
supportedVectorTypes.put(List.class, "TEXT");
5569
supportedVectorTypes.put(Collection.class, "TEXT");
5670
}
57-
private final DataSource dataSource;
58-
private final String collectionsTable;
59-
private final String prefixForCollectionTables;
60-
61-
@SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
62-
protected JDBCVectorStoreDefaultQueryProvider(
63-
@Nonnull DataSource dataSource,
64-
@Nonnull String collectionsTable,
65-
@Nonnull String prefixForCollectionTables) {
66-
this.dataSource = dataSource;
67-
this.collectionsTable = collectionsTable;
68-
this.prefixForCollectionTables = prefixForCollectionTables;
69-
}
7071

7172
/**
7273
* Creates a new builder.
@@ -82,14 +83,9 @@ public static Builder builder() {
8283
* @return the formatted wildcard string
8384
*/
8485
protected String getWildcardString(int wildcards) {
85-
StringBuilder wildcardString = new StringBuilder();
86-
for (int i = 0; i < wildcards; ++i) {
87-
wildcardString.append("?");
88-
if (i < wildcards - 1) {
89-
wildcardString.append(", ");
90-
}
91-
}
92-
return wildcardString.toString();
86+
return Stream.generate(() -> "?")
87+
.limit(wildcards)
88+
.collect(Collectors.joining(", "));
9389
}
9490

9591
/**
@@ -102,6 +98,12 @@ protected String getQueryColumnsFromFields(List<VectorStoreRecordField> fields)
10298
.collect(Collectors.joining(", "));
10399
}
104100

101+
/**
102+
* Formats the column names and types for a table.
103+
* @param fields the fields
104+
* @param types the types
105+
* @return the formatted column names and types
106+
*/
105107
protected String getColumnNamesAndTypes(List<Field> fields, Map<Class<?>, String> types) {
106108
List<String> columns = fields.stream()
107109
.map(field -> field.getName() + " " + types.get(field.getType()))
@@ -114,6 +116,36 @@ protected String getCollectionTableName(String collectionName) {
114116
return validateSQLidentifier(prefixForCollectionTables + collectionName);
115117
}
116118

119+
/**
120+
* Gets the supported key types and their corresponding SQL types.
121+
*
122+
* @return the supported key types
123+
*/
124+
@Override
125+
public Map<Class<?>, String> getSupportedKeyTypes() {
126+
return new HashMap<>(this.supportedKeyTypes);
127+
}
128+
129+
/**
130+
* Gets the supported data types and their corresponding SQL types.
131+
*
132+
* @return the supported data types
133+
*/
134+
@Override
135+
public Map<Class<?>, String> getSupportedDataTypes() {
136+
return new HashMap<>(this.supportedDataTypes);
137+
}
138+
139+
/**
140+
* Gets the supported vector types and their corresponding SQL types.
141+
*
142+
* @return the supported vector types
143+
*/
144+
@Override
145+
public Map<Class<?>, String> getSupportedVectorTypes() {
146+
return new HashMap<>(this.supportedVectorTypes);
147+
}
148+
117149
/**
118150
* Prepares the vector store.
119151
* Executes any necessary setup steps for the vector store.
@@ -146,11 +178,12 @@ public void validateSupportedTypes(Class<?> recordClass,
146178
VectorStoreRecordDefinition recordDefinition) {
147179
VectorStoreRecordDefinition.validateSupportedTypes(
148180
Collections.singletonList(recordDefinition.getKeyDeclaredField(recordClass)),
149-
supportedKeyTypes.keySet());
181+
getSupportedKeyTypes().keySet());
150182
VectorStoreRecordDefinition.validateSupportedTypes(
151-
recordDefinition.getDataDeclaredFields(recordClass), supportedDataTypes.keySet());
183+
recordDefinition.getDataDeclaredFields(recordClass), getSupportedDataTypes().keySet());
152184
VectorStoreRecordDefinition.validateSupportedTypes(
153-
recordDefinition.getVectorDeclaredFields(recordClass), supportedVectorTypes.keySet());
185+
recordDefinition.getVectorDeclaredFields(recordClass),
186+
getSupportedVectorTypes().keySet());
154187
}
155188

156189
/**
@@ -194,8 +227,8 @@ public void createCollection(String collectionName, Class<?> recordClass,
194227
String createStorageTable = "CREATE TABLE IF NOT EXISTS "
195228
+ getCollectionTableName(collectionName)
196229
+ " (" + keyDeclaredField.getName() + " VARCHAR(255) PRIMARY KEY, "
197-
+ getColumnNamesAndTypes(dataDeclaredFields, supportedDataTypes) + ", "
198-
+ getColumnNamesAndTypes(vectorDeclaredFields, supportedVectorTypes) + ");";
230+
+ getColumnNamesAndTypes(dataDeclaredFields, getSupportedDataTypes()) + ", "
231+
+ getColumnNamesAndTypes(vectorDeclaredFields, getSupportedVectorTypes()) + ");";
199232

200233
String insertCollectionQuery = "INSERT INTO " + validateSQLidentifier(collectionsTable)
201234
+ " (collectionId) VALUES (?)";
@@ -284,7 +317,8 @@ public List<String> getCollectionNames() {
284317
*/
285318
@Override
286319
public <Record> List<Record> getRecords(String collectionName, List<String> keys,
287-
VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper<Record> mapper,
320+
VectorStoreRecordDefinition recordDefinition,
321+
VectorStoreRecordMapper<Record, ResultSet> mapper,
288322
GetRecordOptions options) {
289323
List<VectorStoreRecordField> fields;
290324
if (options == null || options.includeVectors()) {

semantickernel-experimental/src/main/java/com/microsoft/semantickernel/connectors/data/jdbc/JDBCVectorStoreQueryProvider.java

+26-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
package com.microsoft.semantickernel.connectors.data.jdbc;
33

44
import com.microsoft.semantickernel.builders.SemanticKernelBuilder;
5+
import com.microsoft.semantickernel.data.VectorStoreRecordMapper;
56
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
67
import com.microsoft.semantickernel.data.recordoptions.DeleteRecordOptions;
78
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
89
import com.microsoft.semantickernel.data.recordoptions.UpsertRecordOptions;
910

11+
import java.sql.ResultSet;
1012
import java.util.List;
13+
import java.util.Map;
1114

1215
/**
1316
* The JDBC vector store query provider.
@@ -24,6 +27,27 @@ public interface JDBCVectorStoreQueryProvider {
2427
*/
2528
String DEFAULT_PREFIX_FOR_COLLECTION_TABLES = "SKCollection_";
2629

30+
/**
31+
* Gets the supported key types and their corresponding SQL types.
32+
*
33+
* @return the supported key types
34+
*/
35+
Map<Class<?>, String> getSupportedKeyTypes();
36+
37+
/**
38+
* Gets the supported data types and their corresponding SQL types.
39+
*
40+
* @return the supported data types
41+
*/
42+
Map<Class<?>, String> getSupportedDataTypes();
43+
44+
/**
45+
* Gets the supported vector types and their corresponding SQL types.
46+
*
47+
* @return the supported vector types
48+
*/
49+
Map<Class<?>, String> getSupportedVectorTypes();
50+
2751
/**
2852
* Prepares the vector store.
2953
* Executes any necessary setup steps for the vector store.
@@ -81,7 +105,8 @@ void createCollection(String collectionName, Class<?> recordClass,
81105
* @return the records
82106
*/
83107
<Record> List<Record> getRecords(String collectionName, List<String> keys,
84-
VectorStoreRecordDefinition recordDefinition, JDBCVectorStoreRecordMapper<Record> mapper,
108+
VectorStoreRecordDefinition recordDefinition,
109+
VectorStoreRecordMapper<Record, ResultSet> mapper,
85110
GetRecordOptions options);
86111

87112
/**

0 commit comments

Comments
 (0)