Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API classes for vector search and Azure AI implementation #202

Merged
merged 17 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/_typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ans = "ans" # Short for answers
arange = "arange" # Method in Python numpy package
prompty = "prompty" # prompty is a format name.
ist = "ist" # German language
Prelease = "Prelease" # Prelease is a format name.

[default.extend-identifiers]
ags = "ags" # Azure Graph Service
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/java-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: ./mvnw -B -Pbug-check -Pcompile-jdk${{ matrix.java-versions }} test --file pom.xml

# Uploads test artifacts for each JDK version
- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v4
if: always()
with:
name: test_output_sk_jdk${{ matrix.java-versions }}u
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordVectorAttribute;

import java.util.List;

Expand All @@ -16,7 +16,7 @@ public class Hotel {
@VectorStoreRecordDataAttribute
private final int code;
@JsonProperty("summary")
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "descriptionEmbedding")
@VectorStoreRecordDataAttribute()
private final String description;
@JsonProperty("summaryEmbedding")
@VectorStoreRecordVectorAttribute(dimensions = 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.connectors.data.postgres.PostgreSQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.data.record.options.GetRecordOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.mysql.cj.jdbc.MysqlDataSource;
import org.junit.jupiter.params.ParameterizedTest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import com.microsoft.semantickernel.connectors.data.redis.RedisHashSetVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.data.redis.RedisHashSetVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDataField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordDataField;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordField;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordKeyField;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordVectorField;
import com.microsoft.semantickernel.data.record.options.GetRecordOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.redis.testcontainers.RedisContainer;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -67,8 +67,6 @@ static void setup() {
.withName("description")
.withStorageName("summary")
.withFieldType(String.class)
.withHasEmbedding(true)
.withEmbeddingFieldName("descriptionEmbedding")
.build());
fields.add(VectorStoreRecordVectorField.builder()
.withName("descriptionEmbedding")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import com.microsoft.semantickernel.connectors.data.redis.RedisJsonVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.data.redis.RedisJsonVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDataField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordKeyField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordVectorField;
import com.microsoft.semantickernel.data.recordoptions.GetRecordOptions;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordDataField;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordField;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordKeyField;
import com.microsoft.semantickernel.data.record.definition.VectorStoreRecordVectorField;
import com.microsoft.semantickernel.data.record.options.GetRecordOptions;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.redis.testcontainers.RedisContainer;
import org.junit.jupiter.api.BeforeAll;
Expand Down Expand Up @@ -68,8 +68,6 @@ static void setup() {
.withName("description")
.withStorageName("summary")
.withFieldType(String.class)
.withHasEmbedding(true)
.withEmbeddingFieldName("descriptionEmbedding")
.build());
fields.add(VectorStoreRecordVectorField.builder()
.withName("descriptionEmbedding")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import com.microsoft.semantickernel.samples.syntaxexamples.functions.Example59_OpenAIFunctionCalling;
import com.microsoft.semantickernel.samples.syntaxexamples.functions.Example60_AdvancedMethodFunctions;
import com.microsoft.semantickernel.samples.syntaxexamples.java.KernelFunctionYaml_Example;
import com.microsoft.semantickernel.samples.syntaxexamples.memory.AzureAISearch_DataStorage;
import com.microsoft.semantickernel.samples.syntaxexamples.memory.AzureAISearchVectorStore;
import com.microsoft.semantickernel.samples.syntaxexamples.plugins.Example10_DescribeAllPluginsAndFunctions;
import com.microsoft.semantickernel.samples.syntaxexamples.plugins.Example13_ConversationSummaryPlugin;
import com.microsoft.semantickernel.samples.syntaxexamples.template.Example06_TemplateLanguage;
Expand All @@ -38,7 +38,7 @@ public class RunAll {

public static void main(String[] args) {
List<MainMethod> mains = Arrays.asList(
AzureAISearch_DataStorage::main,
AzureAISearchVectorStore::main,
Example01_NativeFunctions::main,
Example03_Arguments::main,
Example05_InlineFunctionDefinition::main,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,27 @@
import com.azure.core.util.TracingOptions;
import com.azure.search.documents.indexes.SearchIndexAsyncClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.microsoft.semantickernel.aiservices.openai.textembedding.OpenAITextEmbeddingGenerationService;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreOptions;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
import com.microsoft.semantickernel.data.VectorSearchResult;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordVectorAttribute;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class AzureAISearch_DataStorage {
public class AzureAISearchVectorStore {

private static final String CLIENT_KEY = System.getenv("CLIENT_KEY");
private static final String AZURE_CLIENT_KEY = System.getenv("AZURE_CLIENT_KEY");
Expand All @@ -51,7 +51,7 @@ static class GitHubFile {
@JsonProperty("fileId") // Set a different name for the storage field if needed
@VectorStoreRecordKeyAttribute()
private final String id;
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding")
@VectorStoreRecordDataAttribute()
private final String description;
@VectorStoreRecordDataAttribute
private final String link;
Expand Down Expand Up @@ -118,13 +118,13 @@ public static void dataStorageWithAzureAISearch(
OpenAITextEmbeddingGenerationService embeddingGeneration) {

// Create a new Azure AI Search vector store
var azureAISearchVectorStore = AzureAISearchVectorStore.builder()
.withClient(searchClient)
var azureAISearchVectorStore = com.microsoft.semantickernel.connectors.data.azureaisearch.AzureAISearchVectorStore.builder()
.withSearchIndexAsyncClient(searchClient)
.withOptions(new AzureAISearchVectorStoreOptions())
.build();

String collectionName = "skgithubfiles";
var collection = azureAISearchVectorStore.getCollection(
var collection = (AzureAISearchVectorStoreRecordCollection<GitHubFile>) azureAISearchVectorStore.getCollection(
collectionName,
AzureAISearchVectorStoreRecordCollectionOptions.<GitHubFile>builder()
.withRecordClass(GitHubFile.class)
Expand All @@ -136,18 +136,26 @@ public static void dataStorageWithAzureAISearch(
.then(storeData(collection, embeddingGeneration, sampleData()))
.block();

// Query the Azure AI Search client for results
// This might take a few seconds to return the best result
var result = searchClient.getSearchAsyncClient(collectionName)
.search("How to get started with the Semantic Kernel?")
.blockFirst();
// Search for results
// Might need to wait for the data to be indexed
var results = search("How to get started", collection, embeddingGeneration).block();
var searchResult = results.get(0);
System.out.printf("Search result with score: %f.%n Link: %s, Description: %s%n",
searchResult.getScore(), searchResult.getRecord().link, searchResult.getRecord().description);
}


private static Mono<List<VectorSearchResult<GitHubFile>>> search(
String searchText,
AzureAISearchVectorStoreRecordCollection<GitHubFile> recordCollection,
OpenAITextEmbeddingGenerationService embeddingGeneration) {

GitHubFile gitHubFile = result.getDocument(GitHubFile.class);
System.out.println("Best result: " + gitHubFile.description + ". Link: " + gitHubFile.link);
return embeddingGeneration.generateEmbeddingsAsync(Collections.singletonList(searchText))
.flatMap(r -> recordCollection.searchAsync(r.get(0).getVector(), null));
}

private static Mono<List<String>> storeData(
VectorStoreRecordCollection<String, GitHubFile> recordStore,
AzureAISearchVectorStoreRecordCollection<GitHubFile> recordCollection,
OpenAITextEmbeddingGenerationService embeddingGeneration,
Map<String, String> data) {

Expand All @@ -163,7 +171,7 @@ private static Mono<List<String>> storeData(
entry.getValue(),
entry.getKey(),
embeddings.get(0).getVector());
return recordStore.upsertAsync(gitHubFile, null);
return recordCollection.upsertAsync(gitHubFile, null);
});
})
.collectList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.VolatileVectorStore;
import com.microsoft.semantickernel.data.VolatileVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordVectorAttribute;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand All @@ -37,7 +37,7 @@ static class GitHubFile {

@VectorStoreRecordKeyAttribute()
private final String id;
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding")
@VectorStoreRecordDataAttribute()
private final String description;
@VectorStoreRecordDataAttribute
private final String link;
Expand All @@ -64,7 +64,7 @@ public String getDescription() {
}

static String encodeId(String realId) {
return AzureAISearch_DataStorage.GitHubFile.encodeId(realId);
return AzureAISearchVectorStore.GitHubFile.encodeId(realId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import com.microsoft.semantickernel.connectors.data.jdbc.JDBCVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.connectors.data.mysql.MySQLVectorStoreQueryProvider;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordVectorAttribute;
import com.mysql.cj.jdbc.MysqlDataSource;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
Expand Down Expand Up @@ -45,7 +45,7 @@ static class GitHubFile {

@VectorStoreRecordKeyAttribute()
private final String id;
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding")
@VectorStoreRecordDataAttribute()
private final String description;
@VectorStoreRecordDataAttribute
private final String link;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStore;
import com.microsoft.semantickernel.connectors.data.redis.RedisVectorStoreOptions;
import com.microsoft.semantickernel.data.VectorStoreRecordCollection;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordDataAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordKeyAttribute;
import com.microsoft.semantickernel.data.record.attributes.VectorStoreRecordVectorAttribute;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -47,7 +47,7 @@ public static class GitHubFile {

@VectorStoreRecordKeyAttribute()
private final String id;
@VectorStoreRecordDataAttribute(hasEmbedding = true, embeddingFieldName = "embedding")
@VectorStoreRecordDataAttribute()
private final String description;
@VectorStoreRecordDataAttribute
private final String link;
Expand Down Expand Up @@ -78,7 +78,7 @@ public String getDescription() {
}

static String encodeId(String realId) {
return AzureAISearch_DataStorage.GitHubFile.encodeId(realId);
return AzureAISearchVectorStore.GitHubFile.encodeId(realId);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.connectors.data.azureaisearch;

import com.microsoft.semantickernel.data.vectorsearch.filtering.EqualityFilterClause;
import com.microsoft.semantickernel.exceptions.SKException;

import java.time.OffsetDateTime;
import java.time.format.DateTimeFormatter;

public class AzureAISearchEqualityFilterClause extends EqualityFilterClause {

/**
* Initializes a new instance of the AzureAISearchEqualityFilterClause class.
*
* @param fieldName The field name to filter on.
* @param value The value.
*/
public AzureAISearchEqualityFilterClause(String fieldName, Object value) {
super(fieldName, value);
}

/**
* Gets the filter string.
*
* @return The filter string.
*/
@Override
public String getFilter() {
String fieldName = getFieldName();
Object value = getValue();

if (value instanceof String) {
return String.format("%s eq '%s'", fieldName, value);
} else if (value instanceof Boolean) {
return String.format("%s eq %s", fieldName,
value.toString().toLowerCase());
} else if (value instanceof Integer) {
return String.format("%s eq %d", fieldName, (Integer) value);
} else if (value instanceof Long) {
return String.format("%s eq %d", fieldName, (Long) value);
} else if (value instanceof Float) {
return String.format("%s eq %f", fieldName, (Float) value);
} else if (value instanceof Double) {
return String.format("%s eq %f", fieldName, (Double) value);
} else if (value instanceof OffsetDateTime) {
return String.format("%s eq %s", fieldName, ((OffsetDateTime) value)
.format(DateTimeFormatter.ISO_OFFSET_DATE_TIME));
} else if (value == null) {
return String.format("%s eq null", fieldName);
} else {
throw new SKException("Unsupported filter value type '"
+ value.getClass().getSimpleName() + "'.");
}
}
}
Loading