Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RedisVectorStore and index/collection management #95

Merged
merged 3 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.microsoft.semantickernel.tests.connectors.memory.redis;

import com.microsoft.semantickernel.connectors.memory.redis.RedisVectorStoreRecordCollection;
import com.microsoft.semantickernel.connectors.memory.redis.RedisVectorStoreOptions;
import com.microsoft.semantickernel.connectors.memory.redis.RedisVectorStoreRecordCollectionOptions;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDataField;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordField;
Expand All @@ -11,7 +11,10 @@
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.redis.testcontainers.RedisContainer;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.testcontainers.junit.jupiter.Container;
Expand All @@ -26,23 +29,25 @@
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;

@Testcontainers
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
public class RedisVectorStoreRecordCollectionTest {

@Container private static final RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest");

private static final Map<Options, RedisVectorStoreOptions<Hotel>> optionsMap = new HashMap<>();
private static final Map<RecordCollectionOptions, RedisVectorStoreRecordCollectionOptions<Hotel>> optionsMap = new HashMap<>();

public enum Options {
public enum RecordCollectionOptions {
DEFAULT, WITH_CUSTOM_DEFINITION
}

@BeforeAll
static void setup() {
optionsMap.put(Options.DEFAULT, RedisVectorStoreOptions.<Hotel>builder()
optionsMap.put(RecordCollectionOptions.DEFAULT, RedisVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.build());

Expand Down Expand Up @@ -74,27 +79,21 @@ static void setup() {
.build());
VectorStoreRecordDefinition recordDefinition = VectorStoreRecordDefinition.fromFields(fields);

optionsMap.put(Options.WITH_CUSTOM_DEFINITION, RedisVectorStoreOptions.<Hotel>builder()
optionsMap.put(RecordCollectionOptions.WITH_CUSTOM_DEFINITION, RedisVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(Hotel.class)
.withRecordDefinition(recordDefinition)
.build());
}

private RedisVectorStoreRecordCollection<Hotel> buildRecordStore(@Nonnull RedisVectorStoreOptions<Hotel> options, @Nonnull String collectionName) {
return new RedisVectorStoreRecordCollection<>(new JedisPooled(redisContainer.getRedisURI()), collectionName, RedisVectorStoreOptions.<Hotel>builder()
private RedisVectorStoreRecordCollection<Hotel> buildrecordCollection(@Nonnull RedisVectorStoreRecordCollectionOptions<Hotel> options, @Nonnull String collectionName) {
return new RedisVectorStoreRecordCollection<>(new JedisPooled(redisContainer.getRedisURI()), collectionName, RedisVectorStoreRecordCollectionOptions.<Hotel>builder()
.withRecordClass(options.getRecordClass())
.withVectorStoreRecordMapper(options.getVectorStoreRecordMapper())
.withRecordDefinition(options.getRecordDefinition())
.withPrefixCollectionName(options.prefixCollectionName())
.withPrefixCollectionName(options.isPrefixCollectionName())
.build());
}

@ParameterizedTest
@EnumSource(Options.class)
public void buildRecordStore(Options options) {
assertNotNull(buildRecordStore(optionsMap.get(options), "buildTest"));
}

private List<Hotel> getHotels() {
return List.of(
new Hotel("id_1", "Hotel 1", 1, "Hotel 1 description", Arrays.asList(1.0f, 2.0f, 3.0f), 4.0),
Expand All @@ -105,37 +104,65 @@ private List<Hotel> getHotels() {
);
}

@Order(1)
@ParameterizedTest
@EnumSource(RecordCollectionOptions.class)
public void buildrecordCollection(RecordCollectionOptions options) {
assertNotNull(buildrecordCollection(optionsMap.get(options), options.name()));
}

@Order(2)
@ParameterizedTest
@EnumSource(RecordCollectionOptions.class)
public void createCollectionAsync(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

assertEquals(false, recordCollection.collectionExistsAsync().block());
recordCollection.createCollectionAsync().block();
assertEquals(true, recordCollection.collectionExistsAsync().block());
}

@Test
public void deleteCollectionAsync() {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(RecordCollectionOptions.DEFAULT), "deleteCollectionAsync");

assertEquals(false, recordCollection.collectionExistsAsync().block());
recordCollection.createCollectionAsync().block();
recordCollection.deleteCollectionAsync().block();
assertEquals(false, recordCollection.collectionExistsAsync().block());
}

@ParameterizedTest
@EnumSource(Options.class)
public void upsertAndGetRecordAsync(Options options) {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(options), "upsertAndGetRecordAsync");
@EnumSource(RecordCollectionOptions.class)
public void upsertAndGetRecordAsync(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
for (Hotel hotel : hotels) {
recordStore.upsertAsync(hotel, null).block();
recordCollection.upsertAsync(hotel, null).block();
}

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
assertNotNull(retrievedHotel);
assertEquals(hotel.getId(), retrievedHotel.getId());
}
}

@ParameterizedTest
@EnumSource(Options.class)
public void getBatchAsync(Options options) {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(options), "getBatchAsync");
@EnumSource(RecordCollectionOptions.class)
public void getBatchAsync(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
for (Hotel hotel : hotels) {
recordStore.upsertAsync(hotel, null).block();
recordCollection.upsertAsync(hotel, null).block();
}

List<String> ids = new ArrayList<>();
hotels.forEach(hotel -> ids.add(hotel.getId()));

List<Hotel> retrievedHotels = recordStore.getBatchAsync(ids, null).block();
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(ids, null).block();

assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
Expand All @@ -145,15 +172,15 @@ public void getBatchAsync(Options options) {
}

@ParameterizedTest
@EnumSource(Options.class)
public void upsertBatchAsync(Options options) {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(options), "upsertBatchAsync");
@EnumSource(RecordCollectionOptions.class)
public void upsertBatchAsync(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
List<String> keys = recordStore.upsertBatchAsync(hotels, null).block();
List<String> keys = recordCollection.upsertBatchAsync(hotels, null).block();
assertNotNull(keys);

List<Hotel> retrievedHotels = (List<Hotel>) recordStore.getBatchAsync(keys, null).block();
List<Hotel> retrievedHotels = (List<Hotel>) recordCollection.getBatchAsync(keys, null).block();

assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
Expand All @@ -163,66 +190,68 @@ public void upsertBatchAsync(Options options) {
}

@ParameterizedTest
@EnumSource(Options.class)
public void deleteAsync(Options options) {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(options), "deleteAsync");
@EnumSource(RecordCollectionOptions.class)
public void deleteAsync(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordCollection.upsertBatchAsync(hotels, null).block();

for (Hotel hotel : hotels) {
recordStore.deleteAsync(hotel.getId(), null).block();
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
recordCollection.deleteAsync(hotel.getId(), null).block();
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
assertNull(retrievedHotel);
}
}

@ParameterizedTest
@EnumSource(Options.class)
public void deleteBatchAsync(Options options) {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(options), "deleteBatchAsync");
@EnumSource(RecordCollectionOptions.class)
public void deleteBatchAsync(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordCollection.upsertBatchAsync(hotels, null).block();

List<String> ids = new ArrayList<>();
hotels.forEach(hotel -> ids.add(hotel.getId()));

recordStore.deleteBatchAsync(ids, null).block();
recordCollection.deleteBatchAsync(ids, null).block();

for (String id : ids) {
Hotel retrievedHotel = recordStore.getAsync(id, null).block();
Hotel retrievedHotel = recordCollection.getAsync(id, null).block();
assertNull(retrievedHotel);
}
}

@Test
public void getAsyncWithVectors() {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(Options.DEFAULT), "getAsyncWithVectors");
@ParameterizedTest
@EnumSource(RecordCollectionOptions.class)
public void getAsyncWithVectors(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordCollection.upsertBatchAsync(hotels, null).block();

for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), null).block();
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), null).block();
assertNotNull(retrievedHotel);
assertNotNull(retrievedHotel.getDescriptionEmbedding());
assertEquals(hotel.getId(), retrievedHotel.getId());
assertEquals(hotel.getDescription(), retrievedHotel.getDescription());
}
}

@Test
public void getBatchAsyncWithVectors() {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(Options.DEFAULT), "getBatchAsyncWithVectors");
@ParameterizedTest
@EnumSource(RecordCollectionOptions.class)
public void getBatchAsyncWithVectors(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordCollection.upsertBatchAsync(hotels, null).block();

List<String> ids = new ArrayList<>();
hotels.forEach(hotel -> ids.add(hotel.getId()));

List<Hotel> retrievedHotels = recordStore.getBatchAsync(ids, null).block();
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(ids, null).block();

assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
Expand All @@ -233,35 +262,37 @@ public void getBatchAsyncWithVectors() {
}
}

@Test
public void getAsyncWithNoVectors() {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(Options.WITH_CUSTOM_DEFINITION), "getAsyncWithNoVectors");
@ParameterizedTest
@EnumSource(RecordCollectionOptions.class)
public void getAsyncWithNoVectors(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordCollection.upsertBatchAsync(hotels, null).block();

GetRecordOptions getRecordOptions = GetRecordOptions.builder().includeVectors(false).build();
for (Hotel hotel : hotels) {
Hotel retrievedHotel = recordStore.getAsync(hotel.getId(), getRecordOptions).block();
Hotel retrievedHotel = recordCollection.getAsync(hotel.getId(), getRecordOptions).block();
assertNotNull(retrievedHotel);
assertNull(retrievedHotel.getDescriptionEmbedding());
assertEquals(hotel.getId(), retrievedHotel.getId());
assertEquals(hotel.getDescription(), retrievedHotel.getDescription());
}
}

@Test
public void getBatchAsyncWithNoVectors() {
RedisVectorStoreRecordCollection<Hotel> recordStore = buildRecordStore(optionsMap.get(Options.WITH_CUSTOM_DEFINITION), "getBatchAsyncWithNoVectors");
@ParameterizedTest
@EnumSource(RecordCollectionOptions.class)
public void getBatchAsyncWithNoVectors(RecordCollectionOptions options) {
RedisVectorStoreRecordCollection<Hotel> recordCollection = buildrecordCollection(optionsMap.get(options), options.name());

List<Hotel> hotels = getHotels();
recordStore.upsertBatchAsync(hotels, null).block();
recordCollection.upsertBatchAsync(hotels, null).block();

GetRecordOptions getRecordOptions = GetRecordOptions.builder().includeVectors(false).build();
List<String> ids = new ArrayList<>();
hotels.forEach(hotel -> ids.add(hotel.getId()));

List<Hotel> retrievedHotels = recordStore.getBatchAsync(ids, getRecordOptions).block();
List<Hotel> retrievedHotels = recordCollection.getBatchAsync(ids, getRecordOptions).block();

assertNotNull(retrievedHotels);
assertEquals(hotels.size(), retrievedHotels.size());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.microsoft.semantickernel.tests.connectors.memory.redis;

import com.microsoft.semantickernel.connectors.memory.redis.RedisVectorStore;
import com.microsoft.semantickernel.connectors.memory.redis.RedisVectorStoreOptions;
import com.microsoft.semantickernel.data.recorddefinition.VectorStoreRecordDefinition;
import com.microsoft.semantickernel.tests.connectors.memory.Hotel;
import com.redis.testcontainers.RedisContainer;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import reactor.core.publisher.Mono;
import redis.clients.jedis.JedisPooled;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Testcontainers
public class RedisVectorStoreTest {
@Container
private static final RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest");
private static JedisPooled jedis;

@BeforeAll
public static void setUp() {
jedis = new JedisPooled(redisContainer.getRedisURI());
}

@Test
public void getCollectionNamesAsync() {
RedisVectorStore<Hotel> vectorStore = new RedisVectorStore<>(jedis, new RedisVectorStoreOptions<>(Hotel.class, null));
List<String> collectionNames = Arrays.asList("collection1", "collection2", "collection3");

for (String collectionName : collectionNames) {
vectorStore.getCollection(collectionName, VectorStoreRecordDefinition.fromRecordClass(Hotel.class)).createCollectionAsync().block();
}

List<String> retrievedCollectionNames = vectorStore.getCollectionNamesAsync().block();
assertNotNull(retrievedCollectionNames);
assertEquals(collectionNames.size(), retrievedCollectionNames.size());
for (String collectionName : collectionNames) {
assertTrue(retrievedCollectionNames.contains(collectionName));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public AzureAISearchVectorStoreRecordCollection(
// Validate supported types
VectorStoreRecordDefinition.validateSupportedTypes(
this.options.getRecordClass(), this.recordDefinition, supportedKeyTypes,
supportedDataTypes, supportedVectorTypes);
supportedVectorTypes, supportedDataTypes);

// Add non-vector fields to the list
nonVectorFields.add(this.recordDefinition.getKeyField().getName());
Expand Down
Loading
Loading