Skip to content

Commit

Permalink
Merge pull request #266 from johnoliver/fix-db
Browse files Browse the repository at this point in the history
Add thread safety on database creation
  • Loading branch information
johnoliver authored Nov 19, 2024
2 parents 612e5a5 + 564b272 commit 2c43b23
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.GuardedBy;
import javax.sql.DataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -54,6 +57,8 @@ public class JDBCVectorStoreQueryProvider
private final String collectionsTable;
private final String prefixForCollectionTables;

private final Object dbCreationLock = new Object();

@SuppressFBWarnings("EI_EXPOSE_REP2") // DataSource is not exposed
protected JDBCVectorStoreQueryProvider(
@Nonnull DataSource dataSource,
Expand Down Expand Up @@ -89,12 +94,13 @@ protected JDBCVectorStoreQueryProvider(

/**
* Creates a new instance of the JDBCVectorStoreQueryProvider class.
* @param dataSource the data source
* @param collectionsTable the collections table
*
* @param dataSource the data source
* @param collectionsTable the collections table
* @param prefixForCollectionTables the prefix for collection tables
* @param supportedKeyTypes the supported key types
* @param supportedDataTypes the supported data types
* @param supportedVectorTypes the supported vector types
* @param supportedKeyTypes the supported key types
* @param supportedDataTypes the supported data types
* @param supportedVectorTypes the supported vector types
*/
public JDBCVectorStoreQueryProvider(
@SuppressFBWarnings("EI_EXPOSE_REP2") @Nonnull DataSource dataSource,
Expand Down Expand Up @@ -276,48 +282,57 @@ public boolean collectionExists(String collectionName) {
*/
@Override
@SuppressFBWarnings("SQL_PREPARED_STATEMENT_GENERATED_FROM_NONCONSTANT_STRING")
@GuardedBy("dbCreationLock")
// SQL query is generated dynamically with valid identifiers
public void createCollection(String collectionName,
VectorStoreRecordDefinition recordDefinition) {

// No approximate search is supported in JDBCVectorStoreQueryProvider
if (recordDefinition.getVectorFields().stream()
.anyMatch(
field -> field.getIndexKind() != null && field.getIndexKind() != IndexKind.FLAT
&& field.getIndexKind() != IndexKind.UNDEFINED)) {
LOGGER
.warn(String.format("Indexes are not supported in %s. Ignoring indexKind property.",
this.getClass().getName()));
}

String createStorageTable = formatQuery("CREATE TABLE IF NOT EXISTS %s ("
+ "%s VARCHAR(255) PRIMARY KEY, "
+ "%s, "
+ "%s);",
getCollectionTableName(collectionName),
getKeyColumnName(recordDefinition.getKeyField()),
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes()),
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getVectorFields()),
getSupportedVectorTypes()));
synchronized (dbCreationLock) {
// No approximate search is supported in JDBCVectorStoreQueryProvider
if (recordDefinition.getVectorFields().stream()
.anyMatch(
field -> field.getIndexKind() != null && field.getIndexKind() != IndexKind.FLAT
&& field.getIndexKind() != IndexKind.UNDEFINED)) {
LOGGER
.warn(String.format(
"Indexes are not supported in %s. Ignoring indexKind property.",
this.getClass().getName()));
}

String insertCollectionQuery = formatQuery("INSERT INTO %s (collectionId) VALUES (?)",
validateSQLidentifier(collectionsTable));
String createStorageTable = formatQuery("CREATE TABLE IF NOT EXISTS %s ("
+ "%s VARCHAR(255) PRIMARY KEY, "
+ "%s, "
+ "%s);",
getCollectionTableName(collectionName),
getKeyColumnName(recordDefinition.getKeyField()),
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getDataFields()),
getSupportedDataTypes()),
getColumnNamesAndTypes(new ArrayList<>(recordDefinition.getVectorFields()),
getSupportedVectorTypes()));

String insertCollectionQuery = this.getInsertCollectionQuery(collectionsTable);

try (Connection connection = dataSource.getConnection();
PreparedStatement createTable = connection.prepareStatement(createStorageTable)) {
createTable.execute();
} catch (SQLException e) {
throw new SKException("Failed to create collection", e);
}

try (Connection connection = dataSource.getConnection();
PreparedStatement createTable = connection.prepareStatement(createStorageTable)) {
createTable.execute();
} catch (SQLException e) {
throw new SKException("Failed to create collection", e);
try (Connection connection = dataSource.getConnection();
PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) {
insert.setObject(1, collectionName);
insert.execute();
} catch (SQLException e) {
throw new SKException("Failed to insert collection", e);
}
}
}

try (Connection connection = dataSource.getConnection();
PreparedStatement insert = connection.prepareStatement(insertCollectionQuery)) {
insert.setObject(1, collectionName);
insert.execute();
} catch (SQLException e) {
throw new SKException("Failed to insert collection", e);
}
protected String getInsertCollectionQuery(String collectionsTable) {
return formatQuery(
"INSERT IGNORE INTO %s (collectionId) VALUES (?)",
validateSQLidentifier(collectionsTable));
}

/**
Expand All @@ -327,26 +342,29 @@ public void createCollection(String collectionName,
* @throws SKException if an error occurs while deleting the collection
*/
@Override
@GuardedBy("dbCreationLock")
public void deleteCollection(String collectionName) {
String deleteCollectionOperation = formatQuery("DELETE FROM %s WHERE collectionId = ?",
validateSQLidentifier(collectionsTable));
String dropTableOperation = formatQuery("DROP TABLE %s",
getCollectionTableName(collectionName));

try (Connection connection = dataSource.getConnection();
PreparedStatement deleteCollection = connection
.prepareStatement(deleteCollectionOperation)) {
deleteCollection.setObject(1, collectionName);
deleteCollection.execute();
} catch (SQLException e) {
throw new SKException("Failed to delete collection", e);
}
synchronized (dbCreationLock) {
String deleteCollectionOperation = formatQuery("DELETE FROM %s WHERE collectionId = ?",
validateSQLidentifier(collectionsTable));
String dropTableOperation = formatQuery("DROP TABLE %s",
getCollectionTableName(collectionName));

try (Connection connection = dataSource.getConnection();
PreparedStatement deleteCollection = connection
.prepareStatement(deleteCollectionOperation)) {
deleteCollection.setObject(1, collectionName);
deleteCollection.execute();
} catch (SQLException e) {
throw new SKException("Failed to delete collection", e);
}

try (Connection connection = dataSource.getConnection();
PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) {
dropTable.execute();
} catch (SQLException e) {
throw new SKException("Failed to drop table", e);
try (Connection connection = dataSource.getConnection();
PreparedStatement dropTable = connection.prepareStatement(dropTableOperation)) {
dropTable.execute();
} catch (SQLException e) {
throw new SKException("Failed to drop table", e);
}
}
}

Expand Down Expand Up @@ -518,8 +536,8 @@ protected <Record> List<Record> getRecordsWithFilter(String collectionName,
*
* @param <Record> the record type
* @param collectionName the collection name
* @param vector the vector to search with
* @param options the search options
* @param vector the vector to search with
* @param options the search options
* @param recordDefinition the record definition
* @param mapper the mapper, responsible for mapping the result set to the record
* type.
Expand Down Expand Up @@ -622,8 +640,8 @@ public String getFilter(VectorSearchFilter filter,
}

/**
* Gets the filter parameters for the given vector search filter to associate with the filter string
* generated by the getFilter method.
* Gets the filter parameters for the given vector search filter to associate with the filter
* string generated by the getFilter method.
*
* @param filter The filter to get the filter parameters for.
* @return The filter parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import com.microsoft.semantickernel.data.jdbc.postgres.PostgreSQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.jdbc.postgres.PostgreSQLVectorStoreRecordMapper;
import com.microsoft.semantickernel.data.vectorsearch.VectorSearchResults;
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.vectorstorage.VectorStoreRecordMapper;
import com.microsoft.semantickernel.data.vectorstorage.definition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.vectorstorage.options.DeleteRecordOptions;
import com.microsoft.semantickernel.data.vectorstorage.options.GetRecordOptions;
Expand All @@ -27,10 +27,9 @@
import reactor.core.scheduler.Schedulers;

/**
* The JDBCVectorStoreRecordCollection class represents a collection of records
* in a JDBC vector store. It implements the SQLVectorStoreRecordCollection
* interface and provides methods for managing the collection, such as creating,
* deleting, and upserting records.
* The JDBCVectorStoreRecordCollection class represents a collection of records in a JDBC vector
* store. It implements the SQLVectorStoreRecordCollection interface and provides methods for
* managing the collection, such as creating, deleting, and upserting records.
*
* @param <Record> the type of the records in the collection
*/
Expand Down Expand Up @@ -322,6 +321,7 @@ public Mono<VectorSearchResults<Record>> searchAsync(List<Float> vector,

/**
* Builder for a JDBCVectorStoreRecordCollection.
*
* @param <Record> the type of the records in the collection
*/
public static class Builder<Record>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ public void upsertRecords(String collectionName, List<?> records,
}
}

@Override
protected String getInsertCollectionQuery(String collectionsTable) {
return formatQuery(
"INSERT OR IGNORE INTO %s (collectionId) VALUES (?)",
validateSQLidentifier(collectionsTable));
}

/**
* A builder for {@code SQLiteVectorStoreQueryProvider}.
*/
Expand Down

0 comments on commit 2c43b23

Please sign in to comment.