From 9a19545e864c6d4743156c737dd5bb8c4b86ab6f Mon Sep 17 00:00:00 2001 From: Matthew Tang Date: Mon, 5 Feb 2024 10:56:14 -0800 Subject: [PATCH] feat: Switch Python generateContent to call Unary API endpoint PiperOrigin-RevId: 604369375 --- tests/unit/vertexai/test_generative_models.py | 62 +++++++++++++++++-- .../generative_models/_generative_models.py | 20 +----- 2 files changed, 58 insertions(+), 24 deletions(-) diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index 29e717d065..c85de256d4 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -163,6 +163,56 @@ def mock_stream_generate_content( yield response +def mock_generate_content( + self, + request: gapic_prediction_service_types.GenerateContentRequest, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[gapic_content_types.Content]] = None, +) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]: + is_continued_chat = len(request.contents) > 1 + has_tools = bool(request.tools) + + if has_tools: + has_function_response = any( + "function_response" in content.parts[0] for content in request.contents + ) + needs_function_call = not has_function_response + if needs_function_call: + response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT + else: + response_part_struct = _RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT + elif is_continued_chat: + response_part_struct = {"text": "Other planets may have different sky color."} + else: + response_part_struct = _RESPONSE_TEXT_PART_STRUCT + + return gapic_prediction_service_types.GenerateContentResponse( + candidates=[ + gapic_content_types.Candidate( + index=0, + content=gapic_content_types.Content( + # Model currently does not identify itself + # role="model", + parts=[ + gapic_content_types.Part(response_part_struct), + ], + ), + finish_reason=gapic_content_types.Candidate.FinishReason.STOP, + safety_ratings=[ + gapic_content_types.SafetyRating(rating) + for rating in _RESPONSE_SAFETY_RATINGS_STRUCT + ], + citation_metadata=gapic_content_types.CitationMetadata( + citations=[ + gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT), + ] + ), + ), + ], + ) + + @pytest.mark.usefixtures("google_auth_mock") class TestGenerativeModels: """Unit tests for the generative models.""" @@ -178,8 +228,8 @@ def teardown_method(self): @mock.patch.object( target=prediction_service.PredictionServiceClient, - attribute="stream_generate_content", - new=mock_stream_generate_content, + attribute="generate_content", + new=mock_generate_content, ) def test_generate_content(self): model = generative_models.GenerativeModel("gemini-pro") @@ -212,8 +262,8 @@ def test_generate_content_streaming(self): @mock.patch.object( target=prediction_service.PredictionServiceClient, - attribute="stream_generate_content", - new=mock_stream_generate_content, + attribute="generate_content", + new=mock_generate_content, ) def test_chat_send_message(self): model = generative_models.GenerativeModel("gemini-pro") @@ -225,8 +275,8 @@ def test_chat_send_message(self): @mock.patch.object( target=prediction_service.PredictionServiceClient, - attribute="stream_generate_content", - new=mock_stream_generate_content, + attribute="generate_content", + new=mock_generate_content, ) def test_chat_function_calling(self): get_current_weather_func = generative_models.FunctionDeclaration( diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index aafdb3b3d2..a64875d6c6 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -431,15 +431,7 @@ def _generate_content( safety_settings=safety_settings, tools=tools, ) - # generate_content is not available - # gapic_response = self._prediction_client.generate_content(request=request) - gapic_response = None - stream = self._prediction_client.stream_generate_content(request=request) - for gapic_chunk in stream: - if gapic_response: - _append_gapic_response(gapic_response, gapic_chunk) - else: - gapic_response = gapic_chunk + gapic_response = self._prediction_client.generate_content(request=request) return self._parse_response(gapic_response) async def _generate_content_async( @@ -473,17 +465,9 @@ async def _generate_content_async( safety_settings=safety_settings, tools=tools, ) - # generate_content is not available - # gapic_response = await self._prediction_async_client.generate_content(request=request) - gapic_response = None - stream = await self._prediction_async_client.stream_generate_content( + gapic_response = await self._prediction_async_client.generate_content( request=request ) - async for gapic_chunk in stream: - if gapic_response: - _append_gapic_response(gapic_response, gapic_chunk) - else: - gapic_response = gapic_chunk return self._parse_response(gapic_response) def _generate_content_streaming(