diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py index 43d7993af73c5..d85929d062024 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py +++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import warnings from typing import Any, Sequence from botocore.exceptions import ClientError @@ -157,16 +158,24 @@ def create_cluster_snapshot( ) return response["Snapshot"] if response["Snapshot"] else None - def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str): + def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str | None = None): """ Return Redshift cluster snapshot status. If cluster snapshot not found return ``None`` :param snapshot_identifier: A unique identifier for the snapshot that you are requesting - :param cluster_identifier: The unique identifier of the cluster the snapshot was created from + :param cluster_identifier: (deprecated) The unique identifier of the cluster + the snapshot was created from """ + if cluster_identifier: + warnings.warn( + "Parameter `cluster_identifier` is deprecated." + "This option will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + try: response = self.get_conn().describe_cluster_snapshots( - ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier, ) snapshot = response.get("Snapshots")[0] diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py index 39515b713781c..2da0fbf23a73d 100644 --- a/airflow/providers/amazon/aws/operators/redshift_cluster.py +++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py @@ -367,7 +367,6 @@ def execute(self, context: Context) -> Any: def get_status(self) -> str: return self.redshift_hook.get_cluster_snapshot_status( snapshot_identifier=self.snapshot_identifier, - cluster_identifier=self.cluster_identifier, ) @@ -397,15 +396,27 @@ def __init__( super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id + # These parameters are added to address an issue with the boto3 API where the API + # prematurely reports the cluster as available to receive requests. This causes the cluster + # to reject initial attempts to resume the cluster despite reporting the correct state. + self._attempts = 10 + self._attempt_interval = 15 def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) - if cluster_state == "paused": - self.log.info("Starting Redshift cluster %s", self.cluster_identifier) - redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) - else: - raise Exception(f"Unable to resume cluster - cluster state is {cluster_state}") + + while self._attempts >= 1: + try: + redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier) + return + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._attempts = self._attempts - 1 + + if self._attempts > 0: + self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts) + time.sleep(self._attempt_interval) + else: + raise error class RedshiftPauseClusterOperator(BaseOperator): @@ -434,15 +445,27 @@ def __init__( super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.aws_conn_id = aws_conn_id + # These parameters are added to address an issue with the boto3 API where the API + # prematurely reports the cluster as available to receive requests. This causes the cluster + # to reject initial attempts to pause the cluster despite reporting the correct state. + self._attempts = 10 + self._attempt_interval = 15 def execute(self, context: Context): redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id) - cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier) - if cluster_state == "available": - self.log.info("Pausing Redshift cluster %s", self.cluster_identifier) - redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) - else: - raise Exception(f"Unable to pause cluster - cluster state is {cluster_state}") + + while self._attempts >= 1: + try: + redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier) + return + except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error: + self._attempts = self._attempts - 1 + + if self._attempts > 0: + self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts) + time.sleep(self._attempt_interval) + else: + raise error class RedshiftDeleteClusterOperator(BaseOperator): diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py b/tests/providers/amazon/aws/operators/test_redshift_cluster.py index c10daa9a70f37..5a9322c9c9b96 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py @@ -18,6 +18,7 @@ from unittest import mock +import boto3 import pytest from airflow.exceptions import AirflowException @@ -172,7 +173,6 @@ def test_delete_cluster_snapshot_wait(self, mock_get_conn, mock_get_cluster_snap ) mock_get_cluster_snapshot_status.assert_called_once_with( - cluster_identifier="test_cluster", snapshot_identifier="test_snapshot", ) @@ -205,26 +205,47 @@ def test_init(self): assert redshift_operator.cluster_identifier == "test_cluster" assert redshift_operator.aws_conn_id == "aws_conn_test" - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") - def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn, mock_cluster_status): - mock_cluster_status.return_value = "paused" + def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn): redshift_operator = RedshiftResumeClusterOperator( task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" ) redshift_operator.execute(None) mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") - def test_resume_cluster_not_called_when_cluster_is_not_paused(self, mock_get_conn, mock_cluster_status): - mock_cluster_status.return_value = "available" + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn") + @mock.patch("time.sleep", return_value=None) + def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn): + exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test") + returned_exception = type(exception) + + mock_conn.exceptions.InvalidClusterStateFault = returned_exception + mock_conn.resume_cluster.side_effect = [exception, exception, True] redshift_operator = RedshiftResumeClusterOperator( - task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", ) - with pytest.raises(Exception): + redshift_operator.execute(None) + assert mock_conn.resume_cluster.call_count == 3 + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn") + @mock.patch("time.sleep", return_value=None) + def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): + exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test") + returned_exception = type(exception) + + mock_conn.exceptions.InvalidClusterStateFault = returned_exception + mock_conn.resume_cluster.side_effect = exception + + redshift_operator = RedshiftResumeClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", + ) + with pytest.raises(returned_exception): redshift_operator.execute(None) - mock_get_conn.return_value.resume_cluster.assert_not_called() + assert mock_conn.resume_cluster.call_count == 10 class TestPauseClusterOperator: @@ -236,26 +257,49 @@ def test_init(self): assert redshift_operator.cluster_identifier == "test_cluster" assert redshift_operator.aws_conn_id == "aws_conn_test" - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") - def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn, mock_cluster_status): - mock_cluster_status.return_value = "available" + def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn): redshift_operator = RedshiftPauseClusterOperator( task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" ) redshift_operator.execute(None) mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status") - @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn") - def test_pause_cluster_not_called_when_cluster_is_not_available(self, mock_get_conn, mock_cluster_status): - mock_cluster_status.return_value = "paused" + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn") + @mock.patch("time.sleep", return_value=None) + def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn): + exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test") + returned_exception = type(exception) + + mock_conn.exceptions.InvalidClusterStateFault = returned_exception + mock_conn.pause_cluster.side_effect = [exception, exception, True] + redshift_operator = RedshiftPauseClusterOperator( - task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test" + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", + ) + + redshift_operator.execute(None) + assert mock_conn.pause_cluster.call_count == 3 + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn") + @mock.patch("time.sleep", return_value=None) + def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn): + exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test") + returned_exception = type(exception) + + mock_conn.exceptions.InvalidClusterStateFault = returned_exception + mock_conn.pause_cluster.side_effect = exception + + redshift_operator = RedshiftPauseClusterOperator( + task_id="task_test", + cluster_identifier="test_cluster", + aws_conn_id="aws_conn_test", ) - with pytest.raises(Exception): + with pytest.raises(returned_exception): redshift_operator.execute(None) - mock_get_conn.return_value.pause_cluster.assert_not_called() + assert mock_conn.pause_cluster.call_count == 10 class TestDeleteClusterOperator: diff --git a/tests/system/providers/amazon/aws/example_redshift.py b/tests/system/providers/amazon/aws/example_redshift.py index 0b8b573c89ead..5a355f15f4e04 100644 --- a/tests/system/providers/amazon/aws/example_redshift.py +++ b/tests/system/providers/amazon/aws/example_redshift.py @@ -26,7 +26,6 @@ from airflow.decorators import task from airflow.models import Connection from airflow.models.baseoperator import chain -from airflow.operators.python import get_current_context from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.operators.redshift_cluster import ( RedshiftCreateClusterOperator, @@ -90,8 +89,7 @@ def setup_security_group(sec_group_name: str, ip_permissions: list[dict]): client.authorize_security_group_ingress( GroupId=security_group["GroupId"], GroupName=sec_group_name, IpPermissions=ip_permissions ) - ti = get_current_context()["ti"] - ti.xcom_push(key="security_group_id", value=security_group["GroupId"]) + return security_group["GroupId"] @task(trigger_rule=TriggerRule.ALL_DONE) @@ -99,17 +97,6 @@ def delete_security_group(sec_group_id: str, sec_group_name: str): boto3.client("ec2").delete_security_group(GroupId=sec_group_id, GroupName=sec_group_name) -@task -def await_cluster_snapshot(cluster_identifier): - waiter = boto3.client("redshift").get_waiter("snapshot_available") - waiter.wait( - ClusterIdentifier=cluster_identifier, - WaiterConfig={ - "MaxAttempts": 100, - }, - ) - - with DAG( dag_id=DAG_ID, start_date=datetime(2021, 1, 1), @@ -130,7 +117,7 @@ def await_cluster_snapshot(cluster_identifier): create_cluster = RedshiftCreateClusterOperator( task_id="create_cluster", cluster_identifier=redshift_cluster_identifier, - vpc_security_group_ids=[set_up_sg["security_group_id"]], + vpc_security_group_ids=[set_up_sg], publicly_accessible=True, cluster_type="single-node", node_type="dc2.large", @@ -145,7 +132,7 @@ def await_cluster_snapshot(cluster_identifier): cluster_identifier=redshift_cluster_identifier, target_status="available", poke_interval=15, - timeout=60 * 30, + timeout=60 * 15, ) # [END howto_sensor_redshift_cluster] @@ -161,6 +148,14 @@ def await_cluster_snapshot(cluster_identifier): ) # [END howto_operator_redshift_create_cluster_snapshot] + wait_cluster_available_before_pause = RedshiftClusterSensor( + task_id="wait_cluster_available_before_pause", + cluster_identifier=redshift_cluster_identifier, + target_status="available", + poke_interval=15, + timeout=60 * 15, + ) + # [START howto_operator_redshift_pause_cluster] pause_cluster = RedshiftPauseClusterOperator( task_id="pause_cluster", @@ -173,7 +168,7 @@ def await_cluster_snapshot(cluster_identifier): cluster_identifier=redshift_cluster_identifier, target_status="paused", poke_interval=15, - timeout=60 * 30, + timeout=60 * 15, ) # [START howto_operator_redshift_resume_cluster] @@ -188,7 +183,7 @@ def await_cluster_snapshot(cluster_identifier): cluster_identifier=redshift_cluster_identifier, target_status="available", poke_interval=15, - timeout=60 * 30, + timeout=60 * 15, ) set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier) @@ -269,7 +264,7 @@ def await_cluster_snapshot(cluster_identifier): # [END howto_operator_redshift_delete_cluster_snapshot] delete_sg = delete_security_group( - sec_group_id=set_up_sg["security_group_id"], + sec_group_id=set_up_sg, sec_group_name=sg_name, ) chain( @@ -280,7 +275,7 @@ def await_cluster_snapshot(cluster_identifier): create_cluster, wait_cluster_available, create_cluster_snapshot, - await_cluster_snapshot(redshift_cluster_identifier), + wait_cluster_available_before_pause, pause_cluster, wait_cluster_paused, resume_cluster,