Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a backup implementation in AWS MwaaHook for calling the MWAA API #47035

Merged
merged 6 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,7 @@ urls
useHCatalog
useLegacySQL
useQueryCache
userguide
userId
userpass
usr
Expand Down
94 changes: 79 additions & 15 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import requests
from botocore.exceptions import ClientError

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
Expand All @@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):

Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`

If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
setting `generate_local_token` to True. Learn more here:
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API

Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.

Expand All @@ -47,6 +54,7 @@ def invoke_rest_api(
method: str,
body: dict | None = None,
query_params: dict | None = None,
generate_local_token: bool = False,
) -> dict:
"""
Invoke the REST API on the Airflow webserver with the specified inputs.
Expand All @@ -56,30 +64,86 @@ def invoke_rest_api(

:param env_name: name of the MWAA environment
:param path: Apache Airflow REST API endpoint path to be called
:param method: HTTP method used for making Airflow REST API calls
:param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
:param body: Request body for the Apache Airflow REST API call
:param query_params: Query parameters to be included in the Apache Airflow REST API call
:param generate_local_token: If True, only the local web token method is used without trying boto's
`invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
boto's `invoke_rest_api`
"""
body = body or {}
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
body = {k: v for k, v in body.items() if v is not None} if body else {}
query_params = query_params or {}
api_kwargs = {
"Name": env_name,
"Path": path,
"Method": method,
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
"Body": {k: v for k, v in body.items() if v is not None},
"QueryParameters": query_params if query_params else {},
"Body": body,
"QueryParameters": query_params,
}

if generate_local_token:
return self._invoke_rest_api_using_local_session_token(**api_kwargs)

try:
result = self.conn.invoke_rest_api(**api_kwargs)
response = self.conn.invoke_rest_api(**api_kwargs)
# ResponseMetadata is removed because it contains data that is either very unlikely to be useful
# in XComs and logs, or redundant given the data already included in the response
result.pop("ResponseMetadata", None)
return result
response.pop("ResponseMetadata", None)
return response

except ClientError as e:
to_log = e.response
# ResponseMetadata and Error are removed because they contain data that is either very unlikely to
# be useful in XComs and logs, or redundant given the data already included in the response
to_log.pop("ResponseMetadata", None)
to_log.pop("Error", None)
self.log.error(to_log)
raise e
if (
e.response["Error"]["Code"] == "AccessDeniedException"
and "Airflow role" in e.response["Error"]["Message"]
):
self.log.info(
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
)
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
else:
to_log = e.response
# ResponseMetadata is removed because it contains data that is either very unlikely to be
# useful in XComs and logs, or redundant given the data already included in the response
to_log.pop("ResponseMetadata", None)
self.log.error(to_log)
raise

def _invoke_rest_api_using_local_session_token(
self,
**api_kwargs,
) -> dict:
try:
session, hostname = self._get_session_conn(api_kwargs["Name"])

response = session.request(
method=api_kwargs["Method"],
url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
params=api_kwargs["QueryParameters"],
json=api_kwargs["Body"],
timeout=10,
)
response.raise_for_status()

except requests.HTTPError as e:
self.log.error(e.response.json())
raise

return {
"RestApiStatusCode": response.status_code,
"RestApiResponse": response.json(),
}

# Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
def _get_session_conn(self, env_name: str) -> tuple:
create_token_response = self.conn.create_web_login_token(Name=env_name)
web_server_hostname = create_token_response["WebServerHostname"]
web_token = create_token_response["WebToken"]

login_url = f"https://{web_server_hostname}/aws_mwaa/login"
login_payload = {"token": web_token}
session = requests.Session()
login_response = session.post(login_url, data=login_payload, timeout=10)
login_response.raise_for_status()

return session, web_server_hostname
214 changes: 158 additions & 56 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest import mock

import pytest
import requests
from botocore.exceptions import ClientError
from moto import mock_aws

Expand All @@ -27,16 +28,161 @@
ENV_NAME = "test_env"
PATH = "/dags/test_dag/dagRuns"
METHOD = "POST"
BODY: dict = {"conf": {}}
QUERY_PARAMS = {"limit": 30}
HOSTNAME = "example.com"


class TestMwaaHook:
@pytest.fixture
def mock_conn(self):
with mock.patch.object(MwaaHook, "conn") as m:
yield m

def setup_method(self):
self.hook = MwaaHook()

# these example responses are included here instead of as a constant because the hook will mutate
# responses causing subsequent tests to fail
self.example_responses = {
def test_init(self):
assert self.hook.client_type == "mwaa"

@mock_aws
def test_get_conn(self):
assert self.hook.conn is not None

@pytest.mark.parametrize(
"body",
[
pytest.param(None, id="no_body"),
pytest.param(BODY, id="non_empty_body"),
],
)
def test_invoke_rest_api_success(self, body, mock_conn, example_responses):
boto_invoke_mock = mock.MagicMock(return_value=example_responses["success"])
mock_conn.invoke_rest_api = boto_invoke_mock

retval = self.hook.invoke_rest_api(
env_name=ENV_NAME, path=PATH, method=METHOD, body=body, query_params=QUERY_PARAMS
)
kwargs_to_assert = {
"Name": ENV_NAME,
"Path": PATH,
"Method": METHOD,
"Body": body if body else {},
"QueryParameters": QUERY_PARAMS,
}
boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
mock_conn.create_web_login_token.assert_not_called()
assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"}

def test_invoke_rest_api_failure(self, mock_conn, example_responses):
error = ClientError(error_response=example_responses["failure"], operation_name="invoke_rest_api")
mock_conn.invoke_rest_api = mock.MagicMock(side_effect=error)
mock_error_log = mock.MagicMock()
self.hook.log.error = mock_error_log

with pytest.raises(ClientError) as caught_error:
self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD)

assert caught_error.value == error
mock_conn.create_web_login_token.assert_not_called()
expected_log = {k: v for k, v in example_responses["failure"].items() if k != "ResponseMetadata"}
mock_error_log.assert_called_once_with(expected_log)

@pytest.mark.parametrize("generate_local_token", [pytest.param(True), pytest.param(False)])
@mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
def test_invoke_rest_api_local_token_parameter(
self, mock_create_session, generate_local_token, mock_conn
):
self.hook.invoke_rest_api(
env_name=ENV_NAME, path=PATH, method=METHOD, generate_local_token=generate_local_token
)
if generate_local_token:
mock_conn.invoke_rest_api.assert_not_called()
mock_conn.create_web_login_token.assert_called_once()
mock_create_session.assert_called_once()
mock_create_session.return_value.request.assert_called_once()
else:
mock_conn.invoke_rest_api.assert_called_once()

@mock.patch.object(MwaaHook, "_get_session_conn")
def test_invoke_rest_api_fallback_success_when_iam_fails(
self, mock_get_session_conn, mock_conn, example_responses
):
boto_invoke_error = ClientError(
error_response=example_responses["missingIamRole"], operation_name="invoke_rest_api"
)
mock_conn.invoke_rest_api = mock.MagicMock(side_effect=boto_invoke_error)

kwargs_to_assert = {
"method": METHOD,
"url": f"https://{HOSTNAME}/api/v1{PATH}",
"params": QUERY_PARAMS,
"json": BODY,
"timeout": 10,
}

mock_response = mock.MagicMock()
mock_response.status_code = example_responses["success"]["RestApiStatusCode"]
mock_response.json.return_value = example_responses["success"]["RestApiResponse"]
mock_session = mock.MagicMock()
mock_session.request.return_value = mock_response

mock_get_session_conn.return_value = (mock_session, HOSTNAME)

retval = self.hook.invoke_rest_api(
env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY, query_params=QUERY_PARAMS
)

mock_session.request.assert_called_once_with(**kwargs_to_assert)
mock_response.raise_for_status.assert_called_once()
assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"}

@mock.patch.object(MwaaHook, "_get_session_conn")
def test_invoke_rest_api_using_local_session_token_failure(
self, mock_get_session_conn, example_responses
):
mock_response = mock.MagicMock()
mock_response.json.return_value = example_responses["failure"]["RestApiResponse"]
error = requests.HTTPError(response=mock_response)
mock_response.raise_for_status.side_effect = error

mock_session = mock.MagicMock()
mock_session.request.return_value = mock_response

mock_get_session_conn.return_value = (mock_session, HOSTNAME)

mock_error_log = mock.MagicMock()
self.hook.log.error = mock_error_log

with pytest.raises(requests.HTTPError) as caught_error:
self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD, generate_local_token=True)

assert caught_error.value == error
mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"])

@mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
def test_get_session_conn(self, mock_create_session, mock_conn):
token = "token"
mock_conn.create_web_login_token.return_value = {"WebServerHostname": HOSTNAME, "WebToken": token}
login_url = f"https://{HOSTNAME}/aws_mwaa/login"
login_payload = {"token": token}

mock_session = mock.MagicMock()
mock_create_session.return_value = mock_session

retval = self.hook._get_session_conn(env_name=ENV_NAME)

mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME)
mock_create_session.assert_called_once_with()
mock_session.post.assert_called_once_with(login_url, data=login_payload, timeout=10)
mock_session.post.return_value.raise_for_status.assert_called_once()

assert retval == (mock_session, HOSTNAME)

@pytest.fixture
def example_responses(self):
"""Fixture for test responses to avoid mutation between tests."""
return {
"success": {
"ResponseMetadata": {
"RequestId": "some ID",
Expand Down Expand Up @@ -73,57 +219,13 @@ def setup_method(self):
"type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound",
},
},
"missingIamRole": {
"Error": {"Message": "No Airflow role granted in IAM.", "Code": "AccessDeniedException"},
"ResponseMetadata": {
"RequestId": "some ID",
"HTTPStatusCode": 403,
"HTTPHeaders": {"header1": "value1"},
"RetryAttempts": 0,
},
},
}

def test_init(self):
assert self.hook.client_type == "mwaa"

@mock_aws
def test_get_conn(self):
assert self.hook.conn is not None

@pytest.mark.parametrize(
"body",
[
pytest.param(None, id="no_body"),
pytest.param({"conf": {}}, id="non_empty_body"),
],
)
@mock.patch.object(MwaaHook, "conn")
def test_invoke_rest_api_success(self, mock_conn, body) -> None:
boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"])
mock_conn.invoke_rest_api = boto_invoke_mock

retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS)
kwargs_to_assert = {
"Name": ENV_NAME,
"Path": PATH,
"Method": METHOD,
"Body": body if body else {},
"QueryParameters": QUERY_PARAMS,
}
boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
assert retval == {
k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata"
}

@mock.patch.object(MwaaHook, "conn")
def test_invoke_rest_api_failure(self, mock_conn) -> None:
error = ClientError(
error_response=self.example_responses["failure"], operation_name="invoke_rest_api"
)
boto_invoke_mock = mock.MagicMock(side_effect=error)
mock_conn.invoke_rest_api = boto_invoke_mock
mock_log = mock.MagicMock()
self.hook.log.error = mock_log

with pytest.raises(ClientError) as caught_error:
self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD)

assert caught_error.value == error
expected_log = {
k: v
for k, v in self.example_responses["failure"].items()
if k != "ResponseMetadata" and k != "Error"
}
mock_log.assert_called_once_with(expected_log)