Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate AutoMLBatchPredictOperator and refactor AutoML system tests #42260

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion airflow/providers/google/cloud/operators/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TableSpec,
)

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook
from airflow.providers.google.cloud.links.translate import (
Expand All @@ -45,6 +45,7 @@
TranslationLegacyModelTrainLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

if TYPE_CHECKING:
Expand Down Expand Up @@ -338,6 +339,11 @@ def execute(self, context: Context):
return PredictResponse.to_dict(result)


@deprecated(
planned_removal_date="January 01, 2025",
use_instead="airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job",
category=AirflowProviderDeprecationWarning,
)
class AutoMLBatchPredictOperator(GoogleCloudBaseOperator):
"""
Perform a batch prediction on Google Cloud AutoML.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ datasets. To create and import data to the dataset please use
and
:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`

.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_translation.py
:language: python
:dedent: 4
:start-after: [START howto_operator_automl_create_model]
Expand Down Expand Up @@ -195,17 +195,12 @@ To obtain predictions from Google Cloud AutoML model you can use
:class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator`. In the first case
the model must be deployed.

.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_translation.py
:language: python
:dedent: 4
:start-after: [START howto_operator_prediction]
:end-before: [END howto_operator_prediction]

.. exampleinclude:: /../../tests/system/providers/google/cloud/automl/example_automl_model.py
:language: python
:dedent: 4
:start-after: [START howto_operator_batch_prediction]
:end-before: [END howto_operator_batch_prediction]

Th :class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator` deprecated for tables,
video intelligence, vision and natural language is deprecated and will be removed after 31.03.2024. Please use
Expand Down
1 change: 1 addition & 0 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
}

ASSETS_NOT_REQUIRED = {
Expand Down
22 changes: 12 additions & 10 deletions tests/providers/google/cloud/links/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.cloud.automl_v1beta1 import Model

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.links.translate import (
TRANSLATION_BASE_LINK,
TranslationDatasetListLink,
Expand Down Expand Up @@ -146,16 +147,17 @@ def test_get_link(self, create_task_instance_of_operator, session):
f"predict;modelId={MODEL}?project={GCP_PROJECT_ID}"
)
link = TranslationLegacyModelPredictLink()
ti = create_task_instance_of_operator(
AutoMLBatchPredictOperator,
dag_id="test_legacy_model_predict_link_dag",
task_id="test_legacy_model_predict_link_task",
model_id=MODEL,
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
input_config="input_config",
output_config="input_config",
)
with pytest.warns(AirflowProviderDeprecationWarning):
ti = create_task_instance_of_operator(
AutoMLBatchPredictOperator,
dag_id="test_legacy_model_predict_link_dag",
task_id="test_legacy_model_predict_link_task",
model_id=MODEL,
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
input_config="input_config",
output_config="input_config",
)
ti.task.model = Model(dataset_id=DATASET, display_name=MODEL)
session.add(ti)
session.commit()
Expand Down
70 changes: 36 additions & 34 deletions tests/providers/google/cloud/operators/test_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook
from airflow.providers.google.cloud.operators.automl import (
Expand Down Expand Up @@ -148,15 +148,16 @@ def test_execute(self, mock_hook, mock_link_persist):
mock_hook.return_value.extract_object_id = extract_object_id
mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult()
mock_context = {"ti": mock.MagicMock()}
op = AutoMLBatchPredictOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT_ID,
input_config=INPUT_CONFIG,
output_config=OUTPUT_CONFIG,
task_id=TASK_ID,
prediction_params={},
)
with pytest.warns(AirflowProviderDeprecationWarning):
op = AutoMLBatchPredictOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT_ID,
input_config=INPUT_CONFIG,
output_config=OUTPUT_CONFIG,
task_id=TASK_ID,
prediction_params={},
)
op.execute(context=mock_context)
mock_hook.return_value.batch_predict.assert_called_once_with(
input_config=INPUT_CONFIG,
Expand All @@ -182,16 +183,16 @@ def test_execute_deprecated(self, mock_hook):
del returned_model.translation_model_metadata
mock_hook.return_value.get_model.return_value = returned_model
mock_hook.return_value.extract_object_id = extract_object_id

op = AutoMLBatchPredictOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT_ID,
input_config=INPUT_CONFIG,
output_config=OUTPUT_CONFIG,
task_id=TASK_ID,
prediction_params={},
)
with pytest.warns(AirflowProviderDeprecationWarning):
op = AutoMLBatchPredictOperator(
model_id=MODEL_ID,
location=GCP_LOCATION,
project_id=GCP_PROJECT_ID,
input_config=INPUT_CONFIG,
output_config=OUTPUT_CONFIG,
task_id=TASK_ID,
prediction_params={},
)
expected_exception_str = (
"AutoMLBatchPredictOperator for text, image, and video prediction has been "
"deprecated and no longer available"
Expand All @@ -210,20 +211,21 @@ def test_execute_deprecated(self, mock_hook):

@pytest.mark.db_test
def test_templating(self, create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(
AutoMLBatchPredictOperator,
# Templated fields
model_id="{{ 'model' }}",
input_config="{{ 'input-config' }}",
output_config="{{ 'output-config' }}",
location="{{ 'location' }}",
project_id="{{ 'project-id' }}",
impersonation_chain="{{ 'impersonation-chain' }}",
# Other parameters
dag_id="test_template_body_templating_dag",
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
with pytest.warns(AirflowProviderDeprecationWarning):
ti = create_task_instance_of_operator(
AutoMLBatchPredictOperator,
# Templated fields
model_id="{{ 'model' }}",
input_config="{{ 'input-config' }}",
output_config="{{ 'output-config' }}",
location="{{ 'location' }}",
project_id="{{ 'project-id' }}",
impersonation_chain="{{ 'impersonation-chain' }}",
# Other parameters
dag_id="test_template_body_templating_dag",
task_id="test_template_body_templating_task",
execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc),
)
session.add(ti)
session.commit()
ti.render_templates()
Expand Down
Loading