diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index d3ddc86962..dc8e105ef6 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -3666,6 +3666,95 @@ def raw_predict( headers=headers_with_token, ) + def stream_raw_predict( + self, + body: bytes, + headers: Dict[str, str], + endpoint_override: Optional[str] = None, + ) -> Iterator[bytes]: + """Make a streaming prediction request using arbitrary headers. + + Example usage: + my_endpoint = aiplatform.PrivateEndpoint(ENDPOINT_ID) + + # Prepare the request body + request_body = json.dumps({...}).encode('utf-8') + + # Define the headers + headers = { + 'Content-Type': 'application/json', + } + + # Use stream_raw_predict to send the request and process the response + for stream_response in psc_endpoint.stream_raw_predict( + body=request_body, + headers=headers, + endpoint_override="10.128.0.26" # Replace with your actual endpoint + ): + stream_response_text = stream_response.decode('utf-8') + + Args: + body (bytes): + The body of the prediction request in bytes. This must not + exceed 10 mb per request. + headers (Dict[str, str]): + The header of the request as a dictionary. There are no + restrictions on the header. + endpoint_override (Optional[str]): + The Private Service Connect endpoint's IP address or DNS that + points to the endpoint's service attachment. + + Yields: + predictions (Iterator[bytes]): + The streaming prediction results as lines of bytes. + + Raises: + ValueError: If a endpoint override is not provided for PSC based + endpoint. + ValueError: If a endpoint override is invalid for PSC based endpoint. + """ + self.wait() + if self.network or not self.private_service_connect_config: + raise ValueError( + "PSA based private endpoint does not support streaming prediction." + ) + + if self.private_service_connect_config: + if not endpoint_override: + raise ValueError( + "Cannot make a predict request because endpoint override is" + "not provided. Please ensure an endpoint override is" + "provided." + ) + if not self._validate_endpoint_override(endpoint_override): + raise ValueError( + "Invalid endpoint override provided. Please only use IP" + "address or DNS." + ) + if not self.credentials.valid: + self.credentials.refresh(google_auth_requests.Request()) + + token = self.credentials.token + headers_with_token = dict(headers) + headers_with_token["Authorization"] = f"Bearer {token}" + + if not self.authorized_session: + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self.authorized_session = google_auth_requests.AuthorizedSession( + self.credentials + ) + + url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict" + with self.authorized_session.post( + url=url, + data=body, + headers=headers_with_token, + stream=True, + verify=False, + ) as resp: + for line in resp.iter_lines(): + yield line + def explain(self): raise NotImplementedError( f"{self.__class__.__name__} class does not support 'explain' as of now." diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index a13de3158d..25553d8ad7 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -18,8 +18,8 @@ import copy from datetime import datetime, timedelta from importlib import reload -import requests import json +import requests from unittest import mock from google.api_core import operation as ga_operation @@ -920,6 +920,49 @@ def predict_private_endpoint_mock(): yield predict_mock +@pytest.fixture +def stream_raw_predict_private_endpoint_mock(): + with mock.patch.object( + google_auth_requests.AuthorizedSession, "post" + ) as stream_raw_predict_mock: + # Create a mock response object + mock_response = mock.Mock(spec=requests.Response) + + # Configure the mock to be used as a context manager + stream_raw_predict_mock.return_value.__enter__.return_value = mock_response + + # Set the status code to 200 (OK) + mock_response.status_code = 200 + + # Simulate streaming data with iter_lines + mock_response.iter_lines = mock.Mock( + return_value=iter( + [ + json.dumps( + { + "predictions": [1.0, 2.0, 3.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + json.dumps( + { + "predictions": [4.0, 5.0, 6.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + ] + ) + ) + + yield stream_raw_predict_mock + + @pytest.fixture def health_check_private_endpoint_mock(): with mock.patch.object(urllib3.PoolManager, "request") as health_check_mock: @@ -3195,6 +3238,57 @@ def test_psc_predict(self, predict_private_endpoint_mock): }, ) + @pytest.mark.usefixtures("get_psc_private_endpoint_mock") + def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock): + test_endpoint = models.PrivateEndpoint( + project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID + ) + + test_prediction_iterator = test_endpoint.stream_raw_predict( + body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer None", + }, + endpoint_override=_TEST_ENDPOINT_OVERRIDE, + ) + + test_prediction = list(test_prediction_iterator) + + stream_raw_predict_private_endpoint_mock.assert_called_once_with( + url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:streamRawPredict", + data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', + headers={ + "Content-Type": "application/json", + "Authorization": "Bearer None", + }, + stream=True, + verify=False, + ) + + # Validate the content of the returned predictions + expected_predictions = [ + json.dumps( + { + "predictions": [1.0, 2.0, 3.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + json.dumps( + { + "predictions": [4.0, 5.0, 6.0], + "metadata": {"key": "value"}, + "deployedModelId": "model-id-123", + "model": "model-name", + "modelVersionId": "1", + } + ).encode("utf-8"), + ] + assert test_prediction == expected_predictions + @pytest.mark.usefixtures("get_psc_private_endpoint_mock") def test_psc_predict_without_endpoint_override(self): test_endpoint = models.PrivateEndpoint(