Skip to content

Commit

Permalink
Handle transient state errors in RedshiftResumeClusterOperator and …
Browse files Browse the repository at this point in the history
…`RedshiftPauseClusterOperator` (#27276)

* Modify RedshiftPauseClusterOperator and RedshiftResumeClusterOperator to attempt to pause and resume multiple times to avoid edge cases of state changes
  • Loading branch information
syedahsn authored Nov 17, 2022
1 parent 80ae49e commit 2063e14
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 57 deletions.
15 changes: 12 additions & 3 deletions airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import warnings
from typing import Any, Sequence

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -157,16 +158,24 @@ def create_cluster_snapshot(
)
return response["Snapshot"] if response["Snapshot"] else None

def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str):
def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str | None = None):
"""
Return Redshift cluster snapshot status. If cluster snapshot not found return ``None``
:param snapshot_identifier: A unique identifier for the snapshot that you are requesting
:param cluster_identifier: The unique identifier of the cluster the snapshot was created from
:param cluster_identifier: (deprecated) The unique identifier of the cluster
the snapshot was created from
"""
if cluster_identifier:
warnings.warn(
"Parameter `cluster_identifier` is deprecated."
"This option will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)

try:
response = self.get_conn().describe_cluster_snapshots(
ClusterIdentifier=cluster_identifier,
SnapshotIdentifier=snapshot_identifier,
)
snapshot = response.get("Snapshots")[0]
Expand Down
49 changes: 36 additions & 13 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ def execute(self, context: Context) -> Any:
def get_status(self) -> str:
return self.redshift_hook.get_cluster_snapshot_status(
snapshot_identifier=self.snapshot_identifier,
cluster_identifier=self.cluster_identifier,
)


Expand Down Expand Up @@ -397,15 +396,27 @@ def __init__(
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to resume the cluster despite reporting the correct state.
self._attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "paused":
self.log.info("Starting Redshift cluster %s", self.cluster_identifier)
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
else:
raise Exception(f"Unable to resume cluster - cluster state is {cluster_state}")

while self._attempts >= 1:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error


class RedshiftPauseClusterOperator(BaseOperator):
Expand Down Expand Up @@ -434,15 +445,27 @@ def __init__(
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
self._attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "available":
self.log.info("Pausing Redshift cluster %s", self.cluster_identifier)
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
else:
raise Exception(f"Unable to pause cluster - cluster state is {cluster_state}")

while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
86 changes: 65 additions & 21 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from unittest import mock

import boto3
import pytest

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -172,7 +173,6 @@ def test_delete_cluster_snapshot_wait(self, mock_get_conn, mock_get_cluster_snap
)

mock_get_cluster_snapshot_status.assert_called_once_with(
cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
)

Expand Down Expand Up @@ -205,26 +205,47 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "paused"
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_resume_cluster_not_called_when_cluster_is_not_paused(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "available"
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.resume_cluster.side_effect = [exception, exception, True]
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)
with pytest.raises(Exception):
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 3

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.resume_cluster.side_effect = exception

redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)
with pytest.raises(returned_exception):
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_not_called()
assert mock_conn.resume_cluster.call_count == 10


class TestPauseClusterOperator:
Expand All @@ -236,26 +257,49 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "available"
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn):
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_pause_cluster_not_called_when_cluster_is_not_available(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "paused"
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.pause_cluster.side_effect = [exception, exception, True]

redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)

redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 3

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.pause_cluster.side_effect = exception

redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)
with pytest.raises(Exception):
with pytest.raises(returned_exception):
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_not_called()
assert mock_conn.pause_cluster.call_count == 10


class TestDeleteClusterOperator:
Expand Down
35 changes: 15 additions & 20 deletions tests/system/providers/amazon/aws/example_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
Expand Down Expand Up @@ -90,26 +89,14 @@ def setup_security_group(sec_group_name: str, ip_permissions: list[dict]):
client.authorize_security_group_ingress(
GroupId=security_group["GroupId"], GroupName=sec_group_name, IpPermissions=ip_permissions
)
ti = get_current_context()["ti"]
ti.xcom_push(key="security_group_id", value=security_group["GroupId"])
return security_group["GroupId"]


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_security_group(sec_group_id: str, sec_group_name: str):
boto3.client("ec2").delete_security_group(GroupId=sec_group_id, GroupName=sec_group_name)


@task
def await_cluster_snapshot(cluster_identifier):
waiter = boto3.client("redshift").get_waiter("snapshot_available")
waiter.wait(
ClusterIdentifier=cluster_identifier,
WaiterConfig={
"MaxAttempts": 100,
},
)


with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
Expand All @@ -130,7 +117,7 @@ def await_cluster_snapshot(cluster_identifier):
create_cluster = RedshiftCreateClusterOperator(
task_id="create_cluster",
cluster_identifier=redshift_cluster_identifier,
vpc_security_group_ids=[set_up_sg["security_group_id"]],
vpc_security_group_ids=[set_up_sg],
publicly_accessible=True,
cluster_type="single-node",
node_type="dc2.large",
Expand All @@ -145,7 +132,7 @@ def await_cluster_snapshot(cluster_identifier):
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
timeout=60 * 30,
timeout=60 * 15,
)
# [END howto_sensor_redshift_cluster]

Expand All @@ -161,6 +148,14 @@ def await_cluster_snapshot(cluster_identifier):
)
# [END howto_operator_redshift_create_cluster_snapshot]

wait_cluster_available_before_pause = RedshiftClusterSensor(
task_id="wait_cluster_available_before_pause",
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
timeout=60 * 15,
)

# [START howto_operator_redshift_pause_cluster]
pause_cluster = RedshiftPauseClusterOperator(
task_id="pause_cluster",
Expand All @@ -173,7 +168,7 @@ def await_cluster_snapshot(cluster_identifier):
cluster_identifier=redshift_cluster_identifier,
target_status="paused",
poke_interval=15,
timeout=60 * 30,
timeout=60 * 15,
)

# [START howto_operator_redshift_resume_cluster]
Expand All @@ -188,7 +183,7 @@ def await_cluster_snapshot(cluster_identifier):
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
timeout=60 * 30,
timeout=60 * 15,
)

set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier)
Expand Down Expand Up @@ -269,7 +264,7 @@ def await_cluster_snapshot(cluster_identifier):
# [END howto_operator_redshift_delete_cluster_snapshot]

delete_sg = delete_security_group(
sec_group_id=set_up_sg["security_group_id"],
sec_group_id=set_up_sg,
sec_group_name=sg_name,
)
chain(
Expand All @@ -280,7 +275,7 @@ def await_cluster_snapshot(cluster_identifier):
create_cluster,
wait_cluster_available,
create_cluster_snapshot,
await_cluster_snapshot(redshift_cluster_identifier),
wait_cluster_available_before_pause,
pause_cluster,
wait_cluster_paused,
resume_cluster,
Expand Down

0 comments on commit 2063e14

Please sign in to comment.