Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle transient state errors in RedshiftResumeClusterOperator and RedshiftPauseClusterOperator #27276

Merged
merged 6 commits into from
Nov 17, 2022
15 changes: 12 additions & 3 deletions airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import warnings
from typing import Any, Sequence

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -157,16 +158,24 @@ def create_cluster_snapshot(
)
return response["Snapshot"] if response["Snapshot"] else None

def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str):
def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str | None = None):
"""
Return Redshift cluster snapshot status. If cluster snapshot not found return ``None``

:param snapshot_identifier: A unique identifier for the snapshot that you are requesting
:param cluster_identifier: The unique identifier of the cluster the snapshot was created from
:param cluster_identifier: (deprecated) The unique identifier of the cluster
the snapshot was created from
Comment on lines -165 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please share some context?
The PR description says the motivation is issues with system tests but there should be more to it?
Why are we deprecating this parameter? If boto3 accept this as valid input why should we prevent users from using it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change was made as a result of debugging during testing for the redshift system test. It was not originally in scope of the refactoring of the RedshiftResumeClusterOperator and RedshiftPauseClusterOperator. I included it in this PR because the get_cluster_snapshot_status function does not work as written.

If boto3 accept this as valid input why should we prevent users from using it?

Boto3 does not accept both inputs at the same time. If both are provided, we get an InvalidParameterCombination exception.

Copy link
Contributor

@ferruzzi ferruzzi Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, maybe leave the parameter in and add a block at the beginning of the method along the lines of

if not(cluster_identifier ^ snapshot_identifier):
    raise SomeException("cluster_identifier or snapshot_identifier must be included, but not both")

The ^ is an XOR operator, which returns true if one and only one of those exists.

In [1]: t = True ; f = False

In [2]: t ^ f
Out[2]: True

In [3]: f ^ t
Out[3]: True

In [4]: t ^ t
Out[4]: False

In [5]: f ^ f
Out[5]: False

"""
if cluster_identifier:
warnings.warn(
"Parameter `cluster_identifier` is deprecated."
"This option will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)

try:
response = self.get_conn().describe_cluster_snapshots(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change, I think the purpose of this hook method is to get the snapshot of a cluster WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment. Basically, the describe_cluster_snapshots function requires only one of the 2 parameters.

ClusterIdentifier=cluster_identifier,
SnapshotIdentifier=snapshot_identifier,
)
snapshot = response.get("Snapshots")[0]
Expand Down
49 changes: 36 additions & 13 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ def execute(self, context: Context) -> Any:
def get_status(self) -> str:
return self.redshift_hook.get_cluster_snapshot_status(
snapshot_identifier=self.snapshot_identifier,
cluster_identifier=self.cluster_identifier,
)


Expand Down Expand Up @@ -397,15 +396,27 @@ def __init__(
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to resume the cluster despite reporting the correct state.
self._attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "paused":
self.log.info("Starting Redshift cluster %s", self.cluster_identifier)
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
else:
raise Exception(f"Unable to resume cluster - cluster state is {cluster_state}")

while self._attempts >= 1:
try:
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error


class RedshiftPauseClusterOperator(BaseOperator):
Expand Down Expand Up @@ -434,15 +445,27 @@ def __init__(
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
self._attempts = 10
self._attempt_interval = 15

def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
if cluster_state == "available":
self.log.info("Pausing Redshift cluster %s", self.cluster_identifier)
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
else:
raise Exception(f"Unable to pause cluster - cluster state is {cluster_state}")

while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
86 changes: 65 additions & 21 deletions tests/providers/amazon/aws/operators/test_redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from unittest import mock

import boto3
import pytest

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -172,7 +173,6 @@ def test_delete_cluster_snapshot_wait(self, mock_get_conn, mock_get_cluster_snap
)

mock_get_cluster_snapshot_status.assert_called_once_with(
cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
)

Expand Down Expand Up @@ -205,26 +205,47 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "paused"
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_resume_cluster_not_called_when_cluster_is_not_paused(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "available"
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.resume_cluster.side_effect = [exception, exception, True]
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)
with pytest.raises(Exception):
redshift_operator.execute(None)
assert mock_conn.resume_cluster.call_count == 3

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.resume_cluster.side_effect = exception

redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)
with pytest.raises(returned_exception):
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_not_called()
assert mock_conn.resume_cluster.call_count == 10


class TestPauseClusterOperator:
Expand All @@ -236,26 +257,49 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "available"
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn):
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_pause_cluster_not_called_when_cluster_is_not_available(self, mock_get_conn, mock_cluster_status):
mock_cluster_status.return_value = "paused"
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.pause_cluster.side_effect = [exception, exception, True]

redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)

redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 3

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch("time.sleep", return_value=None)
def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
returned_exception = type(exception)

mock_conn.exceptions.InvalidClusterStateFault = returned_exception
mock_conn.pause_cluster.side_effect = exception

redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
)
with pytest.raises(Exception):
with pytest.raises(returned_exception):
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_not_called()
assert mock_conn.pause_cluster.call_count == 10


class TestDeleteClusterOperator:
Expand Down
35 changes: 15 additions & 20 deletions tests/system/providers/amazon/aws/example_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
Expand Down Expand Up @@ -90,26 +89,14 @@ def setup_security_group(sec_group_name: str, ip_permissions: list[dict]):
client.authorize_security_group_ingress(
GroupId=security_group["GroupId"], GroupName=sec_group_name, IpPermissions=ip_permissions
)
ti = get_current_context()["ti"]
ti.xcom_push(key="security_group_id", value=security_group["GroupId"])
return security_group["GroupId"]


@task(trigger_rule=TriggerRule.ALL_DONE)
def delete_security_group(sec_group_id: str, sec_group_name: str):
boto3.client("ec2").delete_security_group(GroupId=sec_group_id, GroupName=sec_group_name)


@task
def await_cluster_snapshot(cluster_identifier):
waiter = boto3.client("redshift").get_waiter("snapshot_available")
waiter.wait(
ClusterIdentifier=cluster_identifier,
WaiterConfig={
"MaxAttempts": 100,
},
)


with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
Expand All @@ -130,7 +117,7 @@ def await_cluster_snapshot(cluster_identifier):
create_cluster = RedshiftCreateClusterOperator(
task_id="create_cluster",
cluster_identifier=redshift_cluster_identifier,
vpc_security_group_ids=[set_up_sg["security_group_id"]],
vpc_security_group_ids=[set_up_sg],
publicly_accessible=True,
cluster_type="single-node",
node_type="dc2.large",
Expand All @@ -145,7 +132,7 @@ def await_cluster_snapshot(cluster_identifier):
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
timeout=60 * 30,
timeout=60 * 15,
)
# [END howto_sensor_redshift_cluster]

Expand All @@ -161,6 +148,14 @@ def await_cluster_snapshot(cluster_identifier):
)
# [END howto_operator_redshift_create_cluster_snapshot]

wait_cluster_available_before_pause = RedshiftClusterSensor(
task_id="wait_cluster_available_before_pause",
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
timeout=60 * 15,
)

# [START howto_operator_redshift_pause_cluster]
pause_cluster = RedshiftPauseClusterOperator(
task_id="pause_cluster",
Expand All @@ -173,7 +168,7 @@ def await_cluster_snapshot(cluster_identifier):
cluster_identifier=redshift_cluster_identifier,
target_status="paused",
poke_interval=15,
timeout=60 * 30,
timeout=60 * 15,
)

# [START howto_operator_redshift_resume_cluster]
Expand All @@ -188,7 +183,7 @@ def await_cluster_snapshot(cluster_identifier):
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
timeout=60 * 30,
timeout=60 * 15,
)

set_up_connection = create_connection(conn_id_name, cluster_id=redshift_cluster_identifier)
Expand Down Expand Up @@ -269,7 +264,7 @@ def await_cluster_snapshot(cluster_identifier):
# [END howto_operator_redshift_delete_cluster_snapshot]

delete_sg = delete_security_group(
sec_group_id=set_up_sg["security_group_id"],
sec_group_id=set_up_sg,
sec_group_name=sg_name,
)
chain(
Expand All @@ -280,7 +275,7 @@ def await_cluster_snapshot(cluster_identifier):
create_cluster,
wait_cluster_available,
create_cluster_snapshot,
await_cluster_snapshot(redshift_cluster_identifier),
wait_cluster_available_before_pause,
pause_cluster,
wait_cluster_paused,
resume_cluster,
Expand Down