diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 441471269f..4548f33a1d 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -40,6 +40,9 @@ VertexAiSearchConfig, VertexVectorSearch, VertexFeatureStore, + RagEmbeddingModelConfig, + VertexPredictionEndpoint, + RagVectorDbConfig, ) from google.cloud.aiplatform_v1beta1 import ( GoogleDriveSource, @@ -56,7 +59,7 @@ SlackSource as GapicSlackSource, RagContexts, RetrieveContextsResponse, - RagVectorDbConfig, + RagVectorDbConfig as GapicRagVectorDbConfig, VertexAiSearchConfig as GapicVertexAiSearchConfig, ) from google.cloud.aiplatform_v1beta1.types import api_auth @@ -112,8 +115,8 @@ name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - rag_vector_db_config=RagVectorDbConfig( - weaviate=RagVectorDbConfig.Weaviate( + rag_vector_db_config=GapicRagVectorDbConfig( + weaviate=GapicRagVectorDbConfig.Weaviate( http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, collection_name=TEST_WEAVIATE_COLLECTION_NAME, ), @@ -128,8 +131,8 @@ name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - rag_vector_db_config=RagVectorDbConfig( - vertex_feature_store=RagVectorDbConfig.VertexFeatureStore( + rag_vector_db_config=GapicRagVectorDbConfig( + vertex_feature_store=GapicRagVectorDbConfig.VertexFeatureStore( feature_view_resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME ), ), @@ -138,8 +141,8 @@ name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - rag_vector_db_config=RagVectorDbConfig( - vertex_vector_search=RagVectorDbConfig.VertexVectorSearch( + rag_vector_db_config=GapicRagVectorDbConfig( + vertex_vector_search=GapicRagVectorDbConfig.VertexVectorSearch( index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT, index=TEST_VERTEX_VECTOR_SEARCH_INDEX, ), @@ -149,8 +152,8 @@ name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - rag_vector_db_config=RagVectorDbConfig( - pinecone=RagVectorDbConfig.Pinecone(index_name=TEST_PINECONE_INDEX_NAME), + rag_vector_db_config=GapicRagVectorDbConfig( + pinecone=GapicRagVectorDbConfig.Pinecone(index_name=TEST_PINECONE_INDEX_NAME), api_auth=api_auth.ApiAuth( api_key_config=api_auth.ApiAuth.ApiKeyConfig( api_key_secret_version=TEST_PINECONE_API_KEY_SECRET_VERSION @@ -161,6 +164,14 @@ TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig( publisher_model="publishers/google/models/textembedding-gecko", ) +TEST_RAG_EMBEDDING_MODEL_CONFIG = RagEmbeddingModelConfig( + vertex_prediction_endpoint=VertexPredictionEndpoint( + publisher_model="publishers/google/models/textembedding-gecko", + ), +) +TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG = RagVectorDbConfig( + rag_embedding_model_config=TEST_RAG_EMBEDDING_MODEL_CONFIG, +) TEST_VERTEX_FEATURE_STORE_CONFIG = VertexFeatureStore( resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME, ) @@ -195,6 +206,62 @@ vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG, ) TEST_PAGE_TOKEN = "test-page-token" +# Backend Config +TEST_GAPIC_RAG_CORPUS_BACKEND_CONFIG = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, +) +TEST_GAPIC_RAG_CORPUS_BACKEND_CONFIG.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = "projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format( + TEST_PROJECT, TEST_REGION +) +TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH_BACKEND_CONFIG = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db_config=GapicRagVectorDbConfig( + vertex_vector_search=GapicRagVectorDbConfig.VertexVectorSearch( + index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT, + index=TEST_VERTEX_VECTOR_SEARCH_INDEX, + ), + ), +) +TEST_GAPIC_RAG_CORPUS_PINECONE_BACKEND_CONFIG = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db_config=GapicRagVectorDbConfig( + pinecone=GapicRagVectorDbConfig.Pinecone(index_name=TEST_PINECONE_INDEX_NAME), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_PINECONE_API_KEY_SECRET_VERSION + ), + ), + ), +) +TEST_RAG_CORPUS_BACKEND = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + backend_config=TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, +) +TEST_BACKEND_CONFIG_PINECONE_CONFIG = RagVectorDbConfig( + vector_db=TEST_PINECONE_CONFIG, +) +TEST_RAG_CORPUS_PINECONE_BACKEND = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + backend_config=TEST_BACKEND_CONFIG_PINECONE_CONFIG, +) +TEST_BACKEND_CONFIG_VERTEX_VECTOR_SEARCH_CONFIG = RagVectorDbConfig( + vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG, +) +TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH_BACKEND = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + backend_config=TEST_BACKEND_CONFIG_VERTEX_VECTOR_SEARCH_CONFIG, +) # Vertex AI Search Config TEST_VERTEX_AI_SEARCH_ENGINE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/engines/test-engine/servingConfigs/test-serving-config" TEST_VERTEX_AI_SEARCH_DATASTORE_SERVING_CONFIG = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/collections/test-collection/dataStores/test-datastore/servingConfigs/test-serving-config" diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index 43a311e0d2..0b8776edff 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -53,6 +53,21 @@ def create_rag_corpus_mock(): yield create_rag_corpus_mock +@pytest.fixture +def create_rag_corpus_mock_backend(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_BACKEND_CONFIG + ) + create_rag_corpus_mock.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock + + @pytest.fixture def create_rag_corpus_mock_weaviate(): with mock.patch.object( @@ -102,6 +117,23 @@ def create_rag_corpus_mock_vertex_vector_search(): yield create_rag_corpus_mock_vertex_vector_search +@pytest.fixture +def create_rag_corpus_mock_vertex_vector_search_backend(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_vertex_vector_search: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH_BACKEND_CONFIG + ) + create_rag_corpus_mock_vertex_vector_search.return_value = ( + create_rag_corpus_lro_mock + ) + yield create_rag_corpus_mock_vertex_vector_search + + @pytest.fixture def create_rag_corpus_mock_pinecone(): with mock.patch.object( @@ -117,6 +149,21 @@ def create_rag_corpus_mock_pinecone(): yield create_rag_corpus_mock_pinecone +@pytest.fixture +def create_rag_corpus_mock_pinecone_backend(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_pinecone: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_PINECONE_BACKEND_CONFIG + ) + create_rag_corpus_mock_pinecone.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock_pinecone + + @pytest.fixture def create_rag_corpus_mock_vertex_ai_engine_search_config(): with mock.patch.object( @@ -407,6 +454,15 @@ def test_create_corpus_success(self): rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) + @pytest.mark.usefixtures("create_rag_corpus_mock_backend") + def test_create_corpus_backend_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_BACKEND) + @pytest.mark.usefixtures("create_rag_corpus_mock_weaviate") def test_create_corpus_weaviate_success(self): rag_corpus = rag.create_corpus( @@ -438,6 +494,18 @@ def test_create_corpus_vertex_vector_search_success(self): rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH ) + @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_vector_search_backend") + def test_create_corpus_vertex_vector_search_backend_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_VERTEX_VECTOR_SEARCH_CONFIG, + ) + + rag_corpus_eq( + rag_corpus, + test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH_BACKEND, + ) + @pytest.mark.usefixtures("create_rag_corpus_mock_pinecone") def test_create_corpus_pinecone_success(self): rag_corpus = rag.create_corpus( @@ -447,6 +515,43 @@ def test_create_corpus_pinecone_success(self): rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE) + @pytest.mark.usefixtures("create_rag_corpus_mock_pinecone_backend") + def test_create_corpus_pinecone_backend_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_PINECONE_CONFIG, + ) + + rag_corpus_eq( + rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE_BACKEND + ) + + def test_create_corpus_backend_config_with_embedding_model_config_failure( + self, + ): + with pytest.raises(ValueError) as e: + rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, + embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG, + ) + e.match( + "Only one of backend_config or embedding_model_config and vector_db can be set. embedding_model_config and vector_db are deprecated, use backend_config instead." + ) + + def test_create_corpus_backend_config_with_vector_db_failure( + self, + ): + with pytest.raises(ValueError) as e: + rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, + vector_db=test_rag_constants_preview.TEST_PINECONE_CONFIG, + ) + e.match( + "Only one of backend_config or embedding_model_config and vector_db can be set. embedding_model_config and vector_db are deprecated, use backend_config instead." + ) + @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_engine_search_config") def test_create_corpus_vais_engine_search_config_success(self): rag_corpus = rag.create_corpus( @@ -480,6 +585,17 @@ def test_create_corpus_vais_datastore_search_config_with_vector_db_failure(self) ) e.match("Only one of vertex_ai_search_config or vector_db can be set.") + def test_create_corpus_vais_datastore_search_config_with_backend_config_failure( + self, + ): + with pytest.raises(ValueError) as e: + rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, + ) + e.match("Only one of vertex_ai_search_config or backend_config can be set.") + def test_create_corpus_vais_datastore_search_config_with_embedding_model_config_failure( self, ): diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index b1f2c26756..5b1ee632e6 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -58,6 +58,9 @@ VertexFeatureStore, VertexVectorSearch, Weaviate, + RagEmbeddingModelConfig, + VertexPredictionEndpoint, + RagVectorDbConfig, ) __all__ = ( @@ -87,6 +90,9 @@ "VertexRagStore", "VertexVectorSearch", "Weaviate", + "RagEmbeddingModelConfig", + "VertexPredictionEndpoint", + "RagVectorDbConfig", "create_corpus", "delete_corpus", "delete_file", diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index 9aca68296b..a6139f00b9 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -49,6 +49,7 @@ RagCorpus, RagFile, RagManagedDb, + RagVectorDbConfig, SharePointSources, SlackChannelsSource, VertexAiSearchConfig, @@ -67,6 +68,7 @@ def create_corpus( Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb] ] = None, vertex_ai_search_config: Optional[VertexAiSearchConfig] = None, + backend_config: Optional[RagVectorDbConfig] = None, ) -> RagCorpus: """Creates a new RagCorpus resource. @@ -88,11 +90,15 @@ def create_corpus( consist of any UTF-8 characters. description: The description of the RagCorpus. embedding_model_config: The embedding model config. + Note: Deprecated. Use backend_config instead. vector_db: The vector db config of the RagCorpus. If unspecified, the default database Spanner is used. + Note: Deprecated. Use backend_config instead. vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. Note: embedding_model_config or vector_db cannot be set if vertex_ai_search_config is specified. + backend_config: The backend config of the RagCorpus. It can specify a + Vector DB and/or the embedding model config. Returns: RagCorpus. Raises: @@ -115,6 +121,22 @@ def create_corpus( "Only one of vertex_ai_search_config or embedding_model_config can be set." ) + if vertex_ai_search_config and backend_config: + raise ValueError( + "Only one of vertex_ai_search_config or backend_config can be set." + ) + + if backend_config and (embedding_model_config or vector_db): + raise ValueError( + "Only one of backend_config or embedding_model_config and vector_db can be set. embedding_model_config and vector_db are deprecated, use backend_config instead." + ) + + if backend_config: + _gapic_utils.set_backend_config( + backend_config=backend_config, + rag_corpus=rag_corpus, + ) + if vertex_ai_search_config and vector_db: raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.") @@ -156,6 +178,7 @@ def update_corpus( ] ] = None, vertex_ai_search_config: Optional[VertexAiSearchConfig] = None, + backend_config: Optional[RagVectorDbConfig] = None, ) -> RagCorpus: """Updates a RagCorpus resource. @@ -187,6 +210,8 @@ def update_corpus( If not provided, the Vertex AI Search config will not be updated. Note: embedding_model_config or vector_db cannot be set if vertex_ai_search_config is specified. + backend_config: The backend config of the RagCorpus. Specifies a Vector + DB and/or the embedding model config. Returns: RagCorpus. @@ -209,6 +234,12 @@ def update_corpus( if vertex_ai_search_config and vector_db: raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.") + if backend_config: + _gapic_utils.set_backend_config( + backend_config=backend_config, + rag_corpus=rag_corpus, + ) + if vertex_ai_search_config: _gapic_utils.set_vertex_ai_search_config( vertex_ai_search_config=vertex_ai_search_config, diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 1a35e9819e..d11254f2ba 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -18,7 +18,7 @@ from typing import Any, Dict, Optional, Sequence, Union from google.cloud.aiplatform_v1beta1.types import api_auth from google.cloud.aiplatform_v1beta1 import ( - RagEmbeddingModelConfig, + RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig, GoogleDriveSource, ImportRagFilesConfig, ImportRagFilesRequest, @@ -30,8 +30,8 @@ SharePointSources as GapicSharePointSources, SlackSource as GapicSlackSource, JiraSource as GapicJiraSource, - RagVectorDbConfig, VertexAiSearchConfig as GapicVertexAiSearchConfig, + RagVectorDbConfig as GapicRagVectorDbConfig, ) from google.cloud.aiplatform import initializer from google.cloud.aiplatform.utils import ( @@ -41,6 +41,7 @@ ) from vertexai.preview.rag.utils.resources import ( EmbeddingModelConfig, + VertexPredictionEndpoint, Pinecone, RagCorpus, RagFile, @@ -53,6 +54,8 @@ VertexFeatureStore, VertexVectorSearch, Weaviate, + RagVectorDbConfig, + RagEmbeddingModelConfig, ) @@ -78,7 +81,7 @@ def create_rag_service_client(): def convert_gapic_to_embedding_model_config( - gapic_embedding_model_config: RagEmbeddingModelConfig, + gapic_embedding_model_config: GapicRagEmbeddingModelConfig, ) -> EmbeddingModelConfig: """Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig.""" embedding_model_config = EmbeddingModelConfig() @@ -105,45 +108,54 @@ def convert_gapic_to_embedding_model_config( return embedding_model_config -def _check_weaviate(gapic_vector_db: RagVectorDbConfig) -> bool: +def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("weaviate") except AttributeError: return gapic_vector_db.weaviate.ByteSize() > 0 -def _check_rag_managed_db(gapic_vector_db: RagVectorDbConfig) -> bool: +def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("rag_managed_db") except AttributeError: return gapic_vector_db.rag_managed_db.ByteSize() > 0 -def _check_vertex_feature_store(gapic_vector_db: RagVectorDbConfig) -> bool: +def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_feature_store") except AttributeError: return gapic_vector_db.vertex_feature_store.ByteSize() > 0 -def _check_pinecone(gapic_vector_db: RagVectorDbConfig) -> bool: +def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("pinecone") except AttributeError: return gapic_vector_db.pinecone.ByteSize() > 0 -def _check_vertex_vector_search(gapic_vector_db: RagVectorDbConfig) -> bool: +def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool: try: return gapic_vector_db.__contains__("vertex_vector_search") except AttributeError: return gapic_vector_db.vertex_vector_search.ByteSize() > 0 +def _check_rag_embedding_model_config( + gapic_vector_db: GapicRagVectorDbConfig, +) -> bool: + try: + return gapic_vector_db.__contains__("rag_embedding_model_config") + except AttributeError: + return gapic_vector_db.rag_embedding_model_config.ByteSize() > 0 + + def convert_gapic_to_vector_db( - gapic_vector_db: RagVectorDbConfig, + gapic_vector_db: GapicRagVectorDbConfig, ) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]: - """Convert Gapic RagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone.""" + """Convert Gapic GapicRagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone.""" if _check_weaviate(gapic_vector_db): return Weaviate( weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint, @@ -181,6 +193,59 @@ def convert_gapic_to_vertex_ai_search_config( return None +def convert_gapic_to_rag_embedding_model_config( + gapic_embedding_model_config: GapicRagEmbeddingModelConfig, +) -> RagEmbeddingModelConfig: + """Convert GapicRagEmbeddingModelConfig to RagEmbeddingModelConfig.""" + embedding_model_config = RagEmbeddingModelConfig() + path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint + publisher_model = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/publishers/google/models/(?P.+?)$", + path, + ) + endpoint = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + if publisher_model: + embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint( + publisher_model=path + ) + if endpoint: + embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint( + endpoint=path, + model=gapic_embedding_model_config.vertex_prediction_endpoint.model, + model_version_id=gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id, + ) + return embedding_model_config + + +def convert_gapic_to_backend_config( + gapic_vector_db: GapicRagVectorDbConfig, +) -> RagVectorDbConfig: + """Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb.""" + vector_config = RagVectorDbConfig() + if _check_pinecone(gapic_vector_db): + vector_config.vector_db = Pinecone( + index_name=gapic_vector_db.pinecone.index_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + elif _check_vertex_vector_search(gapic_vector_db): + vector_config.vector_db = VertexVectorSearch( + index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, + index=gapic_vector_db.vertex_vector_search.index, + ) + elif _check_rag_managed_db(gapic_vector_db): + vector_config.vector_db = RagManagedDb() + if _check_rag_embedding_model_config(gapic_vector_db): + vector_config.rag_embedding_model_config = ( + convert_gapic_to_rag_embedding_model_config( + gapic_vector_db.rag_embedding_model_config + ) + ) + return vector_config + + def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: """Convert GapicRagCorpus to RagCorpus.""" rag_corpus = RagCorpus( @@ -194,6 +259,9 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config( gapic_rag_corpus.vertex_ai_search_config ), + backend_config=convert_gapic_to_backend_config( + gapic_rag_corpus.rag_vector_db_config + ), ) return rag_corpus @@ -202,6 +270,8 @@ def convert_gapic_to_rag_corpus_no_embedding_model_config( gapic_rag_corpus: GapicRagCorpus, ) -> RagCorpus: """Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus.""" + rag_vector_db_config_no_embedding_model_config = gapic_rag_corpus.vector_db_config + rag_vector_db_config_no_embedding_model_config.rag_embedding_model_config = None rag_corpus = RagCorpus( name=gapic_rag_corpus.name, display_name=gapic_rag_corpus.display_name, @@ -210,6 +280,9 @@ def convert_gapic_to_rag_corpus_no_embedding_model_config( vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config( gapic_rag_corpus.vertex_ai_search_config ), + backend_config=convert_gapic_to_backend_config( + rag_vector_db_config_no_embedding_model_config + ), ) return rag_corpus @@ -563,16 +636,16 @@ def set_vector_db( ) -> None: """Sets the vector db configuration for the rag corpus.""" if vector_db is None or isinstance(vector_db, RagManagedDb): - rag_corpus.rag_vector_db_config = RagVectorDbConfig( - rag_managed_db=RagVectorDbConfig.RagManagedDb(), + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + rag_managed_db=GapicRagVectorDbConfig.RagManagedDb(), ) elif isinstance(vector_db, Weaviate): http_endpoint = vector_db.weaviate_http_endpoint collection_name = vector_db.collection_name api_key = vector_db.api_key - rag_corpus.rag_vector_db_config = RagVectorDbConfig( - weaviate=RagVectorDbConfig.Weaviate( + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + weaviate=GapicRagVectorDbConfig.Weaviate( http_endpoint=http_endpoint, collection_name=collection_name, ), @@ -585,8 +658,8 @@ def set_vector_db( elif isinstance(vector_db, VertexFeatureStore): resource_name = vector_db.resource_name - rag_corpus.rag_vector_db_config = RagVectorDbConfig( - vertex_feature_store=RagVectorDbConfig.VertexFeatureStore( + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + vertex_feature_store=GapicRagVectorDbConfig.VertexFeatureStore( feature_view_resource_name=resource_name, ), ) @@ -594,8 +667,8 @@ def set_vector_db( index_endpoint = vector_db.index_endpoint index = vector_db.index - rag_corpus.rag_vector_db_config = RagVectorDbConfig( - vertex_vector_search=RagVectorDbConfig.VertexVectorSearch( + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + vertex_vector_search=GapicRagVectorDbConfig.VertexVectorSearch( index_endpoint=index_endpoint, index=index, ), @@ -604,8 +677,8 @@ def set_vector_db( index_name = vector_db.index_name api_key = vector_db.api_key - rag_corpus.rag_vector_db_config = RagVectorDbConfig( - pinecone=RagVectorDbConfig.Pinecone( + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + pinecone=GapicRagVectorDbConfig.Pinecone( index_name=index_name, ), api_auth=api_auth.ApiAuth( @@ -642,3 +715,49 @@ def set_vertex_ai_search_config( raise ValueError( "serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`" ) + + +def set_backend_config( + backend_config: Optional[ + Union[ + RagVectorDbConfig, + None, + ] + ], + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the vector db configuration for the rag corpus.""" + if backend_config is None: + return + + if backend_config.vector_db is not None: + vector_config = backend_config.vector_db + if vector_config is None or isinstance(vector_config, RagManagedDb): + rag_corpus.vector_db_config.rag_managed_db.CopyFrom( + GapicRagVectorDbConfig.RagManagedDb() + ) + elif isinstance(vector_config, VertexVectorSearch): + index_endpoint = vector_config.index_endpoint + index = vector_config.index + + rag_corpus.vector_db_config.vertex_vector_search.index_endpoint = ( + index_endpoint + ) + rag_corpus.vector_db_config.vertex_vector_search.index = index + elif isinstance(vector_config, Pinecone): + index_name = vector_config.index_name + api_key = vector_config.api_key + + rag_corpus.vector_db_config.pinecone.index_name = index_name + rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = ( + api_key + ) + else: + raise TypeError( + "backend_config must be a VertexFeatureStore," + "RagManagedDb, or Pinecone." + ) + if backend_config.rag_embedding_model_config: + set_embedding_model_config( + backend_config.rag_embedding_model_config, rag_corpus + ) diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 6cbe6b8977..8bc3304d99 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -71,6 +71,47 @@ class EmbeddingModelConfig: model_version_id: Optional[str] = None +@dataclasses.dataclass +class VertexPredictionEndpoint: + """VertexPredictionEndpoint. + + Attributes: + publisher_model: 1P publisher model resource name. Format: + ``publishers/google/models/{model}`` or + ``projects/{project}/locations/{location}/publishers/google/models/{model}`` + endpoint: 1P fine tuned embedding model resource name. Format: + ``endpoints/{endpoint}`` or + ``projects/{project}/locations/{location}/endpoints/{endpoint}``. + model: + Output only. The resource name of the model that is deployed + on the endpoint. Present only when the endpoint is not a + publisher model. Pattern: + ``projects/{project}/locations/{location}/models/{model}`` + model_version_id: + Output only. Version ID of the model that is + deployed on the endpoint. Present only when the + endpoint is not a publisher model. + """ + + endpoint: Optional[str] = None + publisher_model: Optional[str] = None + model: Optional[str] = None + model_version_id: Optional[str] = None + + +@dataclasses.dataclass +class RagEmbeddingModelConfig: + """RagEmbeddingModelConfig. + + Attributes: + vertex_prediction_endpoint: The Vertex AI Prediction Endpoint resource + name. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + """ + + vertex_prediction_endpoint: Optional[VertexPredictionEndpoint] = None + + @dataclasses.dataclass class Weaviate: """Weaviate. @@ -151,6 +192,22 @@ class VertexAiSearchConfig: serving_config: Optional[str] = None +@dataclasses.dataclass +class RagVectorDbConfig: + """RagVectorDbConfig. + + Attributes: + vector_db: Can be one of the following: Weaviate, VertexFeatureStore, + VertexVectorSearch, Pinecone, RagManagedDb. + rag_embedding_model_config: The embedding model config of the Vector DB. + """ + + vector_db: Optional[ + Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb] + ] = None + rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None + + @dataclasses.dataclass class RagCorpus: """RAG corpus(output only). @@ -161,8 +218,12 @@ class RagCorpus: display_name: Display name that was configured at client side. description: The description of the RagCorpus. embedding_model_config: The embedding model config of the RagCorpus. + Note: Deprecated. Use backend_config instead. vector_db: The Vector DB of the RagCorpus. + Note: Deprecated. Use backend_config instead. vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. + backend_config: The backend config of the RagCorpus. It can specify a + Vector DB and/or the embedding model config. """ name: Optional[str] = None @@ -173,6 +234,7 @@ class RagCorpus: Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb] ] = None vertex_ai_search_config: Optional[VertexAiSearchConfig] = None + backend_config: Optional[RagVectorDbConfig] = None @dataclasses.dataclass