From 781654d29caafd162d960b0cd2010610617dadbe Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Wed, 19 Feb 2020 16:24:20 -0800 Subject: [PATCH 01/12] feat: add SslCredentials class for mTLS ADC (linux) --- google/auth/transport/_mtls_helper.py | 107 ++++++++++++++ google/auth/transport/grpc.py | 37 +++++ tests/data/context_aware_metadata.json | 6 + tests/transport/test__mtls_helper.py | 188 +++++++++++++++++++++++++ tests/transport/test_grpc.py | 81 +++++++++++ 5 files changed, 419 insertions(+) create mode 100644 google/auth/transport/_mtls_helper.py create mode 100644 tests/data/context_aware_metadata.json create mode 100644 tests/transport/test__mtls_helper.py diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py new file mode 100644 index 000000000..9f141b286 --- /dev/null +++ b/google/auth/transport/_mtls_helper.py @@ -0,0 +1,107 @@ +# Copyright 2016 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for getting mTLS cert and key, for internal use only.""" + +import json +import logging +from os import path +import subprocess + +CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" +_CERT_PROVIDER_COMMAND = "cert_provider_command" +_CERTIFICATE_SURFFIX = b"-----END CERTIFICATE-----\n" + +_LOGGER = logging.getLogger(__name__) + + +def read_metadata_file(metadata_path): + """Function to load context aware metadata from the given path. + + Args: + metadata_path (str): context aware metadata path. + + Returns: + Dict[str]: + The metadata. If metadata reading or parsing fails, return None. + """ + metadata_path = path.expanduser(metadata_path) + if not path.exists(metadata_path): + _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path) + return None + + with open(metadata_path) as f: + try: + metadata = json.load(f) + except json.decoder.JSONDecodeError as e: + _LOGGER.debug( + "Failed to decode context_aware_metadata.json with error: %s", str(e) + ) + return None + + return metadata + + +def get_client_ssl_credentials(metadata_json, platform): + """Function to get mTLS client side cert and key. + + Args: + metadata_json (Dict[str]): metadata JSON file which contains the cert + provider command. + platform (str): The OS. + + Returns: + Tuple[bool, bytes, bytes, bytes, bytes]: + The tuple contains the following in order: + (1) boolean to show if client cert and key is obtained successfully + (2) client certificate in PEM forma if successful, otherwise None + (3) client key in PEM format if successful, otherwise None + (4) stdout from cert provider command execution + (5) stderr from cert provider command execution + """ + + # Check the system. For now only Linux is supported. + if not platform.startswith("linux"): + _LOGGER.debug("mTLS for platform: %s is not supported.", platform) + return False, None, None, None, None + + # Execute the cert provider command in the metadata json file. + if _CERT_PROVIDER_COMMAND not in metadata_json: + _LOGGER.debug("cert_provider_command missing, skip client SSL authentication") + return False, None, None, None, None + try: + command = metadata_json[_CERT_PROVIDER_COMMAND] + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + stdout, stderr = process.communicate() + except OSError as e: + _LOGGER.debug("Failed to run cert provider command with error: %s", str(e)) + return False, None, None, None, None + + # Check cert provider command execution error. + if stderr != b"": + _LOGGER.debug("Cert provider command failed with error: %s", stderr) + return False, None, None, stdout, stderr + + # Parse stdout, it should be a cert followed by a key, both in PEM format. + cert_end = stdout.find(_CERTIFICATE_SURFFIX) + if cert_end == -1: + _LOGGER.debug("Client SSL certificate is missing") + return False, None, None, stdout, stderr + private_key_start = cert_end + len(_CERTIFICATE_SURFFIX) + if private_key_start >= len(stdout): + _LOGGER.debug("Client SSL private key is missing") + return False, None, None, stdout, stderr + return True, stdout[0:private_key_start], stdout[private_key_start:], stdout, stderr diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index fb90fbb4b..d9ad2f4d6 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from concurrent import futures +from sys import platform import six +from google.auth.transport import _mtls_helper + try: import grpc except ImportError as caught_exc: # pragma: NO COVER @@ -149,3 +152,37 @@ def secure_authorized_channel( ) return grpc.secure_channel(target, composite_credentials, **kwargs) + + +class SslCredentials: + """Class for application default SSL credentials. For Linux with endpoint + verification support, device certificate will be automatically loaded if + available and mutual TLS will be established. + """ + + def __init__(self): + self._is_mtls = False + + # Load client SSL credentials. + context_aware_metadata = _mtls_helper.read_metadata_file( + _mtls_helper.CONTEXT_AWARE_METADATA_PATH + ) + if context_aware_metadata: + self._is_mtls, cert, key, _, _ = _mtls_helper.get_client_ssl_credentials( + context_aware_metadata, platform + ) + + if self._is_mtls: + self._ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_credentials = grpc.ssl_channel_credentials() + + @property + def ssl_credentials(self): + return self._ssl_credentials + + @property + def is_mtls(self): + return self._is_mtls diff --git a/tests/data/context_aware_metadata.json b/tests/data/context_aware_metadata.json new file mode 100644 index 000000000..ec40e783f --- /dev/null +++ b/tests/data/context_aware_metadata.json @@ -0,0 +1,6 @@ +{ + "cert_provider_command":[ + "/opt/google/endpoint-verification/bin/SecureConnectHelper", + "--print_certificate"], + "device_resource_ids":["11111111-1111-1111"] +} diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py new file mode 100644 index 000000000..90f6cd09e --- /dev/null +++ b/tests/transport/test__mtls_helper.py @@ -0,0 +1,188 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import mock + +from google.auth.transport import _mtls_helper + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") + +with open(os.path.join(DATA_DIR, "privatekey.pub"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + +with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() + +CLIENT_SSL_CREDENTIALS = PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES + +CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} + +CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {} + + +class TestReadMetadataFile(object): + def test_success(self): + metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") + metadata = _mtls_helper.read_metadata_file(metadata_path) + + assert "cert_provider_command" in metadata + + def test_file_not_exist(self): + metadata_path = os.path.join(DATA_DIR, "not_exist.json") + metadata = _mtls_helper.read_metadata_file(metadata_path) + + assert metadata is None + + def test_file_not_json(self): + # read a file which is not json format. + metadata_path = os.path.join(DATA_DIR, "privatekey.pem") + metadata = _mtls_helper.read_metadata_file(metadata_path) + + assert metadata is None + + +class TestGetClientSslCredentials(object): + def create_mock_process(self, output, error): + # There are two steps to execute a script with subprocess.Popen. + # (1) process = subprocess.Popen([comannds]) + # (2) stdout, stderr = process.communicate() + # This function creates a mock process which can be returned by a mock + # subprocess.Popen. The mock process returns the given output and error + # when mock_process.communicate() is called. + mock_process = mock.Mock() + attrs = {"communicate.return_value": (output, error)} + mock_process.configure_mock(**attrs) + return mock_process + + @mock.patch("subprocess.Popen", autospec=True) + def test_success(self, mock_popen): + mock_popen.return_value = self.create_mock_process(CLIENT_SSL_CREDENTIALS, b"") + success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA, "linux" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, output, error), + ( + True, + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + CLIENT_SSL_CREDENTIALS, + b"", + ), + ) + ] + ) + + def test_not_linux_platform(self): + success, cert, key, stdout, stderr = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA, "win32" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, stdout, stderr), + (False, None, None, None, None), + ) + ] + ) + + def test_missing_cert_provider_command(self): + success, cert, key, stdout, stderr = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND, "linux" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, stdout, stderr), + (False, None, None, None, None), + ) + ] + ) + + @mock.patch("subprocess.Popen", autospec=True) + def test_missing_cert(self, mock_popen): + mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"") + success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA, "linux" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, output, error), + (False, None, None, PRIVATE_KEY_BYTES, b""), + ) + ] + ) + + @mock.patch("subprocess.Popen", autospec=True) + def test_missing_key(self, mock_popen): + mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"") + success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA, "linux" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, output, error), + (False, None, None, PUBLIC_CERT_BYTES, b""), + ) + ] + ) + + @mock.patch("subprocess.Popen", autospec=True) + def test_cert_provider_returns_error(self, mock_popen): + mock_popen.return_value = self.create_mock_process(b"", b"some error") + success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA, "linux" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, output, error), + (False, None, None, b"", b"some error"), + ) + ] + ) + + @mock.patch("subprocess.Popen", autospec=True) + def test_popen_raise_exception(self, mock_popen): + mock_popen.side_effect = OSError() + success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA, "linux" + ) + + assert all( + [ + a == b + for a, b in zip( + (success, cert, key, output, error), (False, None, None, None, None) + ) + ] + ) diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 857c32bb9..0928511c7 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import os import time import mock @@ -154,3 +155,83 @@ def test_secure_authorized_channel_explicit_ssl( composite_channel_credentials.assert_called_once_with( ssl_credentials, metadata_call_credentials.return_value ) + + +@mock.patch("grpc.ssl_channel_credentials", autospec=True) +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport._mtls_helper.read_metadata_file", autospec=True) +class TestSslCredentials(object): + def test_no_context_aware_metadata( + self, + mock_read_metadata_file, + mock_get_client_ssl_credentials, + mock_ssl_channel_credentials, + ): + # Mock that read_metadata_file function returns no metadata. + mock_read_metadata_file.return_value = None + + ssl_credentials = google.auth.transport.grpc.SslCredentials() + + # Since no context aware metadata is found, we wouldn't call + # get_client_ssl_credentials, and the SSL channel credentials created is + # non mTLS. + mock_get_client_ssl_credentials.assert_not_called() + mock_ssl_channel_credentials.assert_called_once_with() + assert ssl_credentials.ssl_credentials is not None + assert not ssl_credentials.is_mtls + + def test_get_client_ssl_credentials_failure( + self, + mock_read_metadata_file, + mock_get_client_ssl_credentials, + mock_ssl_channel_credentials, + ): + mock_read_metadata_file.return_value = { + "cert_provider_command": ["some command"] + } + + # Mock that client cert and key are not loaded. + mock_get_client_ssl_credentials.return_value = (False, None, None, None, None) + + ssl_credentials = google.auth.transport.grpc.SslCredentials() + + # Since we failed to get_client_ssl_credentials, the SSL channel + # credentials created is non mTLS. + mock_get_client_ssl_credentials.assert_called_once() + mock_ssl_channel_credentials.assert_called_once_with() + assert ssl_credentials.ssl_credentials is not None + assert not ssl_credentials.is_mtls + + def test_get_client_ssl_credentials_success( + self, + mock_read_metadata_file, + mock_get_client_ssl_credentials, + mock_ssl_channel_credentials, + ): + mock_read_metadata_file.return_value = { + "cert_provider_command": ["some command"] + } + + DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") + with open(os.path.join(DATA_DIR, "privatekey.pub"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() + with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() + mock_get_client_ssl_credentials.return_value = ( + True, + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + None, + None, + ) + + ssl_credentials = google.auth.transport.grpc.SslCredentials() + + mock_get_client_ssl_credentials.assert_called_once() + mock_ssl_channel_credentials.assert_called_once_with( + certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES + ) + assert ssl_credentials.ssl_credentials is not None + assert ssl_credentials.is_mtls From 2f0aaa734a39f53f5fdd889c15925899da92ce31 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Mon, 24 Feb 2020 10:23:26 -0800 Subject: [PATCH 02/12] Fix JSON exception issue --- google/auth/transport/_mtls_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 9f141b286..03c5de782 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -44,7 +44,7 @@ def read_metadata_file(metadata_path): with open(metadata_path) as f: try: metadata = json.load(f) - except json.decoder.JSONDecodeError as e: + except Exception as e: _LOGGER.debug( "Failed to decode context_aware_metadata.json with error: %s", str(e) ) From e628966ec9b3e1f7f8d19e172c01591d6315b8cf Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Tue, 25 Feb 2020 13:51:08 -0800 Subject: [PATCH 03/12] Update google/auth/transport/_mtls_helper.py Co-Authored-By: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> --- google/auth/transport/_mtls_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 03c5de782..46152f07f 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ed65141490d60b69a8ef8afa55a16503242620e0 Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Tue, 25 Feb 2020 15:53:16 -0800 Subject: [PATCH 04/12] modify the code based on comments --- google/auth/transport/_mtls_helper.py | 90 ++++++------ google/auth/transport/grpc.py | 19 ++- tests/transport/test__mtls_helper.py | 190 ++++++++++++-------------- tests/transport/test_grpc.py | 19 ++- 4 files changed, 156 insertions(+), 162 deletions(-) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 46152f07f..fa2512ecf 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -17,23 +17,35 @@ import json import logging from os import path +import re import subprocess CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" _CERT_PROVIDER_COMMAND = "cert_provider_command" -_CERTIFICATE_SURFFIX = b"-----END CERTIFICATE-----\n" +_CERT_REGEX = re.compile( + b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL +) + +# support various format of key files, e.g. +# "-----BEGIN PRIVATE KEY-----...", +# "-----BEGIN EC PRIVATE KEY-----...", +# "-----BEGIN RSA PRIVATE KEY-----..." +_KEY_REGEX = re.compile( + b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?", + re.DOTALL, +) _LOGGER = logging.getLogger(__name__) -def read_metadata_file(metadata_path): +def _read_dca_metadata_file(metadata_path): """Function to load context aware metadata from the given path. Args: metadata_path (str): context aware metadata path. Returns: - Dict[str]: + Dict[str, str]: The metadata. If metadata reading or parsing fails, return None. """ metadata_path = path.expanduser(metadata_path) @@ -53,55 +65,47 @@ def read_metadata_file(metadata_path): return metadata -def get_client_ssl_credentials(metadata_json, platform): +def get_client_ssl_credentials(metadata_json): """Function to get mTLS client side cert and key. Args: metadata_json (Dict[str]): metadata JSON file which contains the cert provider command. - platform (str): The OS. Returns: - Tuple[bool, bytes, bytes, bytes, bytes]: - The tuple contains the following in order: - (1) boolean to show if client cert and key is obtained successfully - (2) client certificate in PEM forma if successful, otherwise None - (3) client key in PEM format if successful, otherwise None - (4) stdout from cert provider command execution - (5) stderr from cert provider command execution + Tuple[bytes, bytes]: client certificate and key, both in PEM format. + + Raises: + OSError: subprocess throws OSError if failed to run cert provider command + RuntimeError: if cert provider command has runtime error + ValueError: + if metadata json file doesn't contain cert provider command, or the + execution of this command doesn't produce both client certicate and + client key. """ - # Check the system. For now only Linux is supported. - if not platform.startswith("linux"): - _LOGGER.debug("mTLS for platform: %s is not supported.", platform) - return False, None, None, None, None - - # Execute the cert provider command in the metadata json file. + # Check the cert provider command existence in the metadata json file. if _CERT_PROVIDER_COMMAND not in metadata_json: - _LOGGER.debug("cert_provider_command missing, skip client SSL authentication") - return False, None, None, None, None - try: - command = metadata_json[_CERT_PROVIDER_COMMAND] - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - stdout, stderr = process.communicate() - except OSError as e: - _LOGGER.debug("Failed to run cert provider command with error: %s", str(e)) - return False, None, None, None, None + raise ValueError("Cert provider command is not found") + + # Execute the command. It throws OsError in case of system failure. + command = metadata_json[_CERT_PROVIDER_COMMAND] + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + process_return_code = process.returncode # Check cert provider command execution error. - if stderr != b"": - _LOGGER.debug("Cert provider command failed with error: %s", stderr) - return False, None, None, stdout, stderr - - # Parse stdout, it should be a cert followed by a key, both in PEM format. - cert_end = stdout.find(_CERTIFICATE_SURFFIX) - if cert_end == -1: - _LOGGER.debug("Client SSL certificate is missing") - return False, None, None, stdout, stderr - private_key_start = cert_end + len(_CERTIFICATE_SURFFIX) - if private_key_start >= len(stdout): - _LOGGER.debug("Client SSL private key is missing") - return False, None, None, stdout, stderr - return True, stdout[0:private_key_start], stdout[private_key_start:], stdout, stderr + if process_return_code != 0: + raise RuntimeError( + "Cert provider command returns non-zero status code %s" + % process_return_code + ) + + # Extract certificate (chain) and key. + cert_match = re.findall(_CERT_REGEX, stdout) + if len(cert_match) != 1: + raise ValueError("Client SSL certificate is missing or invalid") + key_match = re.findall(_KEY_REGEX, stdout) + if len(key_match) != 1: + raise ValueError("Client SSL key is missing or invalid") + return cert_match[0], key_match[0] diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index d9ad2f4d6..12b21c98c 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -17,7 +17,7 @@ from __future__ import absolute_import from concurrent import futures -from sys import platform +import logging import six @@ -34,6 +34,8 @@ caught_exc, ) +_LOGGER = logging.getLogger(__name__) + class AuthMetadataPlugin(grpc.AuthMetadataPlugin): """A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each @@ -164,13 +166,20 @@ def __init__(self): self._is_mtls = False # Load client SSL credentials. - context_aware_metadata = _mtls_helper.read_metadata_file( + context_aware_metadata = _mtls_helper._read_dca_metadata_file( _mtls_helper.CONTEXT_AWARE_METADATA_PATH ) if context_aware_metadata: - self._is_mtls, cert, key, _, _ = _mtls_helper.get_client_ssl_credentials( - context_aware_metadata, platform - ) + try: + cert, key = _mtls_helper.get_client_ssl_credentials( + context_aware_metadata + ) + self._is_mtls = True + + except (NotImplementedError, OSError, RuntimeError, ValueError) as e: + _LOGGER.debug( + "Failed to get client SSL credentials with error: %s", str(e) + ) if self._is_mtls: self._ssl_credentials = grpc.ssl_channel_credentials( diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index 90f6cd09e..3d6b1a0e4 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -13,43 +13,97 @@ # limitations under the License. import os +import re import mock +import pytest from google.auth.transport import _mtls_helper DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") -with open(os.path.join(DATA_DIR, "privatekey.pub"), "rb") as fh: +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: PRIVATE_KEY_BYTES = fh.read() with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: PUBLIC_CERT_BYTES = fh.read() -CLIENT_SSL_CREDENTIALS = PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES - CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {} +def check_cert_and_key(content, expected_cert, expected_key): + success = True + + cert_match = re.findall(_mtls_helper._CERT_REGEX, content) + success = success and len(cert_match) == 1 and cert_match[0] == expected_cert + + key_match = re.findall(_mtls_helper._KEY_REGEX, content) + success = success and len(key_match) == 1 and key_match[0] == expected_key + + return success + + +class TestCertAndKeyRegex(object): + def test_cert_and_key(self): + # Test signle cert and single key + check_cert_and_key( + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES + ) + check_cert_and_key( + PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES + ) + + # Test cert chain and single key + check_cert_and_key( + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + ) + check_cert_and_key( + PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, + PRIVATE_KEY_BYTES, + ) + + def test_key(self): + # Create some fake keys for regex check. + KEY = b"""-----BEGIN PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END PRIVATE KEY-----""" + RSA_KEY = b"""-----BEGIN RSA PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END RSA PRIVATE KEY-----""" + EC_KEY = b"""-----BEGIN EC PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END EC PRIVATE KEY-----""" + + check_cert_and_key(PUBLIC_CERT_BYTES + KEY, PUBLIC_CERT_BYTES, KEY) + check_cert_and_key(PUBLIC_CERT_BYTES + RSA_KEY, PUBLIC_CERT_BYTES, RSA_KEY) + check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY) + + class TestReadMetadataFile(object): def test_success(self): metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") - metadata = _mtls_helper.read_metadata_file(metadata_path) + metadata = _mtls_helper._read_dca_metadata_file(metadata_path) assert "cert_provider_command" in metadata def test_file_not_exist(self): metadata_path = os.path.join(DATA_DIR, "not_exist.json") - metadata = _mtls_helper.read_metadata_file(metadata_path) + metadata = _mtls_helper._read_dca_metadata_file(metadata_path) assert metadata is None def test_file_not_json(self): # read a file which is not json format. metadata_path = os.path.join(DATA_DIR, "privatekey.pem") - metadata = _mtls_helper.read_metadata_file(metadata_path) + metadata = _mtls_helper._read_dca_metadata_file(metadata_path) assert metadata is None @@ -63,126 +117,56 @@ def create_mock_process(self, output, error): # subprocess.Popen. The mock process returns the given output and error # when mock_process.communicate() is called. mock_process = mock.Mock() - attrs = {"communicate.return_value": (output, error)} + attrs = {"communicate.return_value": (output, error), "returncode": 0} mock_process.configure_mock(**attrs) return mock_process @mock.patch("subprocess.Popen", autospec=True) def test_success(self, mock_popen): - mock_popen.return_value = self.create_mock_process(CLIENT_SSL_CREDENTIALS, b"") - success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA, "linux" - ) - - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, output, error), - ( - True, - PUBLIC_CERT_BYTES, - PRIVATE_KEY_BYTES, - CLIENT_SSL_CREDENTIALS, - b"", - ), - ) - ] - ) - - def test_not_linux_platform(self): - success, cert, key, stdout, stderr = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA, "win32" + mock_popen.return_value = self.create_mock_process( + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, b"" ) + cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + assert cert == PUBLIC_CERT_BYTES + assert key == PRIVATE_KEY_BYTES - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, stdout, stderr), - (False, None, None, None, None), - ) - ] + @mock.patch("subprocess.Popen", autospec=True) + def test_success_with_cert_chain(self, mock_popen): + PUBLIC_CERT_CHAIN_BYTES = PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + mock_popen.return_value = self.create_mock_process( + PUBLIC_CERT_CHAIN_BYTES + PRIVATE_KEY_BYTES, b"" ) + cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + assert cert == PUBLIC_CERT_CHAIN_BYTES + assert key == PRIVATE_KEY_BYTES def test_missing_cert_provider_command(self): - success, cert, key, stdout, stderr = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND, "linux" - ) - - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, stdout, stderr), - (False, None, None, None, None), - ) - ] - ) + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND + ) @mock.patch("subprocess.Popen", autospec=True) def test_missing_cert(self, mock_popen): mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"") - success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA, "linux" - ) - - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, output, error), - (False, None, None, PRIVATE_KEY_BYTES, b""), - ) - ] - ) + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) @mock.patch("subprocess.Popen", autospec=True) def test_missing_key(self, mock_popen): mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"") - success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA, "linux" - ) - - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, output, error), - (False, None, None, PUBLIC_CERT_BYTES, b""), - ) - ] - ) + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) @mock.patch("subprocess.Popen", autospec=True) def test_cert_provider_returns_error(self, mock_popen): mock_popen.return_value = self.create_mock_process(b"", b"some error") - success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA, "linux" - ) - - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, output, error), - (False, None, None, b"", b"some error"), - ) - ] - ) + mock_popen.return_value.returncode = 1 + with pytest.raises(RuntimeError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) @mock.patch("subprocess.Popen", autospec=True) def test_popen_raise_exception(self, mock_popen): mock_popen.side_effect = OSError() - success, cert, key, output, error = _mtls_helper.get_client_ssl_credentials( - CONTEXT_AWARE_METADATA, "linux" - ) - - assert all( - [ - a == b - for a, b in zip( - (success, cert, key, output, error), (False, None, None, None, None) - ) - ] - ) + with pytest.raises(OSError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 0928511c7..0788a3a6c 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -161,16 +161,16 @@ def test_secure_authorized_channel_explicit_ssl( @mock.patch( "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True ) -@mock.patch("google.auth.transport._mtls_helper.read_metadata_file", autospec=True) +@mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True) class TestSslCredentials(object): def test_no_context_aware_metadata( self, - mock_read_metadata_file, + mock_read_dca_metadata_file, mock_get_client_ssl_credentials, mock_ssl_channel_credentials, ): - # Mock that read_metadata_file function returns no metadata. - mock_read_metadata_file.return_value = None + # Mock that _read_dca_metadata_file function returns no metadata. + mock_read_dca_metadata_file.return_value = None ssl_credentials = google.auth.transport.grpc.SslCredentials() @@ -184,11 +184,11 @@ def test_no_context_aware_metadata( def test_get_client_ssl_credentials_failure( self, - mock_read_metadata_file, + mock_read_dca_metadata_file, mock_get_client_ssl_credentials, mock_ssl_channel_credentials, ): - mock_read_metadata_file.return_value = { + mock_read_dca_metadata_file.return_value = { "cert_provider_command": ["some command"] } @@ -206,11 +206,11 @@ def test_get_client_ssl_credentials_failure( def test_get_client_ssl_credentials_success( self, - mock_read_metadata_file, + mock_read_dca_metadata_file, mock_get_client_ssl_credentials, mock_ssl_channel_credentials, ): - mock_read_metadata_file.return_value = { + mock_read_dca_metadata_file.return_value = { "cert_provider_command": ["some command"] } @@ -220,11 +220,8 @@ def test_get_client_ssl_credentials_success( with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: PUBLIC_CERT_BYTES = fh.read() mock_get_client_ssl_credentials.return_value = ( - True, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES, - None, - None, ) ssl_credentials = google.auth.transport.grpc.SslCredentials() From f5671e48419ef462afe8cf7d2c68bd25ff447526 Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Wed, 26 Feb 2020 12:57:45 -0800 Subject: [PATCH 05/12] fix docstring --- google/auth/transport/grpc.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index 12b21c98c..aa3a3c21e 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -133,7 +133,9 @@ def secure_authorized_channel( without using a standard http transport. target (str): The host and port of the service. ssl_credentials (grpc.ChannelCredentials): Optional SSL channel - credentials. This can be used to specify different certificates. + credentials. This can be used to specify different certificates. If + not provided, application default SSL channel credentials will be + used. kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. Returns: @@ -146,7 +148,8 @@ def secure_authorized_channel( google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) if ssl_credentials is None: - ssl_credentials = grpc.ssl_channel_credentials() + adc_ssl_credentils = SslCredentials() + ssl_credentials = adc_ssl_credentils.ssl_credentials # Combine the ssl credentials and the authorization credentials. composite_credentials = grpc.composite_channel_credentials( @@ -157,9 +160,11 @@ def secure_authorized_channel( class SslCredentials: - """Class for application default SSL credentials. For Linux with endpoint - verification support, device certificate will be automatically loaded if - available and mutual TLS will be established. + """Class for application default SSL credentials. + + For Linux with endpoint verification support, device certificate will be + automatically loaded if available and mutual TLS will be established. + See https://cloud.google.com/endpoint-verification/docs/overview. """ def __init__(self): @@ -190,8 +195,10 @@ def __init__(self): @property def ssl_credentials(self): + """Get the created SSL channel credentials.""" return self._ssl_credentials @property def is_mtls(self): + """"Property indicting if the created SSL channel credentials is mutual TLS.""" return self._is_mtls From 296977e7c570f30a05c3cbf95a6a27f2e1cdada7 Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Wed, 26 Feb 2020 15:08:05 -0800 Subject: [PATCH 06/12] throw exceptions to user, add client_cert_callback --- google/auth/transport/_mtls_helper.py | 2 + google/auth/transport/grpc.py | 91 +++++++--- tests/transport/test_grpc.py | 245 ++++++++++++++++++-------- 3 files changed, 241 insertions(+), 97 deletions(-) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index fa2512ecf..327b72acb 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -83,6 +83,8 @@ def get_client_ssl_credentials(metadata_json): execution of this command doesn't produce both client certicate and client key. """ + # TODO: implement an in-memory cache of cert and key so we don't have to + # run cert provider command every time. # Check the cert provider command existence in the metadata json file. if _CERT_PROVIDER_COMMAND not in metadata_json: diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index aa3a3c21e..68683d541 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -97,7 +97,12 @@ def __del__(self): def secure_authorized_channel( - credentials, request, target, ssl_credentials=None, **kwargs + credentials, + request, + target, + ssl_credentials=None, + client_cert_callback=None, + **kwargs ): """Creates a secure authorized gRPC channel. @@ -133,13 +138,31 @@ def secure_authorized_channel( without using a standard http transport. target (str): The host and port of the service. ssl_credentials (grpc.ChannelCredentials): Optional SSL channel - credentials. This can be used to specify different certificates. If - not provided, application default SSL channel credentials will be + credentials. This can be used to specify different certificates. + This argument is mutually exclusive with ```client_cert_callback```; + providing both will raise an exception. + If ```ssl_credentials``` is ```None``` and ```client_cert_callback``` + is ```None``` or fails, application default SSL credentials will be + used. + client_cert_callback (Callable[[], (bool, bytes, bytes)]): Optional + callback function to obtain client certicate and key for mutual TLS + connection. This argument is mutually exclusive with + ```ssl_credentials```; providing both will raise an exception. + If ```ssl_credentials``` is ```None``` and ```client_cert_callback``` + is ```None``` or fails, application default SSL credentials will be used. kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. Returns: grpc.Channel: The created gRPC channel. + + Raises: + If ```ssl_credentials``` is ```None``` and ```client_cert_callback``` + is ```None``` or fails, application default SSL credentials will be + used. For device with endpoint verification support, exceptions might be + raised during the application default SSL credentials creation + procedure. Please check :func:`~.SslCredentials.ssl_credentials` for the + possible exceptions. """ # Create the metadata plugin for inserting the authorization header. metadata_plugin = AuthMetadataPlugin(credentials, request) @@ -147,6 +170,19 @@ def secure_authorized_channel( # Create a set of grpc.CallCredentials using the metadata plugin. google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) + if ssl_credentials and client_cert_callback: + raise ValueError( + "Received both ssl_credentials and client_cert_callback; " + "these are mutually exclusive." + ) + + if client_cert_callback: + success, cert, key = client_cert_callback() + if success: + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + if ssl_credentials is None: adc_ssl_credentils = SslCredentials() ssl_credentials = adc_ssl_credentils.ssl_credentials @@ -162,40 +198,47 @@ def secure_authorized_channel( class SslCredentials: """Class for application default SSL credentials. - For Linux with endpoint verification support, device certificate will be - automatically loaded if available and mutual TLS will be established. + For device with endpoint verification support, device certificate will be + automatically loaded and mutual TLS will be established. See https://cloud.google.com/endpoint-verification/docs/overview. """ def __init__(self): - self._is_mtls = False - # Load client SSL credentials. - context_aware_metadata = _mtls_helper._read_dca_metadata_file( + self._context_aware_metadata = _mtls_helper._read_dca_metadata_file( _mtls_helper.CONTEXT_AWARE_METADATA_PATH ) - if context_aware_metadata: - try: - cert, key = _mtls_helper.get_client_ssl_credentials( - context_aware_metadata - ) - self._is_mtls = True - - except (NotImplementedError, OSError, RuntimeError, ValueError) as e: - _LOGGER.debug( - "Failed to get client SSL credentials with error: %s", str(e) - ) - - if self._is_mtls: + if self._context_aware_metadata: + self._is_mtls = True + else: + self._is_mtls = False + + @property + def ssl_credentials(self): + """Get the created SSL channel credentials. + + For device with endpoint verification support, if device certificate + loading has any problems, corresponding exceptions will be raised. For + device without endpoint verification support, no exceptions will be + raised. + + Raises: + OSError: cert provider command launch failure + RuntimeError: cert provider command runtime error + ValueError: + if context aware metadata file is malformed, or cert provider + command doesn't produce both client certicate and key. + """ + if self._context_aware_metadata: + cert, key = _mtls_helper.get_client_ssl_credentials( + self._context_aware_metadata + ) self._ssl_credentials = grpc.ssl_channel_credentials( certificate_chain=cert, private_key=key ) else: self._ssl_credentials = grpc.ssl_channel_credentials() - @property - def ssl_credentials(self): - """Get the created SSL channel credentials.""" return self._ssl_credentials @property diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 0788a3a6c..8d2e8a35c 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -32,6 +32,11 @@ except ImportError: # pragma: NO COVER HAS_GRPC = False +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: + PRIVATE_KEY_BYTES = fh.read() +with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: + PUBLIC_CERT_BYTES = fh.read() pytestmark = pytest.mark.skipif(not HAS_GRPC, reason="gRPC is unavailable.") @@ -88,73 +93,179 @@ def test_call_refresh(self): ) +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) @mock.patch("grpc.composite_channel_credentials", autospec=True) @mock.patch("grpc.metadata_call_credentials", autospec=True) @mock.patch("grpc.ssl_channel_credentials", autospec=True) @mock.patch("grpc.secure_channel", autospec=True) -def test_secure_authorized_channel( - secure_channel, - ssl_channel_credentials, - metadata_call_credentials, - composite_channel_credentials, -): - credentials = CredentialsStub() - request = mock.create_autospec(transport.Request) - target = "example.com:80" - - channel = google.auth.transport.grpc.secure_authorized_channel( - credentials, request, target, options=mock.sentinel.options +class TestSecureAuthorizedChannel(object): + @mock.patch( + "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True ) + def test_secure_authorized_channel( + self, + read_dca_metadata_file, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = CredentialsStub() + request = mock.create_autospec(transport.Request) + target = "example.com:80" - # Check the auth plugin construction. - auth_plugin = metadata_call_credentials.call_args[0][0] - assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin) - assert auth_plugin._credentials == credentials - assert auth_plugin._request == request + # Mock the context aware metadata and client cert/key so mTLS SSL channel + # will be used. + read_dca_metadata_file.return_value = { + "cert_provider_command": ["some command"] + } + get_client_ssl_credentials.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES) - # Check the ssl channel call. - assert ssl_channel_credentials.called + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, options=mock.sentinel.options + ) - # Check the composite credentials call. - composite_channel_credentials.assert_called_once_with( - ssl_channel_credentials.return_value, metadata_call_credentials.return_value - ) + # Check the auth plugin construction. + auth_plugin = metadata_call_credentials.call_args[0][0] + assert isinstance(auth_plugin, google.auth.transport.grpc.AuthMetadataPlugin) + assert auth_plugin._credentials == credentials + assert auth_plugin._request == request - # Check the channel call. - secure_channel.assert_called_once_with( - target, - composite_channel_credentials.return_value, - options=mock.sentinel.options, - ) - assert channel == secure_channel.return_value + # Check the ssl channel call. + ssl_channel_credentials.assert_called_once_with( + certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES + ) + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_channel_credentials.return_value, metadata_call_credentials.return_value + ) -@mock.patch("grpc.composite_channel_credentials", autospec=True) -@mock.patch("grpc.metadata_call_credentials", autospec=True) -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("grpc.secure_channel", autospec=True) -def test_secure_authorized_channel_explicit_ssl( - secure_channel, - ssl_channel_credentials, - metadata_call_credentials, - composite_channel_credentials, -): - credentials = mock.Mock() - request = mock.Mock() - target = "example.com:80" - ssl_credentials = mock.Mock() - - google.auth.transport.grpc.secure_authorized_channel( - credentials, request, target, ssl_credentials=ssl_credentials - ) + # Check the channel call. + secure_channel.assert_called_once_with( + target, + composite_channel_credentials.return_value, + options=mock.sentinel.options, + ) + assert channel == secure_channel.return_value + + def test_secure_authorized_channel_explicit_ssl( + self, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + ssl_credentials = mock.Mock() + + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, ssl_credentials=ssl_credentials + ) + + # Since explicit SSL credentials are provided, get_client_ssl_credentials + # shouldn't be called. + assert not get_client_ssl_credentials.called + + # Check the ssl channel call. + assert not ssl_channel_credentials.called + + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_credentials, metadata_call_credentials.return_value + ) + + def test_secure_authorized_channel_mutual_exclusive( + self, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + ssl_credentials = mock.Mock() + client_cert_callback = mock.Mock() + + with pytest.raises(ValueError): + google.auth.transport.grpc.secure_authorized_channel( + credentials, + request, + target, + ssl_credentials=ssl_credentials, + client_cert_callback=client_cert_callback, + ) + + def test_secure_authorized_channel_with_client_cert_callback_success( + self, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + client_cert_callback = mock.Mock() + client_cert_callback.return_value = (True, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES) + + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, client_cert_callback=client_cert_callback + ) + + client_cert_callback.assert_called_once() + + # Check we are using the cert and key provided by client_cert_callback. + ssl_channel_credentials.assert_called_once_with( + certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES + ) - # Check the ssl channel call. - assert not ssl_channel_credentials.called + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_channel_credentials.return_value, metadata_call_credentials.return_value + ) - # Check the composite credentials call. - composite_channel_credentials.assert_called_once_with( - ssl_credentials, metadata_call_credentials.return_value + @mock.patch( + "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True ) + def test_secure_authorized_channel_with_client_cert_callback_failure( + self, + read_dca_metadata_file, + secure_channel, + ssl_channel_credentials, + metadata_call_credentials, + composite_channel_credentials, + get_client_ssl_credentials, + ): + credentials = mock.Mock() + request = mock.Mock() + target = "example.com:80" + client_cert_callback = mock.Mock() + client_cert_callback.return_value = (False, None, None) + + # Set DCA metadata to None to not trigger mTLS DCA for test simplicity. + read_dca_metadata_file.return_value = None + + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, client_cert_callback=client_cert_callback + ) + + client_cert_callback.assert_called_once() + ssl_channel_credentials.assert_called_once_with() + + # Check the composite credentials call. + composite_channel_credentials.assert_called_once_with( + ssl_channel_credentials.return_value, metadata_call_credentials.return_value + ) @mock.patch("grpc.ssl_channel_credentials", autospec=True) @@ -177,10 +288,10 @@ def test_no_context_aware_metadata( # Since no context aware metadata is found, we wouldn't call # get_client_ssl_credentials, and the SSL channel credentials created is # non mTLS. - mock_get_client_ssl_credentials.assert_not_called() - mock_ssl_channel_credentials.assert_called_once_with() assert ssl_credentials.ssl_credentials is not None assert not ssl_credentials.is_mtls + mock_get_client_ssl_credentials.assert_not_called() + mock_ssl_channel_credentials.assert_called_once_with() def test_get_client_ssl_credentials_failure( self, @@ -192,17 +303,11 @@ def test_get_client_ssl_credentials_failure( "cert_provider_command": ["some command"] } - # Mock that client cert and key are not loaded. - mock_get_client_ssl_credentials.return_value = (False, None, None, None, None) - - ssl_credentials = google.auth.transport.grpc.SslCredentials() + # Mock that client cert and key are not loaded and exception is raised. + mock_get_client_ssl_credentials.side_effect = ValueError() - # Since we failed to get_client_ssl_credentials, the SSL channel - # credentials created is non mTLS. - mock_get_client_ssl_credentials.assert_called_once() - mock_ssl_channel_credentials.assert_called_once_with() - assert ssl_credentials.ssl_credentials is not None - assert not ssl_credentials.is_mtls + with pytest.raises(ValueError): + assert google.auth.transport.grpc.SslCredentials().ssl_credentials def test_get_client_ssl_credentials_success( self, @@ -213,12 +318,6 @@ def test_get_client_ssl_credentials_success( mock_read_dca_metadata_file.return_value = { "cert_provider_command": ["some command"] } - - DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") - with open(os.path.join(DATA_DIR, "privatekey.pub"), "rb") as fh: - PRIVATE_KEY_BYTES = fh.read() - with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: - PUBLIC_CERT_BYTES = fh.read() mock_get_client_ssl_credentials.return_value = ( PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES, @@ -226,9 +325,9 @@ def test_get_client_ssl_credentials_success( ssl_credentials = google.auth.transport.grpc.SslCredentials() + assert ssl_credentials.ssl_credentials is not None + assert ssl_credentials.is_mtls mock_get_client_ssl_credentials.assert_called_once() mock_ssl_channel_credentials.assert_called_once_with( certificate_chain=PUBLIC_CERT_BYTES, private_key=PRIVATE_KEY_BYTES ) - assert ssl_credentials.ssl_credentials is not None - assert ssl_credentials.is_mtls From 7d8dff7908dbcbf2a091973c8c88d6ec0bce82ee Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Mon, 2 Mar 2020 14:55:38 -0800 Subject: [PATCH 07/12] throw exception is metadata file is not json --- google/auth/transport/_mtls_helper.py | 31 ++++++++++++------- google/auth/transport/grpc.py | 44 +++++++++++++++------------ tests/transport/test__mtls_helper.py | 23 ++++++++------ tests/transport/test_grpc.py | 24 +++++++++++++-- 4 files changed, 79 insertions(+), 43 deletions(-) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 327b72acb..b4265ec44 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -38,29 +38,38 @@ _LOGGER = logging.getLogger(__name__) -def _read_dca_metadata_file(metadata_path): - """Function to load context aware metadata from the given path. +def _check_dca_metadata_path(metadata_path): + """Check the existence of context aware metadata. If exists, return the + absolute path; otherwise return None. Args: metadata_path (str): context aware metadata path. Returns: - Dict[str, str]: - The metadata. If metadata reading or parsing fails, return None. + str: absolute path if exists and None otherwise. """ metadata_path = path.expanduser(metadata_path) + print(metadata_path) if not path.exists(metadata_path): _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path) return None + return metadata_path + +def _read_dca_metadata_file(metadata_path): + """Function to load context aware metadata from the given path. + + Args: + metadata_path (str): context aware metadata path. + + Returns: + Dict[str, str]: The metadata. + + Raises: + ValueError: If failed to parse metadata as JSON. + """ with open(metadata_path) as f: - try: - metadata = json.load(f) - except Exception as e: - _LOGGER.debug( - "Failed to decode context_aware_metadata.json with error: %s", str(e) - ) - return None + metadata = json.load(f) return metadata diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index 68683d541..cba94afb5 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -139,30 +139,33 @@ def secure_authorized_channel( target (str): The host and port of the service. ssl_credentials (grpc.ChannelCredentials): Optional SSL channel credentials. This can be used to specify different certificates. - This argument is mutually exclusive with ```client_cert_callback```; + This argument is mutually exclusive with client_cert_callback; providing both will raise an exception. - If ```ssl_credentials``` is ```None``` and ```client_cert_callback``` - is ```None``` or fails, application default SSL credentials will be - used. + If ssl_credentials is None and client_cert_callback is None or + fails, application default SSL credentials will be used. client_cert_callback (Callable[[], (bool, bytes, bytes)]): Optional callback function to obtain client certicate and key for mutual TLS connection. This argument is mutually exclusive with - ```ssl_credentials```; providing both will raise an exception. - If ```ssl_credentials``` is ```None``` and ```client_cert_callback``` - is ```None``` or fails, application default SSL credentials will be - used. + ssl_credentials; providing both will raise an exception. + If ssl_credentials is None and client_cert_callback is None or + fails, application default SSL credentials will be used. kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. Returns: grpc.Channel: The created gRPC channel. Raises: - If ```ssl_credentials``` is ```None``` and ```client_cert_callback``` - is ```None``` or fails, application default SSL credentials will be - used. For device with endpoint verification support, exceptions might be - raised during the application default SSL credentials creation - procedure. Please check :func:`~.SslCredentials.ssl_credentials` for the - possible exceptions. + OSError: cert provider command launch failure, in application default SSL + credentials loading process on devices with endpoint verification + support. + RuntimeError: cert provider command runtime error, in application + default SSL credentials loading process on devices with endpoint + verification support. + ValueError: + if context aware metadata file is malformed, or cert provider + command doesn't produce both client certicate and key, in application + default SSL credentials loading process on devices with endpoint + verification support. """ # Create the metadata plugin for inserting the authorization header. metadata_plugin = AuthMetadataPlugin(credentials, request) @@ -205,10 +208,10 @@ class SslCredentials: def __init__(self): # Load client SSL credentials. - self._context_aware_metadata = _mtls_helper._read_dca_metadata_file( + self._context_aware_metadata_path = _mtls_helper._check_dca_metadata_path( _mtls_helper.CONTEXT_AWARE_METADATA_PATH ) - if self._context_aware_metadata: + if self._context_aware_metadata_path: self._is_mtls = True else: self._is_mtls = False @@ -229,10 +232,11 @@ def ssl_credentials(self): if context aware metadata file is malformed, or cert provider command doesn't produce both client certicate and key. """ - if self._context_aware_metadata: - cert, key = _mtls_helper.get_client_ssl_credentials( - self._context_aware_metadata + if self._context_aware_metadata_path: + metadata = _mtls_helper._read_dca_metadata_file( + self._context_aware_metadata_path ) + cert, key = _mtls_helper.get_client_ssl_credentials(metadata) self._ssl_credentials = grpc.ssl_channel_credentials( certificate_chain=cert, private_key=key ) @@ -243,5 +247,5 @@ def ssl_credentials(self): @property def is_mtls(self): - """"Property indicting if the created SSL channel credentials is mutual TLS.""" + """Property indicting if the created SSL channel credentials is mutual TLS.""" return self._is_mtls diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index 3d6b1a0e4..b0be9bc34 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -87,25 +87,30 @@ def test_key(self): check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY) -class TestReadMetadataFile(object): +class TestCheckaMetadataPath(object): def test_success(self): metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") - metadata = _mtls_helper._read_dca_metadata_file(metadata_path) + returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) + assert returned_path is not None - assert "cert_provider_command" in metadata + def test_failure(self): + metadata_path = os.path.join(DATA_DIR, "not_exists.json") + returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) + assert returned_path is None - def test_file_not_exist(self): - metadata_path = os.path.join(DATA_DIR, "not_exist.json") + +class TestReadMetadataFile(object): + def test_success(self): + metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") metadata = _mtls_helper._read_dca_metadata_file(metadata_path) - assert metadata is None + assert "cert_provider_command" in metadata def test_file_not_json(self): # read a file which is not json format. metadata_path = os.path.join(DATA_DIR, "privatekey.pem") - metadata = _mtls_helper._read_dca_metadata_file(metadata_path) - - assert metadata is None + with pytest.raises(ValueError): + _mtls_helper._read_dca_metadata_file(metadata_path) class TestGetClientSslCredentials(object): diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 8d2e8a35c..54bf00b55 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -33,6 +33,7 @@ HAS_GRPC = False DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") +METADATA_PATH = os.path.join(DATA_DIR, "context_aware_metadata.json") with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: PRIVATE_KEY_BYTES = fh.read() with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: @@ -104,8 +105,12 @@ class TestSecureAuthorizedChannel(object): @mock.patch( "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True ) + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) def test_secure_authorized_channel( self, + check_dca_metadata_path, read_dca_metadata_file, secure_channel, ssl_channel_credentials, @@ -119,6 +124,7 @@ def test_secure_authorized_channel( # Mock the context aware metadata and client cert/key so mTLS SSL channel # will be used. + check_dca_metadata_path.return_value = METADATA_PATH read_dca_metadata_file.return_value = { "cert_provider_command": ["some command"] } @@ -237,8 +243,12 @@ def test_secure_authorized_channel_with_client_cert_callback_success( @mock.patch( "google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True ) + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) def test_secure_authorized_channel_with_client_cert_callback_failure( self, + check_dca_metadata_path, read_dca_metadata_file, secure_channel, ssl_channel_credentials, @@ -253,7 +263,7 @@ def test_secure_authorized_channel_with_client_cert_callback_failure( client_cert_callback.return_value = (False, None, None) # Set DCA metadata to None to not trigger mTLS DCA for test simplicity. - read_dca_metadata_file.return_value = None + check_dca_metadata_path.return_value = None google.auth.transport.grpc.secure_authorized_channel( credentials, request, target, client_cert_callback=client_cert_callback @@ -273,15 +283,19 @@ def test_secure_authorized_channel_with_client_cert_callback_failure( "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True ) @mock.patch("google.auth.transport._mtls_helper._read_dca_metadata_file", autospec=True) +@mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True +) class TestSslCredentials(object): def test_no_context_aware_metadata( self, + mock_check_dca_metadata_path, mock_read_dca_metadata_file, mock_get_client_ssl_credentials, mock_ssl_channel_credentials, ): - # Mock that _read_dca_metadata_file function returns no metadata. - mock_read_dca_metadata_file.return_value = None + # Mock that the metadata file doesn't exist. + mock_check_dca_metadata_path.return_value = None ssl_credentials = google.auth.transport.grpc.SslCredentials() @@ -295,10 +309,12 @@ def test_no_context_aware_metadata( def test_get_client_ssl_credentials_failure( self, + mock_check_dca_metadata_path, mock_read_dca_metadata_file, mock_get_client_ssl_credentials, mock_ssl_channel_credentials, ): + mock_check_dca_metadata_path.return_value = METADATA_PATH mock_read_dca_metadata_file.return_value = { "cert_provider_command": ["some command"] } @@ -311,10 +327,12 @@ def test_get_client_ssl_credentials_failure( def test_get_client_ssl_credentials_success( self, + mock_check_dca_metadata_path, mock_read_dca_metadata_file, mock_get_client_ssl_credentials, mock_ssl_channel_credentials, ): + mock_check_dca_metadata_path.return_value = METADATA_PATH mock_read_dca_metadata_file.return_value = { "cert_provider_command": ["some command"] } From 1dc22c0eb98b4cdc11bf068978e0158e6308a45c Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Tue, 3 Mar 2020 11:08:21 -0800 Subject: [PATCH 08/12] fix typo --- google/auth/transport/grpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index cba94afb5..498372b3e 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -247,5 +247,5 @@ def ssl_credentials(self): @property def is_mtls(self): - """Property indicting if the created SSL channel credentials is mutual TLS.""" + """Property indicating if the created SSL channel credentials is mutual TLS.""" return self._is_mtls From 12e69d8650b516cfa79eb3784cd93e67557f40eb Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Thu, 5 Mar 2020 01:35:05 -0800 Subject: [PATCH 09/12] update --- google/auth/transport/_mtls_helper.py | 26 +++--- google/auth/transport/grpc.py | 117 ++++++++++++++++++++++---- tests/transport/test__mtls_helper.py | 2 +- 3 files changed, 111 insertions(+), 34 deletions(-) diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index b4265ec44..1ce9fa554 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -39,8 +39,8 @@ def _check_dca_metadata_path(metadata_path): - """Check the existence of context aware metadata. If exists, return the - absolute path; otherwise return None. + """Checks for context aware metadata. If it exists, returns the absolute path; + otherwise returns None. Args: metadata_path (str): context aware metadata path. @@ -49,7 +49,6 @@ def _check_dca_metadata_path(metadata_path): str: absolute path if exists and None otherwise. """ metadata_path = path.expanduser(metadata_path) - print(metadata_path) if not path.exists(metadata_path): _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path) return None @@ -57,7 +56,7 @@ def _check_dca_metadata_path(metadata_path): def _read_dca_metadata_file(metadata_path): - """Function to load context aware metadata from the given path. + """Loads context aware metadata from the given path. Args: metadata_path (str): context aware metadata path. @@ -75,22 +74,19 @@ def _read_dca_metadata_file(metadata_path): def get_client_ssl_credentials(metadata_json): - """Function to get mTLS client side cert and key. + """Returns the client side mTLS cert and key. Args: - metadata_json (Dict[str]): metadata JSON file which contains the cert + metadata_json (Dict[str, str]): metadata JSON file which contains the cert provider command. Returns: Tuple[bytes, bytes]: client certificate and key, both in PEM format. Raises: - OSError: subprocess throws OSError if failed to run cert provider command - RuntimeError: if cert provider command has runtime error - ValueError: - if metadata json file doesn't contain cert provider command, or the - execution of this command doesn't produce both client certicate and - client key. + OSError: If the cert provider command failed to run. + RuntimeError: If the cert provider command has a runtime error. + ValueError: If the metadata json file doesn't contain the cert provider command or if the command doesn't produce both the client certificate and client key. """ # TODO: implement an in-memory cache of cert and key so we don't have to # run cert provider command every time. @@ -103,13 +99,11 @@ def get_client_ssl_credentials(metadata_json): command = metadata_json[_CERT_PROVIDER_COMMAND] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.communicate() - process_return_code = process.returncode # Check cert provider command execution error. - if process_return_code != 0: + if process.returncode != 0: raise RuntimeError( - "Cert provider command returns non-zero status code %s" - % process_return_code + "Cert provider command returns non-zero status code %s" % process.returncode ) # Extract certificate (chain) and key. diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index 498372b3e..84a16096a 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -124,11 +124,91 @@ def secure_authorized_channel( # Create a channel. channel = google.auth.transport.grpc.secure_authorized_channel( - credentials, 'speech.googleapis.com:443', request) + credentials, regular_endpoint, request, + ssl_credentials=grpc.ssl_channel_credentials()) # Use the channel to create a stub. cloud_speech.create_Speech_stub(channel) + # There are actually a couple of options to create a channel, depending + # on if you want to create a regular or mutual TLS channel. + # First let's list the endpoints (regular vs mutual TLS) to choose from. + regular_endpoint = 'speech.googleapis.com:443' + mtls_endpoint = 'speech.mtls.googleapis.com:443' + + # Option 1: create a regular (non-mutual) TLS channel by explicitly + # setting the ssl_credentials. + regular_ssl_credentials = grpc.ssl_channel_credentials() + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, regular_endpoint, request, + ssl_credentials=regular_ssl_credentials) + + # Option 2: create a mutual TLS channel by calling a callback which + # returns the call status, the client side certificate and the key. + def my_client_cert_callback(): + code_to_load_client_cert_and_key() + if loaded: + return (True, pem_cert_bytes, pem_key_bytes) + raise MyClientCertFailureException() + + try: + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, mtls_endpoint, request, + client_cert_callback=my_client_cert_callback) + except MyClientCertFailureException: + # handle the exception or use regular endpoint instead. + + # Alternatively you don't throw exceptions in the callback and return a + # False status instead. In this case `secure_authorized_channel` creates + # a regular TLS channel. Since you are still using mtls_endpoint, future + # API calls using this channel will be rejected. + def my_client_cert_callback(): + code_to_load_client_cert_and_key() + if loaded: + return (True, pem_cert_bytes, pem_key_bytes) + else: + return (False, None, None) + + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, mtls_endpoint, request, + client_cert_callback=my_client_cert_callback) + + # Option 3: use application default SSL credentials. It searches and uses + # the command in a context aware metadata file, which is available on + # devices with endpoint verification support. + # See https://cloud.google.com/endpoint-verification/docs/overview. + try: + default_ssl_credentials = SslCredentials() + except: + # Exception can be raised if the context aware metadata is malformed. + # See :class:`SslCredentials` for the possible exceptions. + + # Choose the endpoint based on the SSL credentials type. + if default_ssl_credentials.is_mtls: + endpoint_to_use = mtls_endpoint + else: + endpoint_to_use = regular_endpoint + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, endpoint_to_use, request, + ssl_credentials=default_ssl_credentials) + + # Option 4: not setting ssl_credentials and client_cert_callback. For + # devices without endpoint verification support, a regular TLS channel + # is created; otherwise, a mutual TLS channel is created, however, the + # call should be wrapped in a try/except block in case of malformed + # context aware metadata. + + # The following code uses regular_endpoint, it works the same no matter + # the created channle is regular or mutual TLS. Regular endpoint ignores + # client certificate and key. + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, regular_endpoint, request) + + # The following code uses mtls_endpoint, if the created channle is regular, + # future API calls using this channel will be rejected. + channel = google.auth.transport.grpc.secure_authorized_channel( + credentials, mtls_endpoint, request) + Args: credentials (google.auth.credentials.Credentials): The credentials to add to requests. @@ -155,17 +235,17 @@ def secure_authorized_channel( grpc.Channel: The created gRPC channel. Raises: - OSError: cert provider command launch failure, in application default SSL - credentials loading process on devices with endpoint verification - support. - RuntimeError: cert provider command runtime error, in application + OSError: If the cert provider command launch fails during the application default SSL credentials loading process on devices with endpoint verification support. + RuntimeError: If the cert provider command has a runtime error during the + application default SSL credentials loading process on devices with + endpoint verification support. ValueError: - if context aware metadata file is malformed, or cert provider - command doesn't produce both client certicate and key, in application - default SSL credentials loading process on devices with endpoint - verification support. + If the context aware metadata file is malformed or if the cert provider + command doesn't produce both client certificate and key during the + application default SSL credentials loading process on devices with + endpoint verification support. """ # Create the metadata plugin for inserting the authorization header. metadata_plugin = AuthMetadataPlugin(credentials, request) @@ -201,7 +281,7 @@ def secure_authorized_channel( class SslCredentials: """Class for application default SSL credentials. - For device with endpoint verification support, device certificate will be + For devices with endpoint verification support, a device certificate will be automatically loaded and mutual TLS will be established. See https://cloud.google.com/endpoint-verification/docs/overview. """ @@ -220,17 +300,20 @@ def __init__(self): def ssl_credentials(self): """Get the created SSL channel credentials. - For device with endpoint verification support, if device certificate + For devices with endpoint verification support, if the device certificate loading has any problems, corresponding exceptions will be raised. For - device without endpoint verification support, no exceptions will be + a device without endpoint verification support, no exceptions will be raised. + Returns: + grpc.ChannelCredentials: The created grpc channel credentials. + Raises: - OSError: cert provider command launch failure - RuntimeError: cert provider command runtime error + OSError: If the cert provider command launch fails. + RuntimeError: If the cert provider command has a runtime error. ValueError: - if context aware metadata file is malformed, or cert provider - command doesn't produce both client certicate and key. + If the context aware metadata file is malformed or if the cert provider + command doesn't produce both the client certificate and key. """ if self._context_aware_metadata_path: metadata = _mtls_helper._read_dca_metadata_file( @@ -247,5 +330,5 @@ def ssl_credentials(self): @property def is_mtls(self): - """Property indicating if the created SSL channel credentials is mutual TLS.""" + """Indicates if the created SSL channel credentials is mutual TLS.""" return self._is_mtls diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index b0be9bc34..6e7175f17 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -47,7 +47,7 @@ def check_cert_and_key(content, expected_cert, expected_key): class TestCertAndKeyRegex(object): def test_cert_and_key(self): - # Test signle cert and single key + # Test single cert and single key check_cert_and_key( PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES ) From b2e394b6d5005fc5b956b6ba7825cc62f8265403 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Thu, 5 Mar 2020 14:34:44 -0800 Subject: [PATCH 10/12] update docstring --- google/auth/transport/grpc.py | 61 +++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index 84a16096a..e039d4046 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -130,21 +130,28 @@ def secure_authorized_channel( # Use the channel to create a stub. cloud_speech.create_Speech_stub(channel) - # There are actually a couple of options to create a channel, depending - # on if you want to create a regular or mutual TLS channel. - # First let's list the endpoints (regular vs mutual TLS) to choose from. + Usage: + + There are actually a couple of options to create a channel, depending on if + you want to create a regular or mutual TLS channel. + + First let's list the endpoints (regular vs mutual TLS) to choose from:: + regular_endpoint = 'speech.googleapis.com:443' mtls_endpoint = 'speech.mtls.googleapis.com:443' - # Option 1: create a regular (non-mutual) TLS channel by explicitly - # setting the ssl_credentials. + Option 1: create a regular (non-mutual) TLS channel by explicitly setting + the ssl_credentials:: + regular_ssl_credentials = grpc.ssl_channel_credentials() + channel = google.auth.transport.grpc.secure_authorized_channel( credentials, regular_endpoint, request, ssl_credentials=regular_ssl_credentials) - # Option 2: create a mutual TLS channel by calling a callback which - # returns the call status, the client side certificate and the key. + Option 2: create a mutual TLS channel by calling a callback which returns + the call status, the client side certificate and the key:: + def my_client_cert_callback(): code_to_load_client_cert_and_key() if loaded: @@ -158,10 +165,11 @@ def my_client_cert_callback(): except MyClientCertFailureException: # handle the exception or use regular endpoint instead. - # Alternatively you don't throw exceptions in the callback and return a - # False status instead. In this case `secure_authorized_channel` creates - # a regular TLS channel. Since you are still using mtls_endpoint, future - # API calls using this channel will be rejected. + Alternatively you don't throw exceptions in the callback and return a False + status instead. In this case `secure_authorized_channel` creates a regular + TLS channel. If your API mtls_endpoint is confgured to require client SSL + credentials, then API calls using this channel will be rejected:: + def my_client_cert_callback(): code_to_load_client_cert_and_key() if loaded: @@ -173,10 +181,11 @@ def my_client_cert_callback(): credentials, mtls_endpoint, request, client_cert_callback=my_client_cert_callback) - # Option 3: use application default SSL credentials. It searches and uses - # the command in a context aware metadata file, which is available on - # devices with endpoint verification support. - # See https://cloud.google.com/endpoint-verification/docs/overview. + Option 3: use application default SSL credentials. It searches and uses + the command in a context aware metadata file, which is available on devices + with endpoint verification support. + See https://cloud.google.com/endpoint-verification/docs/overview:: + try: default_ssl_credentials = SslCredentials() except: @@ -192,20 +201,22 @@ def my_client_cert_callback(): credentials, endpoint_to_use, request, ssl_credentials=default_ssl_credentials) - # Option 4: not setting ssl_credentials and client_cert_callback. For - # devices without endpoint verification support, a regular TLS channel - # is created; otherwise, a mutual TLS channel is created, however, the - # call should be wrapped in a try/except block in case of malformed - # context aware metadata. + Option 4: not setting ssl_credentials and client_cert_callback. For devices + without endpoint verification support, a regular TLS channel is created; + otherwise, a mutual TLS channel is created, however, the call should be + wrapped in a try/except block in case of malformed context aware metadata. + + The following code uses regular_endpoint, it works the same no matter the + created channle is regular or mutual TLS. Regular endpoint ignores client + certificate and key:: - # The following code uses regular_endpoint, it works the same no matter - # the created channle is regular or mutual TLS. Regular endpoint ignores - # client certificate and key. channel = google.auth.transport.grpc.secure_authorized_channel( credentials, regular_endpoint, request) - # The following code uses mtls_endpoint, if the created channle is regular, - # future API calls using this channel will be rejected. + The following code uses mtls_endpoint, if the created channle is regular, + and API mtls_endpoint is confgured to require client SSL credentials, API + calls using this channel will be rejected:: + channel = google.auth.transport.grpc.secure_authorized_channel( credentials, mtls_endpoint, request) From e6f2f1031e26e6e933baa606fdd924b6828a5e9c Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Wed, 11 Mar 2020 12:11:45 -0700 Subject: [PATCH 11/12] don't use ADC if client_cert_callback is provided --- google/auth/transport/grpc.py | 42 +++++++++++------------------------ tests/transport/test_grpc.py | 25 ++++++++------------- 2 files changed, 22 insertions(+), 45 deletions(-) diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index e039d4046..90a4ed39d 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -150,12 +150,12 @@ def secure_authorized_channel( ssl_credentials=regular_ssl_credentials) Option 2: create a mutual TLS channel by calling a callback which returns - the call status, the client side certificate and the key:: + the client side certificate and the key:: def my_client_cert_callback(): code_to_load_client_cert_and_key() if loaded: - return (True, pem_cert_bytes, pem_key_bytes) + return (pem_cert_bytes, pem_key_bytes) raise MyClientCertFailureException() try: @@ -163,23 +163,7 @@ def my_client_cert_callback(): credentials, mtls_endpoint, request, client_cert_callback=my_client_cert_callback) except MyClientCertFailureException: - # handle the exception or use regular endpoint instead. - - Alternatively you don't throw exceptions in the callback and return a False - status instead. In this case `secure_authorized_channel` creates a regular - TLS channel. If your API mtls_endpoint is confgured to require client SSL - credentials, then API calls using this channel will be rejected:: - - def my_client_cert_callback(): - code_to_load_client_cert_and_key() - if loaded: - return (True, pem_cert_bytes, pem_key_bytes) - else: - return (False, None, None) - - channel = google.auth.transport.grpc.secure_authorized_channel( - credentials, mtls_endpoint, request, - client_cert_callback=my_client_cert_callback) + # handle the exception Option 3: use application default SSL credentials. It searches and uses the command in a context aware metadata file, which is available on devices @@ -234,12 +218,10 @@ def my_client_cert_callback(): providing both will raise an exception. If ssl_credentials is None and client_cert_callback is None or fails, application default SSL credentials will be used. - client_cert_callback (Callable[[], (bool, bytes, bytes)]): Optional + client_cert_callback (Callable[[], (bytes, bytes)]): Optional callback function to obtain client certicate and key for mutual TLS connection. This argument is mutually exclusive with ssl_credentials; providing both will raise an exception. - If ssl_credentials is None and client_cert_callback is None or - fails, application default SSL credentials will be used. kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. Returns: @@ -270,16 +252,18 @@ def my_client_cert_callback(): "these are mutually exclusive." ) - if client_cert_callback: - success, cert, key = client_cert_callback() - if success: + # If SSL credentials are not explicitly set, try client_cert_callback and ADC. + if not ssl_credentials: + if client_cert_callback: + # Use the callback if provided. + cert, key = client_cert_callback() ssl_credentials = grpc.ssl_channel_credentials( certificate_chain=cert, private_key=key ) - - if ssl_credentials is None: - adc_ssl_credentils = SslCredentials() - ssl_credentials = adc_ssl_credentils.ssl_credentials + else: + # Use application default SSL credentials. + adc_ssl_credentils = SslCredentials() + ssl_credentials = adc_ssl_credentils.ssl_credentials # Combine the ssl credentials and the authorization credentials. composite_credentials = grpc.composite_channel_credentials( diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 54bf00b55..23e62a213 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -108,7 +108,7 @@ class TestSecureAuthorizedChannel(object): @mock.patch( "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True ) - def test_secure_authorized_channel( + def test_secure_authorized_channel_adc( self, check_dca_metadata_path, read_dca_metadata_file, @@ -222,7 +222,7 @@ def test_secure_authorized_channel_with_client_cert_callback_success( request = mock.Mock() target = "example.com:80" client_cert_callback = mock.Mock() - client_cert_callback.return_value = (True, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES) + client_cert_callback.return_value = (PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES) google.auth.transport.grpc.secure_authorized_channel( credentials, request, target, client_cert_callback=client_cert_callback @@ -259,23 +259,16 @@ def test_secure_authorized_channel_with_client_cert_callback_failure( credentials = mock.Mock() request = mock.Mock() target = "example.com:80" - client_cert_callback = mock.Mock() - client_cert_callback.return_value = (False, None, None) - - # Set DCA metadata to None to not trigger mTLS DCA for test simplicity. - check_dca_metadata_path.return_value = None - google.auth.transport.grpc.secure_authorized_channel( - credentials, request, target, client_cert_callback=client_cert_callback - ) + client_cert_callback = mock.Mock() + client_cert_callback.side_effect = Exception("callback exception") - client_cert_callback.assert_called_once() - ssl_channel_credentials.assert_called_once_with() + with pytest.raises(Exception) as excinfo: + google.auth.transport.grpc.secure_authorized_channel( + credentials, request, target, client_cert_callback=client_cert_callback + ) - # Check the composite credentials call. - composite_channel_credentials.assert_called_once_with( - ssl_channel_credentials.return_value, metadata_call_credentials.return_value - ) + assert str(excinfo.value) == "callback exception" @mock.patch("grpc.ssl_channel_credentials", autospec=True) From da135ad3f8d3b6dd4c1ba46607f14e88a7302b40 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Wed, 11 Mar 2020 12:22:04 -0700 Subject: [PATCH 12/12] update docstring --- google/auth/transport/grpc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index 90a4ed39d..ca387392e 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -216,12 +216,14 @@ def my_client_cert_callback(): credentials. This can be used to specify different certificates. This argument is mutually exclusive with client_cert_callback; providing both will raise an exception. - If ssl_credentials is None and client_cert_callback is None or - fails, application default SSL credentials will be used. + If ssl_credentials and client_cert_callback are None, application + default SSL credentials will be used. client_cert_callback (Callable[[], (bytes, bytes)]): Optional callback function to obtain client certicate and key for mutual TLS connection. This argument is mutually exclusive with ssl_credentials; providing both will raise an exception. + If ssl_credentials and client_cert_callback are None, application + default SSL credentials will be used. kwargs: Additional arguments to pass to :func:`grpc.secure_channel`. Returns: