From 7529a9feb7eccfc657e3df9423b2b46d189ab11f Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 24 Feb 2025 10:06:01 -0800 Subject: [PATCH 1/6] Add a backup implementation in AWS MwaaHook for calling the MWAA API The existing implementation doesn't work when the user doesn't have `airflow:InvokeRestApi` permission in their IAM policy or when they make more than 10 transactions per second. This implementation mitigates those issues by using a session token approach. However, my existing implementation is still used by default because it is simpler. Some context here: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html --- .../providers/amazon/aws/hooks/mwaa.py | 91 ++++++++++++++++--- .../tests/unit/amazon/aws/hooks/test_mwaa.py | 6 +- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index d7f01238e6ab8..1890e6762f116 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -18,6 +18,9 @@ from __future__ import annotations +from typing import Any + +import requests from botocore.exceptions import ClientError from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -29,6 +32,10 @@ class MwaaHook(AwsBaseHook): Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` + If your IAM policy doesn't have airflow:InvokeRestApi permission or if you reach throttling capacity, the + hook will use a session token to make the requests. Learn more here: + https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API + Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. @@ -39,6 +46,7 @@ class MwaaHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "mwaa" super().__init__(*args, **kwargs) + self._env_to_session_conn_map: dict[str, dict[str, Any]] = {} def invoke_rest_api( self, @@ -56,30 +64,83 @@ def invoke_rest_api( :param env_name: name of the MWAA environment :param path: Apache Airflow REST API endpoint path to be called - :param method: HTTP method used for making Airflow REST API calls + :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE' :param body: Request body for the Apache Airflow REST API call :param query_params: Query parameters to be included in the Apache Airflow REST API call """ - body = body or {} + # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise + body = {k: v for k, v in body.items() if v is not None} if body else {} + query_params = query_params or {} api_kwargs = { "Name": env_name, "Path": path, "Method": method, - # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise - "Body": {k: v for k, v in body.items() if v is not None}, - "QueryParameters": query_params if query_params else {}, + "Body": body, + "QueryParameters": query_params, } try: - result = self.conn.invoke_rest_api(**api_kwargs) + response = self.conn.invoke_rest_api(**api_kwargs) # ResponseMetadata is removed because it contains data that is either very unlikely to be useful # in XComs and logs, or redundant given the data already included in the response - result.pop("ResponseMetadata", None) - return result + response.pop("ResponseMetadata", None) + return response + except ClientError as e: - to_log = e.response - # ResponseMetadata and Error are removed because they contain data that is either very unlikely to - # be useful in XComs and logs, or redundant given the data already included in the response - to_log.pop("ResponseMetadata", None) - to_log.pop("Error", None) - self.log.error(to_log) - raise e + if e.response["Error"]["Code"] == "AccessDeniedException": + self.log.info( + "Access Denied, possibly due to missing airflow:InvokeRestApi in IAM policy. " + "Trying again with session token..." + ) + return self._invoke_rest_api_using_session_token(env_name, path, method, body, query_params) + else: + to_log = e.response + # ResponseMetadata is removed because it contains data that is either very unlikely to be + # useful in XComs and logs, or redundant given the data already included in the response + to_log.pop("ResponseMetadata", None) + self.log.error(to_log) + raise e + + def _invoke_rest_api_using_session_token( + self, + env_name: str, + path: str, + method: str, + body: dict | None = None, + query_params: dict | None = None, + ) -> dict: + def try_request(): + conn_info = self._env_to_session_conn_map[env_name] + response = conn_info["session"].request( + method=method, + url=f"https://{conn_info['hostname']}/api/v1{path}", + params=query_params, + json=body, + timeout=10, + ) + response.raise_for_status() + return response + + try: + response = try_request() + except (requests.exceptions.HTTPError, KeyError): + self._update_session_conn(env_name) + response = try_request() + + return { + "RestApiStatusCode": response.status_code, + "RestApiResponse": response.json(), + } + + # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token + def _update_session_conn(self, env_name: str): + create_token_response = self.conn.create_web_login_token(Name=env_name) + web_server_hostname = create_token_response["WebServerHostname"] + web_token = create_token_response["WebToken"] + + login_url = f"https://{web_server_hostname}/aws_mwaa/login" + login_payload = {"token": web_token} + session = requests.Session() + login_response = session.post(login_url, data=login_payload, timeout=10) + login_response.raise_for_status() + + self._env_to_session_conn_map[env_name] = {"session": session, "hostname": web_server_hostname} diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py index 461e325891240..2a6da2d8f05e9 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py @@ -121,9 +121,5 @@ def test_invoke_rest_api_failure(self, mock_conn) -> None: self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD) assert caught_error.value == error - expected_log = { - k: v - for k, v in self.example_responses["failure"].items() - if k != "ResponseMetadata" and k != "Error" - } + expected_log = {k: v for k, v in self.example_responses["failure"].items() if k != "ResponseMetadata"} mock_log.assert_called_once_with(expected_log) From e0a744a6bd88969a035acb478aa9cf8d55353cfc Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Tue, 25 Feb 2025 14:55:03 -0800 Subject: [PATCH 2/6] Fix spelling errors --- docs/spelling_wordlist.txt | 1 + providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 87d1909eda62e..ff16e3f01cafe 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1863,6 +1863,7 @@ urls useHCatalog useLegacySQL useQueryCache +userguide userId userpass usr diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index 1890e6762f116..5a006df0a2a0c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -32,7 +32,7 @@ class MwaaHook(AwsBaseHook): Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` - If your IAM policy doesn't have airflow:InvokeRestApi permission or if you reach throttling capacity, the + If your IAM policy doesn't have `airflow:InvokeRestApi` permission or if you reach throttling capacity, the hook will use a session token to make the requests. Learn more here: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API From 7fe8c87f2d8a8d39081a6cd2db68fceed62e0106 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 28 Feb 2025 14:48:31 -0800 Subject: [PATCH 3/6] Add fallback parameter and remove storing of sessions in hook --- .../providers/amazon/aws/hooks/mwaa.py | 57 +++++++++---------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index 5a006df0a2a0c..c1e709a5c982f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -18,8 +18,6 @@ from __future__ import annotations -from typing import Any - import requests from botocore.exceptions import ClientError @@ -32,8 +30,10 @@ class MwaaHook(AwsBaseHook): Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` - If your IAM policy doesn't have `airflow:InvokeRestApi` permission or if you reach throttling capacity, the - hook will use a session token to make the requests. Learn more here: + If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method + that uses the AWS credential to create a web login token for the Airflow Web UI and then directly make + requests to the Airflow API. This fallback method can be set as the default (and only) method used by + setting `only_use_web_login` to True. Learn more here: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API Additional arguments (such as ``aws_conn_id``) may be specified and @@ -46,7 +46,6 @@ class MwaaHook(AwsBaseHook): def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "mwaa" super().__init__(*args, **kwargs) - self._env_to_session_conn_map: dict[str, dict[str, Any]] = {} def invoke_rest_api( self, @@ -55,6 +54,7 @@ def invoke_rest_api( method: str, body: dict | None = None, query_params: dict | None = None, + only_use_web_login: bool = False, ) -> dict: """ Invoke the REST API on the Airflow webserver with the specified inputs. @@ -67,6 +67,8 @@ def invoke_rest_api( :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE' :param body: Request body for the Apache Airflow REST API call :param query_params: Query parameters to be included in the Apache Airflow REST API call + :param only_use_web_login: If True, only the web login method is used without trying boto's + invoke_rest_api first """ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise body = {k: v for k, v in body.items() if v is not None} if body else {} @@ -78,6 +80,10 @@ def invoke_rest_api( "Body": body, "QueryParameters": query_params, } + + if only_use_web_login: + return self._invoke_rest_api_using_web_login(**api_kwargs) + try: response = self.conn.invoke_rest_api(**api_kwargs) # ResponseMetadata is removed because it contains data that is either very unlikely to be useful @@ -89,9 +95,9 @@ def invoke_rest_api( if e.response["Error"]["Code"] == "AccessDeniedException": self.log.info( "Access Denied, possibly due to missing airflow:InvokeRestApi in IAM policy. " - "Trying again with session token..." + "Trying again using web login..." ) - return self._invoke_rest_api_using_session_token(env_name, path, method, body, query_params) + return self._invoke_rest_api_using_web_login(**api_kwargs) else: to_log = e.response # ResponseMetadata is removed because it contains data that is either very unlikely to be @@ -100,31 +106,20 @@ def invoke_rest_api( self.log.error(to_log) raise e - def _invoke_rest_api_using_session_token( + def _invoke_rest_api_using_web_login( self, - env_name: str, - path: str, - method: str, - body: dict | None = None, - query_params: dict | None = None, + **api_kwargs, ) -> dict: - def try_request(): - conn_info = self._env_to_session_conn_map[env_name] - response = conn_info["session"].request( - method=method, - url=f"https://{conn_info['hostname']}/api/v1{path}", - params=query_params, - json=body, - timeout=10, - ) - response.raise_for_status() - return response + session, hostname = self._get_session_conn(api_kwargs["Name"]) - try: - response = try_request() - except (requests.exceptions.HTTPError, KeyError): - self._update_session_conn(env_name) - response = try_request() + response = session.request( + method=api_kwargs["Method"], + url=f"https://{hostname}/api/v1{api_kwargs['Path']}", + params=api_kwargs["QueryParameters"], + json=api_kwargs["Body"], + timeout=10, + ) + response.raise_for_status() return { "RestApiStatusCode": response.status_code, @@ -132,7 +127,7 @@ def try_request(): } # Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token - def _update_session_conn(self, env_name: str): + def _get_session_conn(self, env_name: str) -> tuple: create_token_response = self.conn.create_web_login_token(Name=env_name) web_server_hostname = create_token_response["WebServerHostname"] web_token = create_token_response["WebToken"] @@ -143,4 +138,4 @@ def _update_session_conn(self, env_name: str): login_response = session.post(login_url, data=login_payload, timeout=10) login_response.raise_for_status() - self._env_to_session_conn_map[env_name] = {"session": session, "hostname": web_server_hostname} + return session, web_server_hostname From 67ffd9431f6924e95d6303240338facfad409c50 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Fri, 28 Feb 2025 19:38:06 -0800 Subject: [PATCH 4/6] Add unit tests for added functionality in MwaaHook --- .../tests/unit/amazon/aws/hooks/test_mwaa.py | 193 +++++++++++++----- 1 file changed, 141 insertions(+), 52 deletions(-) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py index 2a6da2d8f05e9..414bb83ec74c5 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py @@ -27,16 +27,145 @@ ENV_NAME = "test_env" PATH = "/dags/test_dag/dagRuns" METHOD = "POST" +BODY: dict = {"conf": {}} QUERY_PARAMS = {"limit": 30} +HOSTNAME = "example.com" class TestMwaaHook: + @pytest.fixture + def mock_conn(self): + with mock.patch.object(MwaaHook, "conn") as m: + yield m + def setup_method(self): self.hook = MwaaHook() - # these example responses are included here instead of as a constant because the hook will mutate - # responses causing subsequent tests to fail - self.example_responses = { + def test_init(self): + assert self.hook.client_type == "mwaa" + + @mock_aws + def test_get_conn(self): + assert self.hook.conn is not None + + @pytest.mark.parametrize( + "body", + [ + pytest.param(None, id="no_body"), + pytest.param(BODY, id="non_empty_body"), + ], + ) + def test_invoke_rest_api_success(self, body, mock_conn, example_responses) -> None: + boto_invoke_mock = mock.MagicMock(return_value=example_responses["success"]) + mock_conn.invoke_rest_api = boto_invoke_mock + + retval = self.hook.invoke_rest_api( + env_name=ENV_NAME, path=PATH, method=METHOD, body=body, query_params=QUERY_PARAMS + ) + kwargs_to_assert = { + "Name": ENV_NAME, + "Path": PATH, + "Method": METHOD, + "Body": body if body else {}, + "QueryParameters": QUERY_PARAMS, + } + boto_invoke_mock.assert_called_once_with(**kwargs_to_assert) + mock_conn.create_web_login_token.assert_not_called() + assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"} + + def test_invoke_rest_api_failure(self, mock_conn, example_responses) -> None: + error = ClientError(error_response=example_responses["failure"], operation_name="invoke_rest_api") + mock_conn.invoke_rest_api = mock.MagicMock(side_effect=error) + mock_log = mock.MagicMock() + self.hook.log.error = mock_log + + with pytest.raises(ClientError) as caught_error: + self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD) + + assert caught_error.value == error + mock_conn.create_web_login_token.assert_not_called() + expected_log = {k: v for k, v in example_responses["failure"].items() if k != "ResponseMetadata"} + mock_log.assert_called_once_with(expected_log) + + @pytest.mark.parametrize("only_use_web_login", [pytest.param(True), pytest.param(False)]) + @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") + def test_invoke_rest_api_web_login_parameter( + self, mock_create_session, only_use_web_login, mock_conn + ) -> None: + self.hook.invoke_rest_api( + env_name=ENV_NAME, path=PATH, method=METHOD, only_use_web_login=only_use_web_login + ) + if only_use_web_login: + mock_conn.invoke_rest_api.assert_not_called() + mock_conn.create_web_login_token.assert_called_once() + mock_create_session.assert_called_once() + mock_create_session.return_value.request.assert_called_once() + else: + mock_conn.invoke_rest_api.assert_called_once() + + @mock.patch.object(MwaaHook, "_get_session_conn") + def test_invoke_rest_api_fallback_to_web_login_when_iam_fails( + self, mock_get_session_conn, mock_conn, example_responses + ) -> None: + boto_invoke_error = ClientError( + error_response=example_responses["missingIamRole"], operation_name="invoke_rest_api" + ) + mock_conn.invoke_rest_api = mock.MagicMock(side_effect=boto_invoke_error) + + kwargs_to_assert = { + "method": METHOD, + "url": f"https://{HOSTNAME}/api/v1{PATH}", + "params": QUERY_PARAMS, + "json": BODY, + "timeout": 10, + } + + mock_response = mock.MagicMock() + mock_response.status_code = example_responses["success"]["RestApiStatusCode"] + mock_response.json.return_value = example_responses["success"]["RestApiResponse"] + + mock_session = mock.MagicMock() + mock_session.request.return_value = mock_response + + mock_get_session_conn.return_value = (mock_session, HOSTNAME) + + retval = self.hook.invoke_rest_api( + env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY, query_params=QUERY_PARAMS + ) + mock_session.request.assert_called_once_with(**kwargs_to_assert) + + # Since the implementation doesn't have separate branches for success and failure cases, this takes + # care of testing both the success and failure cases + mock_response.raise_for_status.assert_called_once() + + assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"} + + @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") + def test_web_login_get_session_conn(self, mock_create_session, mock_conn): + token = "token" + mock_conn.create_web_login_token.return_value = {"WebServerHostname": HOSTNAME, "WebToken": token} + login_url = f"https://{HOSTNAME}/aws_mwaa/login" + login_payload = {"token": token} + + mock_session = mock.MagicMock() + mock_create_session.return_value = mock_session + + retval = self.hook._get_session_conn(env_name=ENV_NAME) + + mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME) + mock_create_session.assert_called_once_with() + mock_session.post.assert_called_once_with(login_url, data=login_payload, timeout=10) + + # Since the implementation doesn't have separate branches for success and failure cases, this takes + # care of testing both the success and failure cases + mock_session.post.return_value.raise_for_status.assert_called_once() + + assert retval == (mock_session, HOSTNAME) + + @pytest.fixture + def example_responses(self): + """Fixture for test responses to avoid mutation between tests.""" + return { "success": { "ResponseMetadata": { "RequestId": "some ID", @@ -73,53 +202,13 @@ def setup_method(self): "type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound", }, }, + "missingIamRole": { + "Error": {"Message": "No Airflow role granted in IAM.", "Code": "AccessDeniedException"}, + "ResponseMetadata": { + "RequestId": "some ID", + "HTTPStatusCode": 403, + "HTTPHeaders": {"header1": "value1"}, + "RetryAttempts": 0, + }, + }, } - - def test_init(self): - assert self.hook.client_type == "mwaa" - - @mock_aws - def test_get_conn(self): - assert self.hook.conn is not None - - @pytest.mark.parametrize( - "body", - [ - pytest.param(None, id="no_body"), - pytest.param({"conf": {}}, id="non_empty_body"), - ], - ) - @mock.patch.object(MwaaHook, "conn") - def test_invoke_rest_api_success(self, mock_conn, body) -> None: - boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"]) - mock_conn.invoke_rest_api = boto_invoke_mock - - retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS) - kwargs_to_assert = { - "Name": ENV_NAME, - "Path": PATH, - "Method": METHOD, - "Body": body if body else {}, - "QueryParameters": QUERY_PARAMS, - } - boto_invoke_mock.assert_called_once_with(**kwargs_to_assert) - assert retval == { - k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata" - } - - @mock.patch.object(MwaaHook, "conn") - def test_invoke_rest_api_failure(self, mock_conn) -> None: - error = ClientError( - error_response=self.example_responses["failure"], operation_name="invoke_rest_api" - ) - boto_invoke_mock = mock.MagicMock(side_effect=error) - mock_conn.invoke_rest_api = boto_invoke_mock - mock_log = mock.MagicMock() - self.hook.log.error = mock_log - - with pytest.raises(ClientError) as caught_error: - self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD) - - assert caught_error.value == error - expected_log = {k: v for k, v in self.example_responses["failure"].items() if k != "ResponseMetadata"} - mock_log.assert_called_once_with(expected_log) From 434328d2363e9d5d216ca1d108d8c7764239ef56 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 3 Mar 2025 18:24:17 -0800 Subject: [PATCH 5/6] Improved error handling and parameter naming --- .../providers/amazon/aws/hooks/mwaa.py | 55 +++++++++------- .../tests/unit/amazon/aws/hooks/test_mwaa.py | 63 ++++++++++++------- 2 files changed, 72 insertions(+), 46 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index c1e709a5c982f..77500725ac9d3 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -31,9 +31,9 @@ class MwaaHook(AwsBaseHook): Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") ` If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method - that uses the AWS credential to create a web login token for the Airflow Web UI and then directly make - requests to the Airflow API. This fallback method can be set as the default (and only) method used by - setting `only_use_web_login` to True. Learn more here: + that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly + make requests to the Airflow API. This fallback method can be set as the default (and only) method used by + setting `only_generate_local_token` to True. Learn more here: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API Additional arguments (such as ``aws_conn_id``) may be specified and @@ -54,7 +54,7 @@ def invoke_rest_api( method: str, body: dict | None = None, query_params: dict | None = None, - only_use_web_login: bool = False, + only_generate_local_token: bool = False, ) -> dict: """ Invoke the REST API on the Airflow webserver with the specified inputs. @@ -67,8 +67,8 @@ def invoke_rest_api( :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE' :param body: Request body for the Apache Airflow REST API call :param query_params: Query parameters to be included in the Apache Airflow REST API call - :param only_use_web_login: If True, only the web login method is used without trying boto's - invoke_rest_api first + :param only_generate_local_token: If True, only the local web login token method is used without + trying boto's invoke_rest_api first """ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise body = {k: v for k, v in body.items() if v is not None} if body else {} @@ -81,8 +81,8 @@ def invoke_rest_api( "QueryParameters": query_params, } - if only_use_web_login: - return self._invoke_rest_api_using_web_login(**api_kwargs) + if only_generate_local_token: + return self._invoke_rest_api_using_local_session_token(**api_kwargs) try: response = self.conn.invoke_rest_api(**api_kwargs) @@ -92,34 +92,41 @@ def invoke_rest_api( return response except ClientError as e: - if e.response["Error"]["Code"] == "AccessDeniedException": + if ( + e.response["Error"]["Code"] == "AccessDeniedException" + and "Airflow role" in e.response["Error"]["Message"] + ): self.log.info( - "Access Denied, possibly due to missing airflow:InvokeRestApi in IAM policy. " - "Trying again using web login..." + "Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..." ) - return self._invoke_rest_api_using_web_login(**api_kwargs) + return self._invoke_rest_api_using_local_session_token(**api_kwargs) else: to_log = e.response # ResponseMetadata is removed because it contains data that is either very unlikely to be # useful in XComs and logs, or redundant given the data already included in the response to_log.pop("ResponseMetadata", None) self.log.error(to_log) - raise e + raise - def _invoke_rest_api_using_web_login( + def _invoke_rest_api_using_local_session_token( self, **api_kwargs, ) -> dict: - session, hostname = self._get_session_conn(api_kwargs["Name"]) - - response = session.request( - method=api_kwargs["Method"], - url=f"https://{hostname}/api/v1{api_kwargs['Path']}", - params=api_kwargs["QueryParameters"], - json=api_kwargs["Body"], - timeout=10, - ) - response.raise_for_status() + try: + session, hostname = self._get_session_conn(api_kwargs["Name"]) + + response = session.request( + method=api_kwargs["Method"], + url=f"https://{hostname}/api/v1{api_kwargs['Path']}", + params=api_kwargs["QueryParameters"], + json=api_kwargs["Body"], + timeout=10, + ) + response.raise_for_status() + + except requests.HTTPError as e: + self.log.error(e.response.json()) + raise return { "RestApiStatusCode": response.status_code, diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py index 414bb83ec74c5..f7d3071657c86 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py @@ -19,6 +19,7 @@ from unittest import mock import pytest +import requests from botocore.exceptions import ClientError from moto import mock_aws @@ -55,7 +56,7 @@ def test_get_conn(self): pytest.param(BODY, id="non_empty_body"), ], ) - def test_invoke_rest_api_success(self, body, mock_conn, example_responses) -> None: + def test_invoke_rest_api_success(self, body, mock_conn, example_responses): boto_invoke_mock = mock.MagicMock(return_value=example_responses["success"]) mock_conn.invoke_rest_api = boto_invoke_mock @@ -73,11 +74,11 @@ def test_invoke_rest_api_success(self, body, mock_conn, example_responses) -> No mock_conn.create_web_login_token.assert_not_called() assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"} - def test_invoke_rest_api_failure(self, mock_conn, example_responses) -> None: + def test_invoke_rest_api_failure(self, mock_conn, example_responses): error = ClientError(error_response=example_responses["failure"], operation_name="invoke_rest_api") mock_conn.invoke_rest_api = mock.MagicMock(side_effect=error) - mock_log = mock.MagicMock() - self.hook.log.error = mock_log + mock_error_log = mock.MagicMock() + self.hook.log.error = mock_error_log with pytest.raises(ClientError) as caught_error: self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD) @@ -85,17 +86,17 @@ def test_invoke_rest_api_failure(self, mock_conn, example_responses) -> None: assert caught_error.value == error mock_conn.create_web_login_token.assert_not_called() expected_log = {k: v for k, v in example_responses["failure"].items() if k != "ResponseMetadata"} - mock_log.assert_called_once_with(expected_log) + mock_error_log.assert_called_once_with(expected_log) - @pytest.mark.parametrize("only_use_web_login", [pytest.param(True), pytest.param(False)]) + @pytest.mark.parametrize("only_generate_local_token", [pytest.param(True), pytest.param(False)]) @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") - def test_invoke_rest_api_web_login_parameter( - self, mock_create_session, only_use_web_login, mock_conn - ) -> None: + def test_invoke_rest_api_local_token_parameter( + self, mock_create_session, only_generate_local_token, mock_conn + ): self.hook.invoke_rest_api( - env_name=ENV_NAME, path=PATH, method=METHOD, only_use_web_login=only_use_web_login + env_name=ENV_NAME, path=PATH, method=METHOD, only_generate_local_token=only_generate_local_token ) - if only_use_web_login: + if only_generate_local_token: mock_conn.invoke_rest_api.assert_not_called() mock_conn.create_web_login_token.assert_called_once() mock_create_session.assert_called_once() @@ -104,9 +105,9 @@ def test_invoke_rest_api_web_login_parameter( mock_conn.invoke_rest_api.assert_called_once() @mock.patch.object(MwaaHook, "_get_session_conn") - def test_invoke_rest_api_fallback_to_web_login_when_iam_fails( + def test_invoke_rest_api_fallback_success_when_iam_fails( self, mock_get_session_conn, mock_conn, example_responses - ) -> None: + ): boto_invoke_error = ClientError( error_response=example_responses["missingIamRole"], operation_name="invoke_rest_api" ) @@ -123,7 +124,6 @@ def test_invoke_rest_api_fallback_to_web_login_when_iam_fails( mock_response = mock.MagicMock() mock_response.status_code = example_responses["success"]["RestApiStatusCode"] mock_response.json.return_value = example_responses["success"]["RestApiResponse"] - mock_session = mock.MagicMock() mock_session.request.return_value = mock_response @@ -132,16 +132,38 @@ def test_invoke_rest_api_fallback_to_web_login_when_iam_fails( retval = self.hook.invoke_rest_api( env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY, query_params=QUERY_PARAMS ) - mock_session.request.assert_called_once_with(**kwargs_to_assert) - # Since the implementation doesn't have separate branches for success and failure cases, this takes - # care of testing both the success and failure cases + mock_session.request.assert_called_once_with(**kwargs_to_assert) mock_response.raise_for_status.assert_called_once() - assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"} + @mock.patch.object(MwaaHook, "_get_session_conn") + def test_invoke_rest_api_using_local_session_token_failure( + self, mock_get_session_conn, example_responses + ): + mock_response = mock.MagicMock() + mock_response.json.return_value = example_responses["failure"]["RestApiResponse"] + error = requests.HTTPError(response=mock_response) + mock_response.raise_for_status.side_effect = error + + mock_session = mock.MagicMock() + mock_session.request.return_value = mock_response + + mock_get_session_conn.return_value = (mock_session, HOSTNAME) + + mock_error_log = mock.MagicMock() + self.hook.log.error = mock_error_log + + with pytest.raises(requests.HTTPError) as caught_error: + self.hook.invoke_rest_api( + env_name=ENV_NAME, path=PATH, method=METHOD, only_generate_local_token=True + ) + + assert caught_error.value == error + mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"]) + @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") - def test_web_login_get_session_conn(self, mock_create_session, mock_conn): + def test_get_session_conn(self, mock_create_session, mock_conn): token = "token" mock_conn.create_web_login_token.return_value = {"WebServerHostname": HOSTNAME, "WebToken": token} login_url = f"https://{HOSTNAME}/aws_mwaa/login" @@ -155,9 +177,6 @@ def test_web_login_get_session_conn(self, mock_create_session, mock_conn): mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME) mock_create_session.assert_called_once_with() mock_session.post.assert_called_once_with(login_url, data=login_payload, timeout=10) - - # Since the implementation doesn't have separate branches for success and failure cases, this takes - # care of testing both the success and failure cases mock_session.post.return_value.raise_for_status.assert_called_once() assert retval == (mock_session, HOSTNAME) From 62a5eefb84ee16d237e68c506508290a058bb6ac Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Wed, 5 Mar 2025 17:23:46 -0800 Subject: [PATCH 6/6] Change MwaaHook local token parameter name --- .../src/airflow/providers/amazon/aws/hooks/mwaa.py | 11 ++++++----- .../amazon/tests/unit/amazon/aws/hooks/test_mwaa.py | 12 +++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py index 77500725ac9d3..0f47f0bafb6c5 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py @@ -33,7 +33,7 @@ class MwaaHook(AwsBaseHook): If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly make requests to the Airflow API. This fallback method can be set as the default (and only) method used by - setting `only_generate_local_token` to True. Learn more here: + setting `generate_local_token` to True. Learn more here: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API Additional arguments (such as ``aws_conn_id``) may be specified and @@ -54,7 +54,7 @@ def invoke_rest_api( method: str, body: dict | None = None, query_params: dict | None = None, - only_generate_local_token: bool = False, + generate_local_token: bool = False, ) -> dict: """ Invoke the REST API on the Airflow webserver with the specified inputs. @@ -67,8 +67,9 @@ def invoke_rest_api( :param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE' :param body: Request body for the Apache Airflow REST API call :param query_params: Query parameters to be included in the Apache Airflow REST API call - :param only_generate_local_token: If True, only the local web login token method is used without - trying boto's invoke_rest_api first + :param generate_local_token: If True, only the local web token method is used without trying boto's + `invoke_rest_api` first. If False, the local web token method is used as a fallback after trying + boto's `invoke_rest_api` """ # Filter out keys with None values because Airflow REST API doesn't accept requests otherwise body = {k: v for k, v in body.items() if v is not None} if body else {} @@ -81,7 +82,7 @@ def invoke_rest_api( "QueryParameters": query_params, } - if only_generate_local_token: + if generate_local_token: return self._invoke_rest_api_using_local_session_token(**api_kwargs) try: diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py index f7d3071657c86..d8046db33a847 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py @@ -88,15 +88,15 @@ def test_invoke_rest_api_failure(self, mock_conn, example_responses): expected_log = {k: v for k, v in example_responses["failure"].items() if k != "ResponseMetadata"} mock_error_log.assert_called_once_with(expected_log) - @pytest.mark.parametrize("only_generate_local_token", [pytest.param(True), pytest.param(False)]) + @pytest.mark.parametrize("generate_local_token", [pytest.param(True), pytest.param(False)]) @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session") def test_invoke_rest_api_local_token_parameter( - self, mock_create_session, only_generate_local_token, mock_conn + self, mock_create_session, generate_local_token, mock_conn ): self.hook.invoke_rest_api( - env_name=ENV_NAME, path=PATH, method=METHOD, only_generate_local_token=only_generate_local_token + env_name=ENV_NAME, path=PATH, method=METHOD, generate_local_token=generate_local_token ) - if only_generate_local_token: + if generate_local_token: mock_conn.invoke_rest_api.assert_not_called() mock_conn.create_web_login_token.assert_called_once() mock_create_session.assert_called_once() @@ -155,9 +155,7 @@ def test_invoke_rest_api_using_local_session_token_failure( self.hook.log.error = mock_error_log with pytest.raises(requests.HTTPError) as caught_error: - self.hook.invoke_rest_api( - env_name=ENV_NAME, path=PATH, method=METHOD, only_generate_local_token=True - ) + self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD, generate_local_token=True) assert caught_error.value == error mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"])