diff --git a/airflow/providers/google/cloud/operators/speech_to_text.py b/airflow/providers/google/cloud/operators/speech_to_text.py index de26a8ba8216f..f8c3e4703f9e0 100644 --- a/airflow/providers/google/cloud/operators/speech_to_text.py +++ b/airflow/providers/google/cloud/operators/speech_to_text.py @@ -113,15 +113,14 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) - - FileDetailsLink.persist( - context=context, - task_instance=self, - # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" - uri=self.audio["uri"][5:], - project_id=self.project_id or hook.project_id, - ) - + if self.audio.uri: + FileDetailsLink.persist( + context=context, + task_instance=self, + # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" + uri=self.audio.uri[5:], + project_id=self.project_id or hook.project_id, + ) response = hook.recognize_speech( config=self.config, audio=self.audio, retry=self.retry, timeout=self.timeout ) diff --git a/airflow/providers/google/cloud/operators/translate_speech.py b/airflow/providers/google/cloud/operators/translate_speech.py index fb3bdccb1abee..b0b540c31e086 100644 --- a/airflow/providers/google/cloud/operators/translate_speech.py +++ b/airflow/providers/google/cloud/operators/translate_speech.py @@ -169,7 +169,14 @@ def execute(self, context: Context) -> dict: raise AirflowException( f"Wrong response '{recognize_dict}' returned - it should contain {key} field" ) - + if self.audio.uri: + FileDetailsLink.persist( + context=context, + task_instance=self, + # Slice from: "gs://{BUCKET_NAME}/{FILE_NAME}" to: "{BUCKET_NAME}/{FILE_NAME}" + uri=self.audio.uri[5:], + project_id=self.project_id or translate_hook.project_id, + ) try: translation = translate_hook.translate( values=transcript, @@ -179,12 +186,6 @@ def execute(self, context: Context) -> dict: model=self.model, ) self.log.info("Translated output: %s", translation) - FileDetailsLink.persist( - context=context, - task_instance=self, - uri=self.audio["uri"][5:], - project_id=self.project_id or translate_hook.project_id, - ) return translation except ValueError as e: self.log.error("An error has been thrown from translate speech method:") diff --git a/tests/providers/google/cloud/operators/test_speech_to_text.py b/tests/providers/google/cloud/operators/test_speech_to_text.py index 51dd6dd8db7c0..1d7fa9ca37fea 100644 --- a/tests/providers/google/cloud/operators/test_speech_to_text.py +++ b/tests/providers/google/cloud/operators/test_speech_to_text.py @@ -21,7 +21,7 @@ import pytest from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.speech_v1 import RecognizeResponse +from google.cloud.speech_v1 import RecognitionAudio, RecognitionConfig, RecognizeResponse from airflow.exceptions import AirflowException from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator @@ -29,8 +29,8 @@ PROJECT_ID = "project-id" GCP_CONN_ID = "gcp-conn-id" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] -CONFIG = {"encoding": "LINEAR16"} -AUDIO = {"uri": "gs://bucket/object"} +CONFIG = RecognitionConfig({"encoding": "LINEAR16"}) +AUDIO = RecognitionAudio({"uri": "gs://bucket/object"}) class TestCloudSpeechToTextRecognizeSpeechOperator: @@ -80,3 +80,25 @@ def test_missing_audio(self, mock_hook): err = ctx.value assert "audio" in str(err) mock_hook.assert_not_called() + + @patch("airflow.providers.google.cloud.operators.speech_to_text.FileDetailsLink.persist") + @patch("airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextHook") + def test_no_audio_uri(self, mock_hook, mock_file_link): + mock_hook.return_value.recognize_speech.return_value = RecognizeResponse() + AUDIO_NO_URI = RecognitionAudio({"content": b"set content data instead of uri"}) + + op = CloudSpeechToTextRecognizeSpeechOperator( + project_id=PROJECT_ID, + gcp_conn_id=GCP_CONN_ID, + config=CONFIG, + audio=AUDIO_NO_URI, + task_id="id", + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context=MagicMock()) + + mock_hook.return_value.recognize_speech.assert_called_once_with( + config=CONFIG, audio=AUDIO_NO_URI, retry=DEFAULT, timeout=None + ) + assert op.audio.uri == "" + mock_file_link.assert_not_called() diff --git a/tests/providers/google/cloud/operators/test_translate_speech.py b/tests/providers/google/cloud/operators/test_translate_speech.py index 6dd000504cef5..8e6beb79b9702 100644 --- a/tests/providers/google/cloud/operators/test_translate_speech.py +++ b/tests/providers/google/cloud/operators/test_translate_speech.py @@ -21,6 +21,8 @@ import pytest from google.cloud.speech_v1 import ( + RecognitionAudio, + RecognitionConfig, RecognizeResponse, SpeechRecognitionAlternative, SpeechRecognitionResult, @@ -54,8 +56,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook): ] op = CloudTranslateSpeechOperator( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), target_language="pl", format_="text", source_language=None, @@ -77,8 +79,8 @@ def test_minimal_green_path(self, mock_translate_hook, mock_speech_hook): ) mock_speech_hook.return_value.recognize_speech.assert_called_once_with( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), ) mock_translate_hook.return_value.translate.assert_called_once_with( @@ -104,8 +106,8 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook): results=[SpeechRecognitionResult()] ) op = CloudTranslateSpeechOperator( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), target_language="pl", format_="text", source_language=None, @@ -128,8 +130,47 @@ def test_bad_recognition_response(self, mock_translate_hook, mock_speech_hook): ) mock_speech_hook.return_value.recognize_speech.assert_called_once_with( - audio={"uri": "gs://bucket/object"}, - config={"encoding": "LINEAR16"}, + audio=RecognitionAudio({"uri": "gs://bucket/object"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), ) mock_translate_hook.return_value.translate.assert_not_called() + + @mock.patch("airflow.providers.google.cloud.operators.translate_speech.FileDetailsLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate_speech.CloudSpeechToTextHook") + @mock.patch("airflow.providers.google.cloud.operators.translate_speech.CloudTranslateHook") + def test_no_audio_uri(self, mock_translate_hook, mock_speech_hook, file_link_mock): + mock_speech_hook.return_value.recognize_speech.return_value = RecognizeResponse( + results=[ + SpeechRecognitionResult( + alternatives=[SpeechRecognitionAlternative(transcript="test speech recognition result")] + ) + ] + ) + mock_translate_hook.return_value.translate.return_value = [ + { + "translatedText": "sprawdzić wynik rozpoznawania mowy", + "detectedSourceLanguage": "en", + "model": "base", + "input": "test speech recognition result", + } + ] + op = CloudTranslateSpeechOperator( + audio=RecognitionAudio({"content": b"set content data instead of uri"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), + target_language="pl", + format_="text", + source_language=None, + model="base", + gcp_conn_id=GCP_CONN_ID, + task_id="id", + impersonation_chain=IMPERSONATION_CHAIN, + ) + op.execute(context=mock.MagicMock()) + + mock_speech_hook.return_value.recognize_speech.assert_called_once_with( + audio=RecognitionAudio({"content": b"set content data instead of uri"}), + config=RecognitionConfig({"encoding": "LINEAR16"}), + ) + assert op.audio.uri == "" + file_link_mock.assert_not_called()