Skip to content
This repository has been archived by the owner on Mar 20, 2018. It is now read-only.

Refactor auth code to ease transition to google-auth #135

Merged
merged 7 commits into from
Nov 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ based on `GRPC`_ and `Google APIs`_ conventions.
:toctree: generated

google.gax
google.gax.auth
google.gax.api_callable
google.gax.bundling
google.gax.config
Expand Down
94 changes: 94 additions & 0 deletions google/gax/_grpc_oauth2client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2015, Google Inc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. So this file goes away entirely once googleapis/google-auth-library-python#67 is merged, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll stay around until we formally deprecate oauth2client. I'll also be adding another module named _grpc_google_auth.py that'll be a very small interface to google.auth.

# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# pylint: disable=too-few-public-methods
"""Provides gRPC authentication support using oauth2client."""

from __future__ import absolute_import

import grpc
import oauth2client.client


class AuthMetadataPlugin(grpc.AuthMetadataPlugin):
"""A `gRPC AuthMetadataPlugin`_ that inserts the credentials into each
request.

.. _gRPC AuthMetadataPlugin:
http://www.grpc.io/grpc/python/grpc.html#grpc.AuthMetadataPlugin

Args:
credentials (oauth2client.client.Credentials): The credentials to
add to requests.
"""
def __init__(self, credentials, ):
self._credentials = credentials

def _get_authorization_headers(self):
"""Gets the authorization headers for a request.

Returns:
Sequence[Tuple[str, str]]: A list of request headers (key, value)
to add to the request.
"""
bearer_token = self._credentials.get_access_token().access_token
return [
('authorization', 'Bearer {}'.format(bearer_token))
]

def __call__(self, context, callback):
"""Passes authorization metadata into the given callback.

Args:
context (grpc.AuthMetadataContext): The RPC context.
callback (grpc.AuthMetadataPluginCallback): The callback that will
be invoked to pass in the authorization metadata.
"""
callback(self._get_authorization_headers(), None)


def get_default_credentials(scopes):
"""Gets the Application Default Credentials."""
credentials = (
oauth2client.client.GoogleCredentials.get_application_default())
return credentials.create_scoped(scopes or [])


def secure_authorized_channel(
credentials, target, ssl_credentials=None):
"""Creates a secure authorized gRPC channel."""
if ssl_credentials is None:
ssl_credentials = grpc.ssl_channel_credentials()

metadata_plugin = AuthMetadataPlugin(credentials)
call_credentials = grpc.metadata_call_credentials(metadata_plugin)
channel_creds = grpc.composite_channel_credentials(
ssl_credentials, call_credentials)

return grpc.secure_channel(target, channel_creds)
49 changes: 0 additions & 49 deletions google/gax/auth.py

This file was deleted.

61 changes: 22 additions & 39 deletions google/gax/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
"""Adapts the grpc surface."""

from __future__ import absolute_import
import grpc

from grpc import RpcError, StatusCode
from . import auth
from . import _grpc_oauth2client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think this file accidentally got left out of the commit

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, fixed. I also did away with the use of closures and went with @dhermes' suggestion of using a class like we do in google-auth.



API_ERRORS = (RpcError, )
Expand Down Expand Up @@ -73,53 +73,36 @@ def exc_to_code(exc):
return None


def _make_grpc_auth_func(auth_func):
"""Creates the auth func expected by the grpc callback."""

def grpc_auth(dummy_context, callback):
"""The auth signature required by grpc."""
callback(auth_func(), None)

return grpc_auth


def _make_channel_creds(auth_func, ssl_creds):
"""Converts the auth func into the composite creds expected by grpc."""
grpc_auth_func = _make_grpc_auth_func(auth_func)
call_creds = grpc.metadata_call_credentials(grpc_auth_func)
return grpc.composite_channel_credentials(ssl_creds, call_creds)


def create_stub(generated_create_stub, service_path, port, ssl_creds=None,
channel=None, metadata_transformer=None, scopes=None):
def create_stub(generated_create_stub, channel=None, service_path=None,
service_port=None, credentials=None, scopes=None,
ssl_credentials=None):
"""Creates a gRPC client stub.

Args:
generated_create_stub: The generated gRPC method to create a stub.
service_path: The domain name of the API remote host.
port: The port on which to connect to the remote host.
ssl_creds: A ClientCredentials object for use with an SSL-enabled
Channel. If none, credentials are pulled from a default location.
channel: A Channel object through which to make calls. If none, a secure
channel is constructed.
metadata_transformer: A function that transforms the metadata for
requests, e.g., to give OAuth credentials.
channel is constructed. If specified, all remaining arguments are
ignored.
service_path: The domain name of the API remote host.
service_port: The port on which to connect to the remote host.
credentials: The authorization credentials to attach to requests.
These credentials identify your application to the service.
scopes: The OAuth scopes for this service. This parameter is ignored if
a custom metadata_transformer is supplied.
a credentials is specified.
ssl_credentials: gRPC channel credentials used to create a secure
gRPC channel. If not specified, SSL credentials will be created
using default certificates.

Returns:
A gRPC client stub.
"""
if channel is None:
if ssl_creds is None:
ssl_creds = grpc.ssl_channel_credentials()
if metadata_transformer is None:
if scopes is None:
scopes = []
metadata_transformer = auth.make_auth_func(scopes)

channel_creds = _make_channel_creds(metadata_transformer, ssl_creds)
target = '{}:{}'.format(service_path, port)
channel = grpc.secure_channel(target, channel_creds)
target = '{}:{}'.format(service_path, service_port)

if credentials is None:
credentials = _grpc_oauth2client.get_default_credentials(scopes)

channel = _grpc_oauth2client.secure_authorized_channel(
credentials, target, ssl_credentials=ssl_credentials)

return generated_create_stub(channel)
57 changes: 49 additions & 8 deletions test/test_auth.py → test/test__grpc_oauth2client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,70 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# pylint: disable=missing-docstring,no-self-use,no-init,invalid-name
"""Unit tests for auth."""
"""Unit tests for _grpc_oauth2client."""

from __future__ import absolute_import

import mock
import unittest2

from google.gax import auth
from google.gax import _grpc_oauth2client


class TestMakeAuthFunc(unittest2.TestCase):
class TestAuthMetadataPlugin(unittest2.TestCase):
TEST_TOKEN = 'an_auth_token'

def test(self):
credentials = mock.Mock()
credentials.get_access_token.return_value = mock.Mock(
access_token=self.TEST_TOKEN)

metadata_plugin = _grpc_oauth2client.AuthMetadataPlugin(credentials)

self.assertFalse(credentials.create_scoped.called)

callback = mock.Mock()
metadata_plugin(None, callback)

callback.assert_called_once_with([
('authorization', 'Bearer {}'.format(self.TEST_TOKEN))], None)


class TestGetDefaultCredentials(unittest2.TestCase):
TEST_TOKEN = 'an_auth_token'

@mock.patch('oauth2client.client.GoogleCredentials.get_application_default')
def test_uses_application_default_credentials(self, factory):
def test(self, factory):
creds = mock.Mock()
creds.get_access_token.return_value = mock.Mock(
access_token=self.TEST_TOKEN)
factory_mock_config = {'create_scoped.return_value': creds}
factory.return_value = mock.Mock(**factory_mock_config)
fake_scopes = ['fake', 'scopes']
the_func = auth.make_auth_func(fake_scopes)

got = _grpc_oauth2client.get_default_credentials(fake_scopes)

factory.return_value.create_scoped.assert_called_once_with(fake_scopes)
got = the_func()
want = [('authorization', 'Bearer an_auth_token')]
self.assertEqual(got, want)
self.assertEqual(got, creds)


class TestSecureAuthorizedChannel(unittest2.TestCase):
FAKE_TARGET = 'service_path:10101'

@mock.patch('grpc.composite_channel_credentials')
@mock.patch('grpc.ssl_channel_credentials')
@mock.patch('grpc.secure_channel')
@mock.patch('google.gax._grpc_oauth2client.AuthMetadataPlugin')
def test(
self, auth_metadata_plugin, secure_channel, ssl_channel_credentials,
composite_channel_credentials):
credentials = mock.Mock()

got_channel = _grpc_oauth2client.secure_authorized_channel(
credentials, self.FAKE_TARGET)

ssl_channel_credentials.assert_called_once_with()
secure_channel.assert_called_once_with(
self.FAKE_TARGET, composite_channel_credentials.return_value)
auth_metadata_plugin.assert_called_once_with(credentials)
self.assertEqual(got_channel, secure_channel.return_value)
Loading