Skip to content
This repository has been archived by the owner on Jan 18, 2025. It is now read-only.

Commit

Permalink
Handle missing storage files
Browse files Browse the repository at this point in the history
* Warn user if storage file is missing
* Raise an `IOError` exception if the given filename is a directory.
* (test) Expanding single-letter variables
* (test) `assertEqual(None, <obj>)` -> `assertIsNone(<obj>)`
* (test) `assertNotEqual(None, <obj>)` -> `assertIsNotNone(<obj>)`
  • Loading branch information
pferate committed Jul 29, 2016
1 parent ae73312 commit a2a0ede
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 43 deletions.
11 changes: 10 additions & 1 deletion oauth2client/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@

import os
import threading
import warnings

from oauth2client import client


__author__ = 'jcgregorio@google.com (Joe Gregorio)'

_SYM_LINK_MESSAGE = 'File: {0}: Is a symbolic link.'
_IS_DIR_MESSAGE = '{0}: Is a directory'
_MISSING_FILE_MESSAGE = 'Cannot access {0}: No such file or directory'


class CredentialsFileSymbolicLinkError(Exception):
"""Credentials files must not be symbolic links."""
Expand All @@ -41,7 +46,11 @@ def __init__(self, filename):
def _validate_file(self):
if os.path.islink(self._filename):
raise CredentialsFileSymbolicLinkError(
'File: {0} is a symbolic link.'.format(self._filename))
_SYM_LINK_MESSAGE.format(self._filename))
elif os.path.isdir(self._filename):
raise IOError(_IS_DIR_MESSAGE.format(self._filename))
elif not os.path.isfile(self._filename):
warnings.warn(_MISSING_FILE_MESSAGE.format(self._filename))

def locked_get(self):
"""Retrieve Credential from file.
Expand Down
99 changes: 57 additions & 42 deletions tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import pickle
import stat
import tempfile
import warnings

import mock
import six
from six.moves import http_client
import unittest2
Expand Down Expand Up @@ -54,6 +56,7 @@ def tearDown(self):
pass

def setUp(self):
warnings.simplefilter("ignore")
try:
os.unlink(FILENAME)
except OSError:
Expand All @@ -74,40 +77,52 @@ def _create_test_credentials(self, client_id='some_client_id',
user_agent)
return credentials

def test_non_existent_file_storage(self):
s = file.Storage(FILENAME)
credentials = s.get()
self.assertEquals(None, credentials)
@mock.patch('warnings.warn')
def test_non_existent_file_storage(self, warn_mock):
storage = file.Storage(FILENAME)
credentials = storage.get()
warn_mock.assert_called_with(
file._MISSING_FILE_MESSAGE.format(FILENAME))
self.assertIsNone(credentials)

def test_directory_file_storage(self):
storage = file.Storage(FILENAME)
os.mkdir(FILENAME)
try:
with self.assertRaises(IOError):
storage.get()
finally:
os.rmdir(FILENAME)

@unittest2.skipIf(not hasattr(os, 'symlink'), 'No symlink available')
def test_no_sym_link_credentials(self):
SYMFILENAME = FILENAME + '.sym'
os.symlink(FILENAME, SYMFILENAME)
s = file.Storage(SYMFILENAME)
storage = file.Storage(SYMFILENAME)
try:
with self.assertRaises(file.CredentialsFileSymbolicLinkError):
s.get()
storage.get()
finally:
os.unlink(SYMFILENAME)

def test_pickle_and_json_interop(self):
# Write a file with a pickled OAuth2Credentials.
credentials = self._create_test_credentials()

f = open(FILENAME, 'wb')
pickle.dump(credentials, f)
f.close()
cred_file = open(FILENAME, 'wb')
pickle.dump(credentials, cred_file)
cred_file.close()

# Storage should be not be able to read that object, as the capability
# to read and write credentials as pickled objects has been removed.
s = file.Storage(FILENAME)
read_credentials = s.get()
self.assertEquals(None, read_credentials)
storage = file.Storage(FILENAME)
read_credentials = storage.get()
self.assertIsNone(read_credentials)

# Now write it back out and confirm it has been rewritten as JSON
s.put(credentials)
with open(FILENAME) as f:
data = json.load(f)
storage.put(credentials)
with open(FILENAME) as cred_file:
data = json.load(cred_file)

self.assertEquals(data['access_token'], 'foo')
self.assertEquals(data['_class'], 'OAuth2Credentials')
Expand All @@ -118,12 +133,12 @@ def test_token_refresh_store_expired(self):
datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration)

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
storage = file.Storage(FILENAME)
storage.put(credentials)
credentials = storage.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
storage.put(new_cred)

access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
Expand All @@ -141,12 +156,12 @@ def test_token_refresh_store_expires_soon(self):
datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration)

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
storage = file.Storage(FILENAME)
storage.put(credentials)
credentials = storage.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
storage.put(new_cred)

access_token = '1/3w'
token_response = {'access_token': access_token, 'expires_in': 3600}
Expand All @@ -170,12 +185,12 @@ def test_token_refresh_good_store(self):
datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration)

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
storage = file.Storage(FILENAME)
storage.put(credentials)
credentials = storage.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
storage.put(new_cred)

credentials._refresh(None)
self.assertEquals(credentials.access_token, 'bar')
Expand All @@ -185,12 +200,12 @@ def test_token_refresh_stream_body(self):
datetime.timedelta(minutes=15))
credentials = self._create_test_credentials(expiration=expiration)

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
storage = file.Storage(FILENAME)
storage.put(credentials)
credentials = storage.get()
new_cred = copy.copy(credentials)
new_cred.access_token = 'bar'
s.put(new_cred)
storage.put(new_cred)

valid_access_token = '1/3w'
token_response = {'access_token': valid_access_token,
Expand All @@ -215,25 +230,25 @@ def test_token_refresh_stream_body(self):
def test_credentials_delete(self):
credentials = self._create_test_credentials()

s = file.Storage(FILENAME)
s.put(credentials)
credentials = s.get()
self.assertNotEquals(None, credentials)
s.delete()
credentials = s.get()
self.assertEquals(None, credentials)
storage = file.Storage(FILENAME)
storage.put(credentials)
credentials = storage.get()
self.assertIsNotNone(credentials)
storage.delete()
credentials = storage.get()
self.assertIsNone(credentials)

def test_access_token_credentials(self):
access_token = 'foo'
user_agent = 'refresh_checker/1.0'

credentials = client.AccessTokenCredentials(access_token, user_agent)

s = file.Storage(FILENAME)
credentials = s.put(credentials)
credentials = s.get()
storage = file.Storage(FILENAME)
credentials = storage.put(credentials)
credentials = storage.get()

self.assertNotEquals(None, credentials)
self.assertIsNotNone(credentials)
self.assertEquals('foo', credentials.access_token)

self.assertTrue(os.path.exists(FILENAME))
Expand Down

0 comments on commit a2a0ede

Please sign in to comment.