Skip to content

Commit

Permalink
feat: LLM - Grounding - Added support for the disable_attribution g…
Browse files Browse the repository at this point in the history
…rounding parameter

PiperOrigin-RevId: 580285757
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 7, 2023
1 parent 791eff5 commit 91e985a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 32 deletions.
93 changes: 71 additions & 22 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@
"citations": [
{"url": "url1", "startIndex": 1, "endIndex": 2},
{"url": "url2", "startIndex": 3, "endIndex": 4},
]
],
"searchQueries": [
"searchQuery",
],
},
"content": """
Ingredients:
Expand Down Expand Up @@ -211,7 +214,8 @@
"license": None,
"publication_date": None,
},
]
],
"search_queries": ["searchQuery"],
}

_TEST_TEXT_GENERATION_PREDICTION = {
Expand Down Expand Up @@ -332,7 +336,8 @@
"endIndex": 2,
"url": "url1",
}
]
],
"searchQueries": ["searchQuery1"],
},
{
"citations": [
Expand All @@ -341,7 +346,8 @@
"endIndex": 4,
"url": "url2",
}
]
],
"searchQueries": ["searchQuery2"],
},
],
"candidates": [
Expand Down Expand Up @@ -396,10 +402,12 @@
"publication_date": None,
},
],
"search_queries": ["searchQuery1"],
}

_EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE = {
"citations": [],
"search_queries": [],
}

_TEST_CHAT_PREDICTION_STREAMING = [
Expand Down Expand Up @@ -1567,12 +1575,13 @@ def test_text_generation_multiple_candidates_grounding(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{"sources": [{"type": "WEB", "disableAttribution": False}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
},
Expand Down Expand Up @@ -1680,12 +1689,20 @@ async def test_text_generation_multiple_candidates_grounding_async(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
"type": "WEB",
"disableAttribution": False,
}
]
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
},
Expand Down Expand Up @@ -2416,12 +2433,20 @@ def test_chat(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
"type": "WEB",
"disableAttribution": False,
}
]
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
},
Expand Down Expand Up @@ -2461,12 +2486,20 @@ def test_chat(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
"type": "WEB",
"disableAttribution": False,
}
]
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
},
Expand Down Expand Up @@ -2537,12 +2570,20 @@ async def test_chat_async(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
"type": "WEB",
"disableAttribution": False,
}
]
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
},
Expand Down Expand Up @@ -2586,12 +2627,20 @@ async def test_chat_async(self):
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
"type": "WEB",
"disableAttribution": False,
}
]
},
{
"sources": [
{
"type": "VERTEX_AI_SEARCH",
"vertexAiSearchDatastore": datastore_path,
"disableAttribution": False,
}
]
},
Expand Down
30 changes: 20 additions & 10 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,12 +709,16 @@ def _to_grounding_source_dict(self) -> Dict[str, Any]:

@dataclasses.dataclass
class WebSearch(_GroundingSourceBase):
"""WebSearch represents a grounding source using public web search."""
"""WebSearch represents a grounding source using public web search.
Attributes:
disable_attribution: If set to `True`, skip finding claim attributions (i.e not generate grounding citation). Default: False.
"""

disable_attribution: bool = False
_type: str = dataclasses.field(default="WEB", init=False, repr=False)

def _to_grounding_source_dict(self) -> Dict[str, Any]:
return {"type": self._type}
return {"type": self._type, "disableAttribution": self.disable_attribution}


@dataclasses.dataclass
Expand All @@ -723,16 +727,18 @@ class VertexAISearch(_GroundingSourceBase):
Attributes:
data_store_id: Data store ID of the Vertex AI Search datastore.
location: GCP multi region where you have set up your Vertex AI Search data store. Possible values can be `global`, `us`, `eu`, etc.
Learn more about Vertex AI Search location here:
https://cloud.google.com/generative-ai-app-builder/docs/locations
Learn more about Vertex AI Search location here:
https://cloud.google.com/generative-ai-app-builder/docs/locations
project: The project where you have set up your Vertex AI Search.
If not specified, will assume that your Vertex AI Search is within your current project.
If not specified, will assume that your Vertex AI Search is within your current project.
disable_attribution: If set to `True`, skip finding claim attributions (i.e not generate grounding citation). Default: False.
"""

data_store_id: str
location: str
project: Optional[str] = None
_type: str = dataclasses.field(default="ENTERPRISE", init=False, repr=False)
disable_attribution: bool = False
_type: str = dataclasses.field(default="VERTEX_AI_SEARCH", init=False, repr=False)

def _get_datastore_path(self) -> str:
_project = self.project or aiplatform_initializer.global_config.project
Expand All @@ -742,7 +748,11 @@ def _get_datastore_path(self) -> str:
)

def _to_grounding_source_dict(self) -> Dict[str, Any]:
return {"type": self._type, "enterpriseDatastore": self._get_datastore_path()}
return {
"type": self._type,
"vertexAiSearchDatastore": self._get_datastore_path(),
"disableAttribution": self.disable_attribution,
}


@dataclasses.dataclass
Expand Down Expand Up @@ -790,6 +800,7 @@ class GroundingMetadata:
"""

citations: Optional[List[GroundingCitation]] = None
search_queries: Optional[List[str]] = None

def _parse_citation_from_dict(
self, citation_dict_camel: Dict[str, Any]
Expand Down Expand Up @@ -819,6 +830,7 @@ def __init__(self, response: Optional[Dict[str, Any]] = {}):
self._parse_citation_from_dict(citation)
for citation in response.get("citations", [])
]
self.search_queries = response.get("searchQueries", [])


@dataclasses.dataclass
Expand Down Expand Up @@ -1521,9 +1533,7 @@ def _prepare_text_embedding_request(
A `_MultiInstancePredictionRequest` object.
"""
if isinstance(texts, str) or not isinstance(texts, Sequence):
raise TypeError(
"The `texts` argument must be a list, not a single string."
)
raise TypeError("The `texts` argument must be a list, not a single string.")
instances = []
for text in texts:
if isinstance(text, TextEmbeddingInput):
Expand Down

0 comments on commit 91e985a

Please sign in to comment.