diff --git a/storage/google/cloud/storage/_helpers.py b/storage/google/cloud/storage/_helpers.py index 3a4eb2e232a2..5322f8a615b7 100644 --- a/storage/google/cloud/storage/_helpers.py +++ b/storage/google/cloud/storage/_helpers.py @@ -67,6 +67,11 @@ def client(self): """Abstract getter for the object client.""" raise NotImplementedError + @property + def user_project(self): + """Abstract getter for the object user_project.""" + raise NotImplementedError + def _require_client(self, client): """Check client or verify over-ride. @@ -94,6 +99,8 @@ def reload(self, client=None): # Pass only '?projection=noAcl' here because 'acl' and related # are handled via custom endpoints. query_params = {'projection': 'noAcl'} + if self.user_project is not None: + query_params['userProject'] = self.user_project api_response = client._connection.api_request( method='GET', path=self.path, query_params=query_params, _target_object=self) @@ -140,13 +147,16 @@ def patch(self, client=None): client = self._require_client(client) # Pass '?projection=full' here because 'PATCH' documented not # to work properly w/ 'noAcl'. + query_params = {'projection': 'full'} + if self.user_project is not None: + query_params['userProject'] = self.user_project update_properties = {key: self._properties[key] for key in self._changes} # Make the API call. api_response = client._connection.api_request( method='PATCH', path=self.path, data=update_properties, - query_params={'projection': 'full'}, _target_object=self) + query_params=query_params, _target_object=self) self._set_properties(api_response) def update(self, client=None): @@ -160,9 +170,12 @@ def update(self, client=None): ``client`` stored on the current object. """ client = self._require_client(client) + query_params = {'projection': 'full'} + if self.user_project is not None: + query_params['userProject'] = self.user_project api_response = client._connection.api_request( method='PUT', path=self.path, data=self._properties, - query_params={'projection': 'full'}, _target_object=self) + query_params=query_params, _target_object=self) self._set_properties(api_response) diff --git a/storage/google/cloud/storage/acl.py b/storage/google/cloud/storage/acl.py index c4525ea88735..240662c4dc8d 100644 --- a/storage/google/cloud/storage/acl.py +++ b/storage/google/cloud/storage/acl.py @@ -198,6 +198,7 @@ class ACL(object): # as properties). reload_path = None save_path = None + user_project = None def __init__(self): self.entities = {} @@ -405,10 +406,18 @@ def reload(self, client=None): """ path = self.reload_path client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project self.entities.clear() - found = client._connection.api_request(method='GET', path=path) + found = client._connection.api_request( + method='GET', + path=path, + query_params=query_params, + ) self.loaded = True for entry in found.get('items', ()): self.add_entity(self.entity_from_dict(entry)) @@ -435,8 +444,12 @@ def _save(self, acl, predefined, client): acl = [] query_params[self._PREDEFINED_QUERY_PARAM] = predefined + if self.user_project is not None: + query_params['userProject'] = self.user_project + path = self.save_path client = self._require_client(client) + result = client._connection.api_request( method='PATCH', path=path, @@ -532,6 +545,11 @@ def save_path(self): """Compute the path for PATCH API requests for this ACL.""" return self.bucket.path + @property + def user_project(self): + """Compute the user project charged for API requests for this ACL.""" + return self.bucket.user_project + class DefaultObjectACL(BucketACL): """A class representing the default object ACL for a bucket.""" @@ -565,3 +583,8 @@ def reload_path(self): def save_path(self): """Compute the path for PATCH API requests for this ACL.""" return self.blob.path + + @property + def user_project(self): + """Compute the user project charged for API requests for this ACL.""" + return self.blob.user_project diff --git a/storage/google/cloud/storage/blob.py b/storage/google/cloud/storage/blob.py index d2784d6e9ad6..21d92acd955a 100644 --- a/storage/google/cloud/storage/blob.py +++ b/storage/google/cloud/storage/blob.py @@ -34,7 +34,11 @@ import time import warnings +from six.moves.urllib.parse import parse_qsl from six.moves.urllib.parse import quote +from six.moves.urllib.parse import urlencode +from six.moves.urllib.parse import urlsplit +from six.moves.urllib.parse import urlunsplit from google import resumable_media from google.resumable_media.requests import ChunkedDownload @@ -220,6 +224,16 @@ def client(self): """The client bound to this blob.""" return self.bucket.client + @property + def user_project(self): + """Project ID used for API requests made via this blob. + + Derived from bucket's value. + + :rtype: str + """ + return self.bucket.user_project + @property def public_url(self): """The public URL for this blob's object. @@ -328,10 +342,14 @@ def exists(self, client=None): :returns: True if the blob exists in Cloud Storage. """ client = self._require_client(client) + # We only need the status code (200 or not) so we seek to + # minimize the returned payload. + query_params = {'fields': 'name'} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + try: - # We only need the status code (200 or not) so we seek to - # minimize the returned payload. - query_params = {'fields': 'name'} # We intentionally pass `_target_object=None` since fields=name # would limit the local properties. client._connection.api_request( @@ -385,13 +403,19 @@ def _get_download_url(self): :rtype: str :returns: The download URL for the current blob. """ + name_value_pairs = [] if self.media_link is None: - download_url = _DOWNLOAD_URL_TEMPLATE.format(path=self.path) + base_url = _DOWNLOAD_URL_TEMPLATE.format(path=self.path) if self.generation is not None: - download_url += u'&generation={:d}'.format(self.generation) - return download_url + name_value_pairs.append( + ('generation', '{:d}'.format(self.generation))) else: - return self.media_link + base_url = self.media_link + + if self.user_project is not None: + name_value_pairs.append(('userProject', self.user_project)) + + return _add_query_parameters(base_url, name_value_pairs) def _do_download(self, transport, file_obj, download_url, headers): """Perform a download without any error handling. @@ -637,8 +661,14 @@ def _do_multipart_upload(self, client, stream, content_type, info = self._get_upload_arguments(content_type) headers, object_metadata, content_type = info - upload_url = _MULTIPART_URL_TEMPLATE.format( + base_url = _MULTIPART_URL_TEMPLATE.format( bucket_path=self.bucket.path) + name_value_pairs = [] + + if self.user_project is not None: + name_value_pairs.append(('userProject', self.user_project)) + + upload_url = _add_query_parameters(base_url, name_value_pairs) upload = MultipartUpload(upload_url, headers=headers) if num_retries is not None: @@ -709,8 +739,14 @@ def _initiate_resumable_upload(self, client, stream, content_type, if extra_headers is not None: headers.update(extra_headers) - upload_url = _RESUMABLE_URL_TEMPLATE.format( + base_url = _RESUMABLE_URL_TEMPLATE.format( bucket_path=self.bucket.path) + name_value_pairs = [] + + if self.user_project is not None: + name_value_pairs.append(('userProject', self.user_project)) + + upload_url = _add_query_parameters(base_url, name_value_pairs) upload = ResumableUpload(upload_url, chunk_size, headers=headers) if num_retries is not None: @@ -1064,9 +1100,16 @@ def get_iam_policy(self, client=None): the ``getIamPolicy`` API request. """ client = self._require_client(client) + + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + info = client._connection.api_request( method='GET', path='%s/iam' % (self.path,), + query_params=query_params, _target_object=None) return Policy.from_api_repr(info) @@ -1089,11 +1132,18 @@ def set_iam_policy(self, policy, client=None): the ``setIamPolicy`` API request. """ client = self._require_client(client) + + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + resource = policy.to_api_repr() resource['resourceId'] = self.path info = client._connection.api_request( method='PUT', path='%s/iam' % (self.path,), + query_params=query_params, data=resource, _target_object=None) return Policy.from_api_repr(info) @@ -1117,12 +1167,17 @@ def test_iam_permissions(self, permissions, client=None): request. """ client = self._require_client(client) - query = {'permissions': permissions} + query_params = {'permissions': permissions} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + path = '%s/iam/testPermissions' % (self.path,) resp = client._connection.api_request( method='GET', path=path, - query_params=query) + query_params=query_params) + return resp.get('permissions', []) def make_public(self, client=None): @@ -1152,13 +1207,22 @@ def compose(self, sources, client=None): """ if self.content_type is None: raise ValueError("Destination 'content_type' not set.") + client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + request = { 'sourceObjects': [{'name': source.name} for source in sources], 'destination': self._properties.copy(), } api_response = client._connection.api_request( - method='POST', path=self.path + '/compose', data=request, + method='POST', + path=self.path + '/compose', + query_params=query_params, + data=request, _target_object=self) self._set_properties(api_response) @@ -1190,14 +1254,20 @@ def rewrite(self, source, token=None, client=None): headers.update(_get_encryption_headers( source._encryption_key, source=True)) + query_params = {} + if token: - query_params = {'rewriteToken': token} - else: - query_params = {} + query_params['rewriteToken'] = token + + if self.user_project is not None: + query_params['userProject'] = self.user_project api_response = client._connection.api_request( - method='POST', path=source.path + '/rewriteTo' + self.path, - query_params=query_params, data=self._properties, headers=headers, + method='POST', + path=source.path + '/rewriteTo' + self.path, + query_params=query_params, + data=self._properties, + headers=headers, _target_object=self) rewritten = int(api_response['totalBytesRewritten']) size = int(api_response['objectSize']) @@ -1228,13 +1298,22 @@ def update_storage_class(self, new_class, client=None): raise ValueError("Invalid storage class: %s" % (new_class,)) client = self._require_client(client) + + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + headers = _get_encryption_headers(self._encryption_key) headers.update(_get_encryption_headers( self._encryption_key, source=True)) api_response = client._connection.api_request( - method='POST', path=self.path + '/rewriteTo' + self.path, - data={'storageClass': new_class}, headers=headers, + method='POST', + path=self.path + '/rewriteTo' + self.path, + query_params=query_params, + data={'storageClass': new_class}, + headers=headers, _target_object=self) self._set_properties(api_response['resource']) @@ -1603,3 +1682,24 @@ def _raise_from_invalid_response(error): to the failed status code """ raise exceptions.from_http_response(error.response) + + +def _add_query_parameters(base_url, name_value_pairs): + """Add one query parameter to a base URL. + + :type base_url: string + :param base_url: Base URL (may already contain query parameters) + + :type name_value_pairs: list of (string, string) tuples. + :param name_value_pairs: Names and values of the query parameters to add + + :rtype: string + :returns: URL with additional query strings appended. + """ + if len(name_value_pairs) == 0: + return base_url + + scheme, netloc, path, query, frag = urlsplit(base_url) + query = parse_qsl(query) + query.extend(name_value_pairs) + return urlunsplit((scheme, netloc, path, urlencode(query), frag)) diff --git a/storage/google/cloud/storage/bucket.py b/storage/google/cloud/storage/bucket.py index b280ecd55e30..f39bf8e1e7dd 100644 --- a/storage/google/cloud/storage/bucket.py +++ b/storage/google/cloud/storage/bucket.py @@ -108,6 +108,10 @@ class Bucket(_PropertyMixin): :type name: str :param name: The name of the bucket. Bucket names must start and end with a number or letter. + + :type user_project: str + :param user_project: (Optional) the project ID to be billed for API + requests made via this instance. """ _MAX_OBJECTS_FOR_ITERATION = 256 @@ -131,13 +135,14 @@ class Bucket(_PropertyMixin): https://cloud.google.com/storage/docs/storage-classes """ - def __init__(self, client, name=None): + def __init__(self, client, name=None, user_project=None): name = _validate_name(name) super(Bucket, self).__init__(name=name) self._client = client self._acl = BucketACL(self) self._default_object_acl = DefaultObjectACL(self) self._label_removals = set() + self._user_project = user_project def __repr__(self): return '' % (self.name,) @@ -156,6 +161,16 @@ def _set_properties(self, value): self._label_removals.clear() return super(Bucket, self)._set_properties(value) + @property + def user_project(self): + """Project ID to be billed for API requests made via this bucket. + + If unset, API requests are billed to the bucket owner. + + :rtype: str + """ + return self._user_project + def blob(self, blob_name, chunk_size=None, encryption_key=None): """Factory constructor for blob object. @@ -214,10 +229,14 @@ def exists(self, client=None): :returns: True if the bucket exists in Cloud Storage. """ client = self._require_client(client) + # We only need the status code (200 or not) so we seek to + # minimize the returned payload. + query_params = {'fields': 'name'} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + try: - # We only need the status code (200 or not) so we seek to - # minimize the returned payload. - query_params = {'fields': 'name'} # We intentionally pass `_target_object=None` since fields=name # would limit the local properties. client._connection.api_request( @@ -243,6 +262,9 @@ def create(self, client=None): :param client: Optional. The client to use. If not passed, falls back to the ``client`` stored on the current bucket. """ + if self.user_project is not None: + raise ValueError("Cannot create bucket with 'user_project' set.") + client = self._require_client(client) query_params = {'project': client.project} properties = {key: self._properties[key] for key in self._changes} @@ -334,13 +356,21 @@ def get_blob(self, blob_name, client=None, encryption_key=None, **kwargs): :returns: The blob object if it exists, otherwise None. """ client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project blob = Blob(bucket=self, name=blob_name, encryption_key=encryption_key, **kwargs) try: headers = _get_encryption_headers(encryption_key) response = client._connection.api_request( - method='GET', path=blob.path, _target_object=blob, - headers=headers) + method='GET', + path=blob.path, + query_params=query_params, + headers=headers, + _target_object=blob, + ) # NOTE: We assume response.get('name') matches `blob_name`. blob._set_properties(response) # NOTE: This will not fail immediately in a batch. However, when @@ -394,7 +424,7 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None, :returns: Iterator of all :class:`~google.cloud.storage.blob.Blob` in this bucket matching the arguments. """ - extra_params = {} + extra_params = {'projection': projection} if prefix is not None: extra_params['prefix'] = prefix @@ -405,11 +435,12 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None, if versions is not None: extra_params['versions'] = versions - extra_params['projection'] = projection - if fields is not None: extra_params['fields'] = fields + if self.user_project is not None: + extra_params['userProject'] = self.user_project + client = self._require_client(client) path = self.path + '/o' iterator = page_iterator.HTTPIterator( @@ -478,6 +509,11 @@ def delete(self, force=False, client=None): contains more than 256 objects / blobs. """ client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + if force: blobs = list(self.list_blobs( max_results=self._MAX_OBJECTS_FOR_ITERATION + 1, @@ -499,7 +535,10 @@ def delete(self, force=False, client=None): # request has no response value (whether in a standard request or # in a batch request). client._connection.api_request( - method='DELETE', path=self.path, _target_object=None) + method='DELETE', + path=self.path, + query_params=query_params, + _target_object=None) def delete_blob(self, blob_name, client=None): """Deletes a blob from the current bucket. @@ -531,12 +570,20 @@ def delete_blob(self, blob_name, client=None): """ client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + blob_path = Blob.path_helper(self.path, blob_name) # We intentionally pass `_target_object=None` since a DELETE # request has no response value (whether in a standard request or # in a batch request). client._connection.api_request( - method='DELETE', path=blob_path, _target_object=None) + method='DELETE', + path=blob_path, + query_params=query_params, + _target_object=None) def delete_blobs(self, blobs, on_error=None, client=None): """Deletes a list of blobs from the current bucket. @@ -599,14 +646,26 @@ def copy_blob(self, blob, destination_bucket, new_name=None, :returns: The new Blob. """ client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + if new_name is None: new_name = blob.name + new_blob = Blob(bucket=destination_bucket, name=new_name) api_path = blob.path + '/copyTo' + new_blob.path copy_result = client._connection.api_request( - method='POST', path=api_path, _target_object=new_blob) + method='POST', + path=api_path, + query_params=query_params, + _target_object=new_blob, + ) + if not preserve_acl: new_blob.acl.save(acl={}, client=client) + new_blob._set_properties(copy_result) return new_blob @@ -964,10 +1023,40 @@ def versioning_enabled(self, value): details. :type value: convertible to boolean - :param value: should versioning be anabled for the bucket? + :param value: should versioning be enabled for the bucket? """ self._patch_property('versioning', {'enabled': bool(value)}) + @property + def requester_pays(self): + """Does the requester pay for API requests for this bucket? + + .. note:: + + No public docs exist yet for the "requester pays" feature. + + :setter: Update whether requester pays for this bucket. + :getter: Query whether requester pays for this bucket. + + :rtype: bool + :returns: True if requester pays for API requests for the bucket, + else False. + """ + versioning = self._properties.get('billing', {}) + return versioning.get('requesterPays', False) + + @requester_pays.setter + def requester_pays(self, value): + """Update whether requester pays for API requests for this bucket. + + See https://cloud.google.com/storage/docs/ for + details. + + :type value: convertible to boolean + :param value: should requester pay for API requests for the bucket? + """ + self._patch_property('billing', {'requesterPays': bool(value)}) + def configure_website(self, main_page_suffix=None, not_found_page=None): """Configure website-related properties. @@ -1033,9 +1122,15 @@ def get_iam_policy(self, client=None): the ``getIamPolicy`` API request. """ client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + info = client._connection.api_request( method='GET', path='%s/iam' % (self.path,), + query_params=query_params, _target_object=None) return Policy.from_api_repr(info) @@ -1058,11 +1153,17 @@ def set_iam_policy(self, policy, client=None): the ``setIamPolicy`` API request. """ client = self._require_client(client) + query_params = {} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + resource = policy.to_api_repr() resource['resourceId'] = self.path info = client._connection.api_request( method='PUT', path='%s/iam' % (self.path,), + query_params=query_params, data=resource, _target_object=None) return Policy.from_api_repr(info) @@ -1086,12 +1187,16 @@ def test_iam_permissions(self, permissions, client=None): request. """ client = self._require_client(client) - query = {'permissions': permissions} + query_params = {'permissions': permissions} + + if self.user_project is not None: + query_params['userProject'] = self.user_project + path = '%s/iam/testPermissions' % (self.path,) resp = client._connection.api_request( method='GET', path=path, - query_params=query) + query_params=query_params) return resp.get('permissions', []) def make_public(self, recursive=False, future=False, client=None): diff --git a/storage/google/cloud/storage/client.py b/storage/google/cloud/storage/client.py index 5743dc059936..db66b8aa27c6 100644 --- a/storage/google/cloud/storage/client.py +++ b/storage/google/cloud/storage/client.py @@ -121,7 +121,7 @@ def current_batch(self): """ return self._batch_stack.top - def bucket(self, bucket_name): + def bucket(self, bucket_name, user_project=None): """Factory constructor for bucket object. .. note:: @@ -131,10 +131,14 @@ def bucket(self, bucket_name): :type bucket_name: str :param bucket_name: The name of the bucket to be instantiated. + :type user_project: str + :param user_project: (Optional) the project ID to be billed for API + requests made via this instance. + :rtype: :class:`google.cloud.storage.bucket.Bucket` :returns: The bucket object created. """ - return Bucket(client=self, name=bucket_name) + return Bucket(client=self, name=bucket_name, user_project=user_project) def batch(self): """Factory constructor for batch object. @@ -194,7 +198,7 @@ def lookup_bucket(self, bucket_name): except NotFound: return None - def create_bucket(self, bucket_name): + def create_bucket(self, bucket_name, requester_pays=None): """Create a new bucket. For example: @@ -211,10 +215,17 @@ def create_bucket(self, bucket_name): :type bucket_name: str :param bucket_name: The bucket name to create. + :type requester_pays: bool + :param requester_pays: + (Optional) Whether requester pays for API requests for this + bucket and its blobs. + :rtype: :class:`google.cloud.storage.bucket.Bucket` :returns: The newly created bucket. """ bucket = Bucket(self, name=bucket_name) + if requester_pays is not None: + bucket.requester_pays = requester_pays bucket.create(client=self) return bucket diff --git a/storage/tests/system.py b/storage/tests/system.py index 37ac4aeedb96..9e2aa38916ee 100644 --- a/storage/tests/system.py +++ b/storage/tests/system.py @@ -28,6 +28,9 @@ from test_utils.system import unique_resource_id +USER_PROJECT = os.environ.get('GOOGLE_CLOUD_TESTS_USER_PROJECT') + + def _bad_copy(bad_request): """Predicate: pass only exceptions for a failed copyTo.""" err_msg = bad_request.message @@ -133,6 +136,135 @@ def test_bucket_update_labels(self): bucket.update() self.assertEqual(bucket.labels, {}) + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_crud_bucket_with_requester_pays(self): + new_bucket_name = 'w-requester-pays' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + self.case_buckets_to_delete.append(new_bucket_name) + self.assertEqual(created.name, new_bucket_name) + self.assertTrue(created.requester_pays) + + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + # Bucket will be deleted in-line below. + self.case_buckets_to_delete.remove(new_bucket_name) + + try: + # Exercise 'buckets.get' w/ userProject. + self.assertTrue(with_user_project.exists()) + with_user_project.reload() + self.assertTrue(with_user_project.requester_pays) + + # Exercise 'buckets.patch' w/ userProject. + with_user_project.configure_website( + main_page_suffix='index.html', not_found_page='404.html') + with_user_project.patch() + self.assertEqual( + with_user_project._properties['website'], { + 'mainPageSuffix': 'index.html', + 'notFoundPage': '404.html', + }) + + # Exercise 'buckets.update' w/ userProject. + new_labels = {'another-label': 'another-value'} + with_user_project.labels = new_labels + with_user_project.update() + self.assertEqual(with_user_project.labels, new_labels) + + finally: + # Exercise 'buckets.delete' w/ userProject. + with_user_project.delete() + + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_bucket_acls_iam_with_user_project(self): + new_bucket_name = 'acl-w-user-project' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + self.case_buckets_to_delete.append(new_bucket_name) + + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + # Exercise bucket ACL w/ userProject + acl = with_user_project.acl + acl.reload() + acl.all().grant_read() + acl.save() + self.assertIn('READER', acl.all().get_roles()) + del acl.entities['allUsers'] + acl.save() + self.assertFalse(acl.has_entity('allUsers')) + + # Exercise default object ACL w/ userProject + doa = with_user_project.default_object_acl + doa.reload() + doa.all().grant_read() + doa.save() + self.assertIn('READER', doa.all().get_roles()) + + # Exercise IAM w/ userProject + test_permissions = ['storage.buckets.get'] + self.assertEqual( + with_user_project.test_iam_permissions(test_permissions), + test_permissions) + + policy = with_user_project.get_iam_policy() + viewers = policy.setdefault('roles/storage.objectViewer', set()) + viewers.add(policy.all_users()) + with_user_project.set_iam_policy(policy) + + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_copy_existing_file_with_user_project(self): + new_bucket_name = 'copy-w-requester-pays' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + self.case_buckets_to_delete.append(new_bucket_name) + self.assertEqual(created.name, new_bucket_name) + self.assertTrue(created.requester_pays) + + to_delete = [] + blob = storage.Blob('simple', bucket=created) + blob.upload_from_string(b'DEADBEEF') + to_delete.append(blob) + try: + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + new_blob = retry_bad_copy(with_user_project.copy_blob)( + blob, with_user_project, 'simple-copy') + to_delete.append(new_blob) + + base_contents = blob.download_as_string() + copied_contents = new_blob.download_as_string() + self.assertEqual(base_contents, copied_contents) + finally: + for blob in to_delete: + retry_429(blob.delete)() + + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_bucket_get_blob_with_user_project(self): + new_bucket_name = 'w-requester-pays' + unique_resource_id('-') + data = b'DEADBEEF' + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + self.case_buckets_to_delete.append(new_bucket_name) + self.assertEqual(created.name, new_bucket_name) + self.assertTrue(created.requester_pays) + + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + self.assertIsNone(with_user_project.get_blob('nonesuch')) + to_add = created.blob('blob-name') + to_add.upload_from_string(data) + try: + found = with_user_project.get_blob('blob-name') + self.assertEqual(found.download_as_string(), data) + finally: + to_add.delete() + class TestStorageFiles(unittest.TestCase): @@ -216,6 +348,66 @@ def test_small_file_write_from_filename(self): md5_hash = md5_hash.encode('utf-8') self.assertEqual(md5_hash, file_data['hash']) + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_crud_blob_w_user_project(self): + with_user_project = Config.CLIENT.bucket( + self.bucket.name, user_project=USER_PROJECT) + blob = with_user_project.blob('SmallFile') + + file_data = self.FILES['simple'] + with open(file_data['path'], mode='rb') as to_read: + file_contents = to_read.read() + + # Exercise 'objects.insert' w/ userProject. + blob.upload_from_filename(file_data['path']) + + try: + # Exercise 'objects.get' (metadata) w/ userProject. + self.assertTrue(blob.exists()) + blob.reload() + + # Exercise 'objects.get' (media) w/ userProject. + downloaded = blob.download_as_string() + self.assertEqual(downloaded, file_contents) + + # Exercise 'objects.patch' w/ userProject. + blob.content_language = 'en' + blob.patch() + self.assertEqual(blob.content_language, 'en') + + # Exercise 'objects.update' w/ userProject. + metadata = { + 'foo': 'Foo', + 'bar': 'Bar', + } + blob.metadata = metadata + blob.update() + self.assertEqual(blob.metadata, metadata) + finally: + # Exercise 'objects.delete' (metadata) w/ userProject. + blob.delete() + + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_blob_acl_w_user_project(self): + with_user_project = Config.CLIENT.bucket( + self.bucket.name, user_project=USER_PROJECT) + blob = with_user_project.blob('SmallFile') + + file_data = self.FILES['simple'] + + blob.upload_from_filename(file_data['path']) + self.case_blobs_to_delete.append(blob) + + # Exercise bucket ACL w/ userProject + acl = blob.acl + acl.reload() + acl.all().grant_read() + acl.save() + self.assertIn('READER', acl.all().get_roles()) + del acl.entities['allUsers'] + acl.save() + self.assertFalse(acl.has_entity('allUsers')) + def test_write_metadata(self): filename = self.FILES['logo']['path'] blob_name = os.path.basename(filename) @@ -314,6 +506,15 @@ def test_list_files(self): self.assertEqual(sorted(blob.name for blob in all_blobs), sorted(self.FILENAMES)) + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + @RetryErrors(unittest.TestCase.failureException) + def test_list_files_with_user_project(self): + with_user_project = Config.CLIENT.bucket( + self.bucket.name, user_project=USER_PROJECT) + all_blobs = list(with_user_project.list_blobs()) + self.assertEqual(sorted(blob.name for blob in all_blobs), + sorted(self.FILENAMES)) + @RetryErrors(unittest.TestCase.failureException) def test_paginate_files(self): truncation_size = 1 @@ -501,6 +702,32 @@ def test_compose_replace_existing_blob(self): composed = original.download_as_string() self.assertEqual(composed, BEFORE + TO_APPEND) + @unittest.skipUnless(USER_PROJECT, 'USER_PROJECT not set in environment.') + def test_compose_with_user_project(self): + new_bucket_name = 'compose-user-project' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + try: + SOURCE_1 = b'AAA\n' + source_1 = created.blob('source-1') + source_1.upload_from_string(SOURCE_1) + + SOURCE_2 = b'BBB\n' + source_2 = created.blob('source-2') + source_2.upload_from_string(SOURCE_2) + + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + destination = with_user_project.blob('destination') + destination.content_type = 'text/plain' + destination.compose([source_1, source_2]) + + composed = destination.download_as_string() + self.assertEqual(composed, SOURCE_1 + SOURCE_2) + finally: + retry_429(created.delete)(force=True) + class TestStorageRewrite(TestStorageFiles): @@ -544,12 +771,66 @@ def test_rewrite_rotate_encryption_key(self): # Not adding 'dest' to 'self.case_blobs_to_delete': it is the # same object as 'source'. - self.assertEqual(token, None) + self.assertIsNone(token) self.assertEqual(rewritten, len(source_data)) self.assertEqual(total, len(source_data)) self.assertEqual(dest.download_as_string(), source_data) + def test_rewrite_add_key_with_user_project(self): + file_data = self.FILES['simple'] + new_bucket_name = 'rewrite-key-up' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + try: + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + source = with_user_project.blob('source') + source.upload_from_filename(file_data['path']) + source_data = source.download_as_string() + + KEY = os.urandom(32) + dest = with_user_project.blob('dest', encryption_key=KEY) + token, rewritten, total = dest.rewrite(source) + + self.assertEqual(token, None) + self.assertEqual(rewritten, len(source_data)) + self.assertEqual(total, len(source_data)) + + self.assertEqual(source.download_as_string(), + dest.download_as_string()) + finally: + retry_429(created.delete)(force=True) + + def test_rewrite_rotate_with_user_project(self): + BLOB_NAME = 'rotating-keys' + file_data = self.FILES['simple'] + new_bucket_name = 'rewrite-rotate-up' + unique_resource_id('-') + created = Config.CLIENT.create_bucket( + new_bucket_name, requester_pays=True) + try: + with_user_project = Config.CLIENT.bucket( + new_bucket_name, user_project=USER_PROJECT) + + SOURCE_KEY = os.urandom(32) + source = with_user_project.blob( + BLOB_NAME, encryption_key=SOURCE_KEY) + source.upload_from_filename(file_data['path']) + source_data = source.download_as_string() + + DEST_KEY = os.urandom(32) + dest = with_user_project.blob(BLOB_NAME, encryption_key=DEST_KEY) + token, rewritten, total = dest.rewrite(source) + + self.assertEqual(token, None) + self.assertEqual(rewritten, len(source_data)) + self.assertEqual(total, len(source_data)) + + self.assertEqual(dest.download_as_string(), source_data) + finally: + retry_429(created.delete)(force=True) + class TestStorageNotificationCRUD(unittest.TestCase): diff --git a/storage/tests/unit/test__helpers.py b/storage/tests/unit/test__helpers.py index 90def4867268..78bf5fcf3d0a 100644 --- a/storage/tests/unit/test__helpers.py +++ b/storage/tests/unit/test__helpers.py @@ -26,7 +26,7 @@ def _get_target_class(): def _make_one(self, *args, **kw): return self._get_target_class()(*args, **kw) - def _derivedClass(self, path=None): + def _derivedClass(self, path=None, user_project=None): class Derived(self._get_target_class()): @@ -36,30 +36,67 @@ class Derived(self._get_target_class()): def path(self): return path + @property + def user_project(self): + return user_project + return Derived def test_path_is_abstract(self): mixin = self._make_one() - self.assertRaises(NotImplementedError, lambda: mixin.path) + with self.assertRaises(NotImplementedError): + mixin.path def test_client_is_abstract(self): mixin = self._make_one() - self.assertRaises(NotImplementedError, lambda: mixin.client) + with self.assertRaises(NotImplementedError): + mixin.client + + def test_user_project_is_abstract(self): + mixin = self._make_one() + with self.assertRaises(NotImplementedError): + mixin.user_project def test_reload(self): connection = _Connection({'foo': 'Foo'}) client = _Client(connection) derived = self._derivedClass('/path')() - # Make sure changes is not a set, so we can observe a change. + # Make sure changes is not a set instance before calling reload + # (which will clear / replace it with an empty set), checked below. derived._changes = object() derived.reload(client=client) self.assertEqual(derived._properties, {'foo': 'Foo'}) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'GET') - self.assertEqual(kw[0]['path'], '/path') - self.assertEqual(kw[0]['query_params'], {'projection': 'noAcl'}) - # Make sure changes get reset by reload. + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '/path', + 'query_params': {'projection': 'noAcl'}, + '_target_object': derived, + }) + self.assertEqual(derived._changes, set()) + + def test_reload_w_user_project(self): + user_project = 'user-project-123' + connection = _Connection({'foo': 'Foo'}) + client = _Client(connection) + derived = self._derivedClass('/path', user_project)() + # Make sure changes is not a set instance before calling reload + # (which will clear / replace it with an empty set), checked below. + derived._changes = object() + derived.reload(client=client) + self.assertEqual(derived._properties, {'foo': 'Foo'}) + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '/path', + 'query_params': { + 'projection': 'noAcl', + 'userProject': user_project, + }, + '_target_object': derived, + }) self.assertEqual(derived._changes, set()) def test__set_properties(self): @@ -87,11 +124,42 @@ def test_patch(self): self.assertEqual(derived._properties, {'foo': 'Foo'}) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'PATCH') - self.assertEqual(kw[0]['path'], '/path') - self.assertEqual(kw[0]['query_params'], {'projection': 'full'}) - # Since changes does not include `baz`, we don't see it sent. - self.assertEqual(kw[0]['data'], {'bar': BAR}) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/path', + 'query_params': {'projection': 'full'}, + # Since changes does not include `baz`, we don't see it sent. + 'data': {'bar': BAR}, + '_target_object': derived, + }) + # Make sure changes get reset by patch(). + self.assertEqual(derived._changes, set()) + + def test_patch_w_user_project(self): + user_project = 'user-project-123' + connection = _Connection({'foo': 'Foo'}) + client = _Client(connection) + derived = self._derivedClass('/path', user_project)() + # Make sure changes is non-empty, so we can observe a change. + BAR = object() + BAZ = object() + derived._properties = {'bar': BAR, 'baz': BAZ} + derived._changes = set(['bar']) # Ignore baz. + derived.patch(client=client) + self.assertEqual(derived._properties, {'foo': 'Foo'}) + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/path', + 'query_params': { + 'projection': 'full', + 'userProject': user_project, + }, + # Since changes does not include `baz`, we don't see it sent. + 'data': {'bar': BAR}, + '_target_object': derived, + }) # Make sure changes get reset by patch(). self.assertEqual(derived._changes, set()) @@ -115,6 +183,31 @@ def test_update(self): # Make sure changes get reset by patch(). self.assertEqual(derived._changes, set()) + def test_update_w_user_project(self): + user_project = 'user-project-123' + connection = _Connection({'foo': 'Foo'}) + client = _Client(connection) + derived = self._derivedClass('/path', user_project)() + # Make sure changes is non-empty, so we can observe a change. + BAR = object() + BAZ = object() + derived._properties = {'bar': BAR, 'baz': BAZ} + derived._changes = set(['bar']) # Update sends 'baz' anyway. + derived.update(client=client) + self.assertEqual(derived._properties, {'foo': 'Foo'}) + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'PUT') + self.assertEqual(kw[0]['path'], '/path') + self.assertEqual( + kw[0]['query_params'], { + 'projection': 'full', + 'userProject': user_project, + }) + self.assertEqual(kw[0]['data'], {'bar': BAR, 'baz': BAZ}) + # Make sure changes get reset by patch(). + self.assertEqual(derived._changes, set()) + class Test__scalar_property(unittest.TestCase): diff --git a/storage/tests/unit/test_acl.py b/storage/tests/unit/test_acl.py index 1159c8c1f2aa..4e4018ae7c8c 100644 --- a/storage/tests/unit/test_acl.py +++ b/storage/tests/unit/test_acl.py @@ -532,8 +532,11 @@ def test_reload_missing(self): self.assertEqual(list(acl), []) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'GET') - self.assertEqual(kw[0]['path'], '/testing/acl') + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '/testing/acl', + 'query_params': {}, + }) def test_reload_empty_result_clears_local(self): ROLE = 'role' @@ -543,29 +546,41 @@ def test_reload_empty_result_clears_local(self): acl.reload_path = '/testing/acl' acl.loaded = True acl.entity('allUsers', ROLE) + acl.reload(client=client) + self.assertTrue(acl.loaded) self.assertEqual(list(acl), []) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'GET') - self.assertEqual(kw[0]['path'], '/testing/acl') + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '/testing/acl', + 'query_params': {}, + }) - def test_reload_nonempty_result(self): + def test_reload_nonempty_result_w_user_project(self): ROLE = 'role' + USER_PROJECT = 'user-project-123' connection = _Connection( {'items': [{'entity': 'allUsers', 'role': ROLE}]}) client = _Client(connection) acl = self._make_one() acl.reload_path = '/testing/acl' acl.loaded = True + acl.user_project = USER_PROJECT + acl.reload(client=client) + self.assertTrue(acl.loaded) self.assertEqual(list(acl), [{'entity': 'allUsers', 'role': ROLE}]) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'GET') - self.assertEqual(kw[0]['path'], '/testing/acl') + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '/testing/acl', + 'query_params': {'userProject': USER_PROJECT}, + }) def test_save_none_set_none_passed(self): connection = _Connection() @@ -606,30 +621,43 @@ def test_save_no_acl(self): self.assertEqual(len(kw), 1) self.assertEqual(kw[0]['method'], 'PATCH') self.assertEqual(kw[0]['path'], '/testing') - self.assertEqual(kw[0]['data'], {'acl': AFTER}) - self.assertEqual(kw[0]['query_params'], {'projection': 'full'}) - - def test_save_w_acl(self): + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/testing', + 'query_params': {'projection': 'full'}, + 'data': {'acl': AFTER}, + }) + + def test_save_w_acl_w_user_project(self): ROLE1 = 'role1' ROLE2 = 'role2' STICKY = {'entity': 'allUsers', 'role': ROLE2} + USER_PROJECT = 'user-project-123' new_acl = [{'entity': 'allUsers', 'role': ROLE1}] connection = _Connection({'acl': [STICKY] + new_acl}) client = _Client(connection) acl = self._make_one() acl.save_path = '/testing' acl.loaded = True + acl.user_project = USER_PROJECT + acl.save(new_acl, client=client) + entries = list(acl) self.assertEqual(len(entries), 2) self.assertTrue(STICKY in entries) self.assertTrue(new_acl[0] in entries) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'PATCH') - self.assertEqual(kw[0]['path'], '/testing') - self.assertEqual(kw[0]['data'], {'acl': new_acl}) - self.assertEqual(kw[0]['query_params'], {'projection': 'full'}) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/testing', + 'query_params': { + 'projection': 'full', + 'userProject': USER_PROJECT, + }, + 'data': {'acl': new_acl}, + }) def test_save_prefefined_invalid(self): connection = _Connection() @@ -652,11 +680,15 @@ def test_save_predefined_valid(self): self.assertEqual(len(entries), 0) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'PATCH') - self.assertEqual(kw[0]['path'], '/testing') - self.assertEqual(kw[0]['data'], {'acl': []}) - self.assertEqual(kw[0]['query_params'], - {'projection': 'full', 'predefinedAcl': PREDEFINED}) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/testing', + 'query_params': { + 'projection': 'full', + 'predefinedAcl': PREDEFINED, + }, + 'data': {'acl': []}, + }) def test_save_predefined_w_XML_alias(self): PREDEFINED_XML = 'project-private' @@ -671,12 +703,15 @@ def test_save_predefined_w_XML_alias(self): self.assertEqual(len(entries), 0) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'PATCH') - self.assertEqual(kw[0]['path'], '/testing') - self.assertEqual(kw[0]['data'], {'acl': []}) - self.assertEqual(kw[0]['query_params'], - {'projection': 'full', - 'predefinedAcl': PREDEFINED_JSON}) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/testing', + 'query_params': { + 'projection': 'full', + 'predefinedAcl': PREDEFINED_JSON, + }, + 'data': {'acl': []}, + }) def test_save_predefined_valid_w_alternate_query_param(self): # Cover case where subclass overrides _PREDEFINED_QUERY_PARAM @@ -692,11 +727,15 @@ def test_save_predefined_valid_w_alternate_query_param(self): self.assertEqual(len(entries), 0) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'PATCH') - self.assertEqual(kw[0]['path'], '/testing') - self.assertEqual(kw[0]['data'], {'acl': []}) - self.assertEqual(kw[0]['query_params'], - {'projection': 'full', 'alternate': PREDEFINED}) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/testing', + 'query_params': { + 'projection': 'full', + 'alternate': PREDEFINED, + }, + 'data': {'acl': []}, + }) def test_clear(self): ROLE1 = 'role1' @@ -712,10 +751,12 @@ def test_clear(self): self.assertEqual(list(acl), [STICKY]) kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'PATCH') - self.assertEqual(kw[0]['path'], '/testing') - self.assertEqual(kw[0]['data'], {'acl': []}) - self.assertEqual(kw[0]['query_params'], {'projection': 'full'}) + self.assertEqual(kw[0], { + 'method': 'PATCH', + 'path': '/testing', + 'query_params': {'projection': 'full'}, + 'data': {'acl': []}, + }) class Test_BucketACL(unittest.TestCase): @@ -739,6 +780,15 @@ def test_ctor(self): self.assertEqual(acl.reload_path, '/b/%s/acl' % NAME) self.assertEqual(acl.save_path, '/b/%s' % NAME) + def test_user_project(self): + NAME = 'name' + USER_PROJECT = 'user-project-123' + bucket = _Bucket(NAME) + acl = self._make_one(bucket) + self.assertIsNone(acl.user_project) + bucket.user_project = USER_PROJECT + self.assertEqual(acl.user_project, USER_PROJECT) + class Test_DefaultObjectACL(unittest.TestCase): @@ -785,9 +835,22 @@ def test_ctor(self): self.assertEqual(acl.reload_path, '/b/%s/o/%s/acl' % (NAME, BLOB_NAME)) self.assertEqual(acl.save_path, '/b/%s/o/%s' % (NAME, BLOB_NAME)) + def test_user_project(self): + NAME = 'name' + BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' + bucket = _Bucket(NAME) + blob = _Blob(bucket, BLOB_NAME) + acl = self._make_one(blob) + self.assertIsNone(acl.user_project) + blob.user_project = USER_PROJECT + self.assertEqual(acl.user_project, USER_PROJECT) + class _Blob(object): + user_project = None + def __init__(self, bucket, blob): self.bucket = bucket self.blob = blob @@ -799,6 +862,8 @@ def path(self): class _Bucket(object): + user_project = None + def __init__(self, name): self.name = name diff --git a/storage/tests/unit/test_blob.py b/storage/tests/unit/test_blob.py index e0a41ee793d2..9ce326d818c4 100644 --- a/storage/tests/unit/test_blob.py +++ b/storage/tests/unit/test_blob.py @@ -141,6 +141,19 @@ def test_path_with_non_ascii(self): blob = self._make_one(blob_name, bucket=bucket) self.assertEqual(blob.path, '/b/name/o/Caf%C3%A9') + def test_client(self): + blob_name = 'BLOB' + bucket = _Bucket() + blob = self._make_one(blob_name, bucket=bucket) + self.assertIs(blob.client, bucket.client) + + def test_user_project(self): + user_project = 'user-project-123' + blob_name = 'BLOB' + bucket = _Bucket(user_project=user_project) + blob = self._make_one(blob_name, bucket=bucket) + self.assertEqual(blob.user_project, user_project) + def test_public_url(self): BLOB_NAME = 'blob-name' bucket = _Bucket() @@ -304,16 +317,31 @@ def test_exists_miss(self): bucket = _Bucket(client) blob = self._make_one(NONESUCH, bucket=bucket) self.assertFalse(blob.exists()) + self.assertEqual(len(connection._requested), 1) + self.assertEqual(connection._requested[0], { + 'method': 'GET', + 'path': '/b/name/o/{}'.format(NONESUCH), + 'query_params': {'fields': 'name'}, + '_target_object': None, + }) - def test_exists_hit(self): + def test_exists_hit_w_user_project(self): BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' found_response = ({'status': http_client.OK}, b'') connection = _Connection(found_response) client = _Client(connection) - bucket = _Bucket(client) + bucket = _Bucket(client, user_project=USER_PROJECT) blob = self._make_one(BLOB_NAME, bucket=bucket) bucket._blobs[BLOB_NAME] = 1 self.assertTrue(blob.exists()) + self.assertEqual(len(connection._requested), 1) + self.assertEqual(connection._requested[0], { + 'method': 'GET', + 'path': '/b/name/o/{}'.format(BLOB_NAME), + 'query_params': {'fields': 'name', 'userProject': USER_PROJECT}, + '_target_object': None, + }) def test_delete(self): BLOB_NAME = 'blob-name' @@ -338,7 +366,7 @@ def test__get_transport(self): def test__get_download_url_with_media_link(self): blob_name = 'something.txt' - bucket = mock.Mock(spec=[]) + bucket = _Bucket(name='IRRELEVANT') blob = self._make_one(blob_name, bucket=bucket) media_link = 'http://test.invalid' # Set the media link on the blob @@ -347,9 +375,22 @@ def test__get_download_url_with_media_link(self): download_url = blob._get_download_url() self.assertEqual(download_url, media_link) + def test__get_download_url_with_media_link_w_user_project(self): + blob_name = 'something.txt' + user_project = 'user-project-123' + bucket = _Bucket(name='IRRELEVANT', user_project=user_project) + blob = self._make_one(blob_name, bucket=bucket) + media_link = 'http://test.invalid' + # Set the media link on the blob + blob._properties['mediaLink'] = media_link + + download_url = blob._get_download_url() + self.assertEqual( + download_url, '{}?userProject={}'.format(media_link, user_project)) + def test__get_download_url_on_the_fly(self): blob_name = 'bzzz-fly.txt' - bucket = mock.Mock(path='/b/buhkit', spec=['path']) + bucket = _Bucket(name='buhkit') blob = self._make_one(blob_name, bucket=bucket) self.assertIsNone(blob.media_link) @@ -361,7 +402,7 @@ def test__get_download_url_on_the_fly(self): def test__get_download_url_on_the_fly_with_generation(self): blob_name = 'pretend.txt' - bucket = mock.Mock(path='/b/fictional', spec=['path']) + bucket = _Bucket(name='fictional') blob = self._make_one(blob_name, bucket=bucket) generation = 1493058489532987 # Set the media link on the blob @@ -374,6 +415,20 @@ def test__get_download_url_on_the_fly_with_generation(self): 'fictional/o/pretend.txt?alt=media&generation=1493058489532987') self.assertEqual(download_url, expected_url) + def test__get_download_url_on_the_fly_with_user_project(self): + blob_name = 'pretend.txt' + user_project = 'user-project-123' + bucket = _Bucket(name='fictional', user_project=user_project) + blob = self._make_one(blob_name, bucket=bucket) + + self.assertIsNone(blob.media_link) + download_url = blob._get_download_url() + expected_url = ( + 'https://www.googleapis.com/download/storage/v1/b/' + 'fictional/o/pretend.txt?alt=media&userProject={}'.format( + user_project)) + self.assertEqual(download_url, expected_url) + @staticmethod def _mock_requests_response( status_code, headers, content=b'', stream=False): @@ -759,8 +814,8 @@ def _mock_transport(self, status_code, headers, content=b''): return fake_transport def _do_multipart_success(self, mock_get_boundary, size=None, - num_retries=None): - bucket = mock.Mock(path='/b/w00t', spec=[u'path']) + num_retries=None, user_project=None): + bucket = _Bucket(name='w00t', user_project=user_project) blob = self._make_one(u'blob-name', bucket=bucket) self.assertIsNone(blob.chunk_size) @@ -790,6 +845,8 @@ def _do_multipart_success(self, mock_get_boundary, size=None, 'https://www.googleapis.com/upload/storage/v1' + bucket.path + '/o?uploadType=multipart') + if user_project is not None: + upload_url += '&userProject={}'.format(user_project) payload = ( b'--==0==\r\n' + b'content-type: application/json; charset=UTF-8\r\n\r\n' + @@ -812,6 +869,13 @@ def test__do_multipart_upload_no_size(self, mock_get_boundary): def test__do_multipart_upload_with_size(self, mock_get_boundary): self._do_multipart_success(mock_get_boundary, size=10) + @mock.patch(u'google.resumable_media._upload.get_boundary', + return_value=b'==0==') + def test__do_multipart_upload_with_user_project(self, mock_get_boundary): + user_project = 'user-project-123' + self._do_multipart_success( + mock_get_boundary, user_project=user_project) + @mock.patch(u'google.resumable_media._upload.get_boundary', return_value=b'==0==') def test__do_multipart_upload_with_retry(self, mock_get_boundary): @@ -833,11 +897,12 @@ def test__do_multipart_upload_bad_size(self): 'was specified but the file-like object only had', exc_contents) self.assertEqual(stream.tell(), len(data)) - def _initiate_resumable_helper(self, size=None, extra_headers=None, - chunk_size=None, num_retries=None): + def _initiate_resumable_helper( + self, size=None, extra_headers=None, chunk_size=None, + num_retries=None, user_project=None): from google.resumable_media.requests import ResumableUpload - bucket = mock.Mock(path='/b/whammy', spec=[u'path']) + bucket = _Bucket(name='whammy', user_project=user_project) blob = self._make_one(u'blob-name', bucket=bucket) blob.metadata = {'rook': 'takes knight'} blob.chunk_size = 3 * blob._CHUNK_SIZE_MULTIPLE @@ -869,6 +934,8 @@ def _initiate_resumable_helper(self, size=None, extra_headers=None, 'https://www.googleapis.com/upload/storage/v1' + bucket.path + '/o?uploadType=resumable') + if user_project is not None: + upload_url += '&userProject={}'.format(user_project) self.assertEqual(upload.upload_url, upload_url) if extra_headers is None: self.assertEqual(upload._headers, {}) @@ -920,6 +987,10 @@ def test__initiate_resumable_upload_no_size(self): def test__initiate_resumable_upload_with_size(self): self._initiate_resumable_helper(size=10000) + def test__initiate_resumable_upload_with_user_project(self): + user_project = 'user-project-123' + self._initiate_resumable_helper(user_project=user_project) + def test__initiate_resumable_upload_with_chunk_size(self): one_mb = 1048576 self._initiate_resumable_helper(chunk_size=one_mb) @@ -1000,7 +1071,7 @@ def _do_resumable_upload_call2(blob, content_type, data, 'PUT', resumable_url, data=payload, headers=expected_headers) def _do_resumable_helper(self, use_size=False, num_retries=None): - bucket = mock.Mock(path='/b/yesterday', spec=[u'path']) + bucket = _Bucket(name='yesterday') blob = self._make_one(u'blob-name', bucket=bucket) blob.chunk_size = blob._CHUNK_SIZE_MULTIPLE self.assertIsNotNone(blob.chunk_size) @@ -1245,7 +1316,7 @@ def test_upload_from_string_w_text(self): def _create_resumable_upload_session_helper(self, origin=None, side_effect=None): - bucket = mock.Mock(path='/b/alex-trebek', spec=[u'path']) + bucket = _Bucket(name='alex-trebek') blob = self._make_one('blob-name', bucket=bucket) chunk_size = 99 * blob._CHUNK_SIZE_MULTIPLE blob.chunk_size = chunk_size @@ -1354,8 +1425,49 @@ def test_get_iam_policy(self): kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'GET') - self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '%s/iam' % (PATH,), + 'query_params': {}, + '_target_object': None, + }) + + def test_get_iam_policy_w_user_project(self): + from google.cloud.iam import Policy + + BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' + PATH = '/b/name/o/%s' % (BLOB_NAME,) + ETAG = 'DEADBEEF' + VERSION = 17 + RETURNED = { + 'resourceId': PATH, + 'etag': ETAG, + 'version': VERSION, + 'bindings': [], + } + after = ({'status': http_client.OK}, RETURNED) + EXPECTED = {} + connection = _Connection(after) + client = _Client(connection) + bucket = _Bucket(client=client, user_project=USER_PROJECT) + blob = self._make_one(BLOB_NAME, bucket=bucket) + + policy = blob.get_iam_policy() + + self.assertIsInstance(policy, Policy) + self.assertEqual(policy.etag, RETURNED['etag']) + self.assertEqual(policy.version, RETURNED['version']) + self.assertEqual(dict(policy), EXPECTED) + + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0], { + 'method': 'GET', + 'path': '%s/iam' % (PATH,), + 'query_params': {'userProject': USER_PROJECT}, + '_target_object': None, + }) def test_set_iam_policy(self): import operator @@ -1404,6 +1516,7 @@ def test_set_iam_policy(self): self.assertEqual(len(kw), 1) self.assertEqual(kw[0]['method'], 'PUT') self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0]['query_params'], {}) sent = kw[0]['data'] self.assertEqual(sent['resourceId'], PATH) self.assertEqual(len(sent['bindings']), len(BINDINGS)) @@ -1415,6 +1528,41 @@ def test_set_iam_policy(self): self.assertEqual( sorted(found['members']), sorted(expected['members'])) + def test_set_iam_policy_w_user_project(self): + from google.cloud.iam import Policy + + BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' + PATH = '/b/name/o/%s' % (BLOB_NAME,) + ETAG = 'DEADBEEF' + VERSION = 17 + BINDINGS = [] + RETURNED = { + 'etag': ETAG, + 'version': VERSION, + 'bindings': BINDINGS, + } + after = ({'status': http_client.OK}, RETURNED) + policy = Policy() + + connection = _Connection(after) + client = _Client(connection) + bucket = _Bucket(client=client, user_project=USER_PROJECT) + blob = self._make_one(BLOB_NAME, bucket=bucket) + + returned = blob.set_iam_policy(policy) + + self.assertEqual(returned.etag, ETAG) + self.assertEqual(returned.version, VERSION) + self.assertEqual(dict(returned), dict(policy)) + + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'PUT') + self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0]['query_params'], {'userProject': USER_PROJECT}) + self.assertEqual(kw[0]['data'], {'resourceId': PATH}) + def test_test_iam_permissions(self): from google.cloud.storage.iam import STORAGE_OBJECTS_LIST from google.cloud.storage.iam import STORAGE_BUCKETS_GET @@ -1445,6 +1593,39 @@ def test_test_iam_permissions(self): self.assertEqual(kw[0]['path'], '%s/iam/testPermissions' % (PATH,)) self.assertEqual(kw[0]['query_params'], {'permissions': PERMISSIONS}) + def test_test_iam_permissions_w_user_project(self): + from google.cloud.storage.iam import STORAGE_OBJECTS_LIST + from google.cloud.storage.iam import STORAGE_BUCKETS_GET + from google.cloud.storage.iam import STORAGE_BUCKETS_UPDATE + + BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' + PATH = '/b/name/o/%s' % (BLOB_NAME,) + PERMISSIONS = [ + STORAGE_OBJECTS_LIST, + STORAGE_BUCKETS_GET, + STORAGE_BUCKETS_UPDATE, + ] + ALLOWED = PERMISSIONS[1:] + RETURNED = {'permissions': ALLOWED} + after = ({'status': http_client.OK}, RETURNED) + connection = _Connection(after) + client = _Client(connection) + bucket = _Bucket(client=client, user_project=USER_PROJECT) + blob = self._make_one(BLOB_NAME, bucket=bucket) + + allowed = blob.test_iam_permissions(PERMISSIONS) + + self.assertEqual(allowed, ALLOWED) + + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'GET') + self.assertEqual(kw[0]['path'], '%s/iam/testPermissions' % (PATH,)) + self.assertEqual( + kw[0]['query_params'], + {'permissions': PERMISSIONS, 'userProject': USER_PROJECT}) + def test_make_public(self): from google.cloud.storage.acl import _ACLEntity @@ -1479,17 +1660,18 @@ def test_compose_wo_content_type_set(self): with self.assertRaises(ValueError): destination.compose(sources=[source_1, source_2]) - def test_compose_minimal(self): + def test_compose_minimal_w_user_project(self): SOURCE_1 = 'source-1' SOURCE_2 = 'source-2' DESTINATION = 'destinaton' RESOURCE = { 'etag': 'DEADBEEF' } + USER_PROJECT = 'user-project-123' after = ({'status': http_client.OK}, RESOURCE) connection = _Connection(after) client = _Client(connection) - bucket = _Bucket(client=client) + bucket = _Bucket(client=client, user_project=USER_PROJECT) source_1 = self._make_one(SOURCE_1, bucket=bucket) source_2 = self._make_one(SOURCE_2, bucket=bucket) destination = self._make_one(DESTINATION, bucket=bucket) @@ -1499,20 +1681,23 @@ def test_compose_minimal(self): self.assertEqual(destination.etag, 'DEADBEEF') - SENT = { - 'sourceObjects': [ - {'name': source_1.name}, - {'name': source_2.name}, - ], - 'destination': { - 'contentType': 'text/plain', - }, - } kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'POST') - self.assertEqual(kw[0]['path'], '/b/name/o/%s/compose' % DESTINATION) - self.assertEqual(kw[0]['data'], SENT) + self.assertEqual(kw[0], { + 'method': 'POST', + 'path': '/b/name/o/%s/compose' % DESTINATION, + 'query_params': {'userProject': USER_PROJECT}, + 'data': { + 'sourceObjects': [ + {'name': source_1.name}, + {'name': source_2.name}, + ], + 'destination': { + 'contentType': 'text/plain', + }, + }, + '_target_object': destination, + }) def test_compose_w_additional_property_changes(self): SOURCE_1 = 'source-1' @@ -1536,24 +1721,27 @@ def test_compose_w_additional_property_changes(self): self.assertEqual(destination.etag, 'DEADBEEF') - SENT = { - 'sourceObjects': [ - {'name': source_1.name}, - {'name': source_2.name}, - ], - 'destination': { - 'contentType': 'text/plain', - 'contentLanguage': 'en-US', - 'metadata': { - 'my-key': 'my-value', - } - }, - } kw = connection._requested self.assertEqual(len(kw), 1) - self.assertEqual(kw[0]['method'], 'POST') - self.assertEqual(kw[0]['path'], '/b/name/o/%s/compose' % DESTINATION) - self.assertEqual(kw[0]['data'], SENT) + self.assertEqual(kw[0], { + 'method': 'POST', + 'path': '/b/name/o/%s/compose' % DESTINATION, + 'query_params': {}, + 'data': { + 'sourceObjects': [ + {'name': source_1.name}, + {'name': source_2.name}, + ], + 'destination': { + 'contentType': 'text/plain', + 'contentLanguage': 'en-US', + 'metadata': { + 'my-key': 'my-value', + } + }, + }, + '_target_object': destination, + }) def test_rewrite_response_without_resource(self): SOURCE_BLOB = 'source' @@ -1625,7 +1813,7 @@ def test_rewrite_other_bucket_other_name_no_encryption_partial(self): self.assertNotIn('X-Goog-Encryption-Key', headers) self.assertNotIn('X-Goog-Encryption-Key-Sha256', headers) - def test_rewrite_same_name_no_old_key_new_key_done(self): + def test_rewrite_same_name_no_old_key_new_key_done_w_user_project(self): import base64 import hashlib @@ -1634,6 +1822,7 @@ def test_rewrite_same_name_no_old_key_new_key_done(self): KEY_HASH = hashlib.sha256(KEY).digest() KEY_HASH_B64 = base64.b64encode(KEY_HASH).rstrip().decode('ascii') BLOB_NAME = 'blob' + USER_PROJECT = 'user-project-123' RESPONSE = { 'totalBytesRewritten': 42, 'objectSize': 42, @@ -1643,7 +1832,7 @@ def test_rewrite_same_name_no_old_key_new_key_done(self): response = ({'status': http_client.OK}, RESPONSE) connection = _Connection(response) client = _Client(connection) - bucket = _Bucket(client=client) + bucket = _Bucket(client=client, user_project=USER_PROJECT) plain = self._make_one(BLOB_NAME, bucket=bucket) encrypted = self._make_one(BLOB_NAME, bucket=bucket, encryption_key=KEY) @@ -1659,7 +1848,7 @@ def test_rewrite_same_name_no_old_key_new_key_done(self): self.assertEqual(kw[0]['method'], 'POST') PATH = '/b/name/o/%s/rewriteTo/b/name/o/%s' % (BLOB_NAME, BLOB_NAME) self.assertEqual(kw[0]['path'], PATH) - self.assertEqual(kw[0]['query_params'], {}) + self.assertEqual(kw[0]['query_params'], {'userProject': USER_PROJECT}) SENT = {} self.assertEqual(kw[0]['data'], SENT) @@ -1762,7 +1951,7 @@ def test_update_storage_class_wo_encryption_key(self): self.assertEqual(kw[0]['method'], 'POST') PATH = '/b/name/o/%s/rewriteTo/b/name/o/%s' % (BLOB_NAME, BLOB_NAME) self.assertEqual(kw[0]['path'], PATH) - self.assertNotIn('query_params', kw[0]) + self.assertEqual(kw[0]['query_params'], {}) SENT = {'storageClass': STORAGE_CLASS} self.assertEqual(kw[0]['data'], SENT) @@ -1776,7 +1965,7 @@ def test_update_storage_class_wo_encryption_key(self): self.assertNotIn('X-Goog-Encryption-Key', headers) self.assertNotIn('X-Goog-Encryption-Key-Sha256', headers) - def test_update_storage_class_w_encryption_key(self): + def test_update_storage_class_w_encryption_key_w_user_project(self): import base64 import hashlib @@ -1787,13 +1976,14 @@ def test_update_storage_class_w_encryption_key(self): BLOB_KEY_HASH_B64 = base64.b64encode( BLOB_KEY_HASH).rstrip().decode('ascii') STORAGE_CLASS = u'NEARLINE' + USER_PROJECT = 'user-project-123' RESPONSE = { 'resource': {'storageClass': STORAGE_CLASS}, } response = ({'status': http_client.OK}, RESPONSE) connection = _Connection(response) client = _Client(connection) - bucket = _Bucket(client=client) + bucket = _Bucket(client=client, user_project=USER_PROJECT) blob = self._make_one( BLOB_NAME, bucket=bucket, encryption_key=BLOB_KEY) @@ -1806,7 +1996,7 @@ def test_update_storage_class_w_encryption_key(self): self.assertEqual(kw[0]['method'], 'POST') PATH = '/b/name/o/%s/rewriteTo/b/name/o/%s' % (BLOB_NAME, BLOB_NAME) self.assertEqual(kw[0]['path'], PATH) - self.assertNotIn('query_params', kw[0]) + self.assertEqual(kw[0]['query_params'], {'userProject': USER_PROJECT}) SENT = {'storageClass': STORAGE_CLASS} self.assertEqual(kw[0]['data'], SENT) @@ -2240,6 +2430,37 @@ def test_default(self): self.assertEqual(exc_info.exception.errors, []) +class Test__add_query_parameters(unittest.TestCase): + + @staticmethod + def _call_fut(*args, **kwargs): + from google.cloud.storage.blob import _add_query_parameters + + return _add_query_parameters(*args, **kwargs) + + def test_w_empty_list(self): + BASE_URL = 'https://test.example.com/base' + self.assertEqual(self._call_fut(BASE_URL, []), BASE_URL) + + def test_wo_existing_qs(self): + BASE_URL = 'https://test.example.com/base' + NV_LIST = [('one', 'One'), ('two', 'Two')] + expected = '&'.join([ + '{}={}'.format(name, value) for name, value in NV_LIST]) + self.assertEqual( + self._call_fut(BASE_URL, NV_LIST), + '{}?{}'.format(BASE_URL, expected)) + + def test_w_existing_qs(self): + BASE_URL = 'https://test.example.com/base?one=Three' + NV_LIST = [('one', 'One'), ('two', 'Two')] + expected = '&'.join([ + '{}={}'.format(name, value) for name, value in NV_LIST]) + self.assertEqual( + self._call_fut(BASE_URL, NV_LIST), + '{}&{}'.format(BASE_URL, expected)) + + class _Connection(object): API_BASE_URL = 'http://example.com' @@ -2267,7 +2488,7 @@ def api_request(self, **kw): class _Bucket(object): - def __init__(self, client=None, name='name'): + def __init__(self, client=None, name='name', user_project=None): if client is None: connection = _Connection() client = _Client(connection) @@ -2277,6 +2498,7 @@ def __init__(self, client=None, name='name'): self._deleted = [] self.name = name self.path = '/b/' + name + self.user_project = user_project def delete_blob(self, blob_name, client=None): del self._blobs[blob_name] diff --git a/storage/tests/unit/test_bucket.py b/storage/tests/unit/test_bucket.py index 31e5f817e1aa..370e401f2bb7 100644 --- a/storage/tests/unit/test_bucket.py +++ b/storage/tests/unit/test_bucket.py @@ -33,13 +33,21 @@ class _SigningCredentials( class Test_Bucket(unittest.TestCase): - def _make_one(self, client=None, name=None, properties=None): + @staticmethod + def _get_target_class(): from google.cloud.storage.bucket import Bucket + return Bucket + def _make_one( + self, client=None, name=None, properties=None, user_project=None): if client is None: connection = _Connection() client = _Client(connection) - bucket = Bucket(client, name=name) + if user_project is None: + bucket = self._get_target_class()(client, name=name) + else: + bucket = self._get_target_class()( + client, name=name, user_project=user_project) bucket._properties = properties or {} return bucket @@ -53,6 +61,21 @@ def test_ctor(self): self.assertIs(bucket._acl.bucket, bucket) self.assertFalse(bucket._default_object_acl.loaded) self.assertIs(bucket._default_object_acl.bucket, bucket) + self.assertIsNone(bucket.user_project) + + def test_ctor_w_user_project(self): + NAME = 'name' + USER_PROJECT = 'user-project-123' + connection = _Connection() + client = _Client(connection) + bucket = self._make_one(client, name=NAME, user_project=USER_PROJECT) + self.assertEqual(bucket.name, NAME) + self.assertEqual(bucket._properties, {}) + self.assertEqual(bucket.user_project, USER_PROJECT) + self.assertFalse(bucket._acl.loaded) + self.assertIs(bucket._acl.bucket, bucket) + self.assertFalse(bucket._default_object_acl.loaded) + self.assertIs(bucket._default_object_acl.bucket, bucket) def test_blob(self): from google.cloud.storage.blob import Blob @@ -131,9 +154,8 @@ def test_notification_explicit(self): notification.payload_format, JSON_API_V1_PAYLOAD_FORMAT) def test_bucket_name_value(self): - bucket_name = 'testing123' - mixin = self._make_one(name=bucket_name) - self.assertEqual(mixin.name, bucket_name) + BUCKET_NAME = 'bucket-name' + bucket = self._make_one(name=BUCKET_NAME) bad_start_bucket_name = '/testing123' with self.assertRaises(ValueError): @@ -143,6 +165,13 @@ def test_bucket_name_value(self): with self.assertRaises(ValueError): self._make_one(name=bad_end_bucket_name) + def test_user_project(self): + BUCKET_NAME = 'name' + USER_PROJECT = 'user-project-123' + bucket = self._make_one(name=BUCKET_NAME) + bucket._user_project = USER_PROJECT + self.assertEqual(bucket.user_project, USER_PROJECT) + def test_exists_miss(self): from google.cloud.exceptions import NotFound @@ -170,7 +199,9 @@ def api_request(cls, *args, **kwargs): expected_cw = [((), expected_called_kwargs)] self.assertEqual(_FakeConnection._called_with, expected_cw) - def test_exists_hit(self): + def test_exists_hit_w_user_project(self): + USER_PROJECT = 'user-project-123' + class _FakeConnection(object): _called_with = [] @@ -182,7 +213,7 @@ def api_request(cls, *args, **kwargs): return object() BUCKET_NAME = 'bucket-name' - bucket = self._make_one(name=BUCKET_NAME) + bucket = self._make_one(name=BUCKET_NAME, user_project=USER_PROJECT) client = _Client(_FakeConnection) self.assertTrue(bucket.exists(client=client)) expected_called_kwargs = { @@ -190,17 +221,29 @@ def api_request(cls, *args, **kwargs): 'path': bucket.path, 'query_params': { 'fields': 'name', + 'userProject': USER_PROJECT, }, '_target_object': None, } expected_cw = [((), expected_called_kwargs)] self.assertEqual(_FakeConnection._called_with, expected_cw) + def test_create_w_user_project(self): + PROJECT = 'PROJECT' + BUCKET_NAME = 'bucket-name' + USER_PROJECT = 'user-project-123' + connection = _Connection() + client = _Client(connection, project=PROJECT) + bucket = self._make_one(client, BUCKET_NAME, user_project=USER_PROJECT) + + with self.assertRaises(ValueError): + bucket.create() + def test_create_hit(self): + PROJECT = 'PROJECT' BUCKET_NAME = 'bucket-name' DATA = {'name': BUCKET_NAME} connection = _Connection(DATA) - PROJECT = 'PROJECT' client = _Client(connection, project=PROJECT) bucket = self._make_one(client=client, name=BUCKET_NAME) bucket.create() @@ -234,6 +277,7 @@ def test_create_w_extra_properties(self): 'location': LOCATION, 'storageClass': STORAGE_CLASS, 'versioning': {'enabled': True}, + 'billing': {'requesterPays': True}, 'labels': LABELS, } connection = _Connection(DATA) @@ -244,6 +288,7 @@ def test_create_w_extra_properties(self): bucket.location = LOCATION bucket.storage_class = STORAGE_CLASS bucket.versioning_enabled = True + bucket.requester_pays = True bucket.labels = LABELS bucket.create() @@ -290,18 +335,20 @@ def test_get_blob_miss(self): self.assertEqual(kw['method'], 'GET') self.assertEqual(kw['path'], '/b/%s/o/%s' % (NAME, NONESUCH)) - def test_get_blob_hit(self): + def test_get_blob_hit_w_user_project(self): NAME = 'name' BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' connection = _Connection({'name': BLOB_NAME}) client = _Client(connection) - bucket = self._make_one(name=NAME) + bucket = self._make_one(name=NAME, user_project=USER_PROJECT) blob = bucket.get_blob(BLOB_NAME, client=client) self.assertIs(blob.bucket, bucket) self.assertEqual(blob.name, BLOB_NAME) kw, = connection._requested self.assertEqual(kw['method'], 'GET') self.assertEqual(kw['path'], '/b/%s/o/%s' % (NAME, BLOB_NAME)) + self.assertEqual(kw['query_params'], {'userProject': USER_PROJECT}) def test_get_blob_hit_with_kwargs(self): from google.cloud.storage.blob import _get_encryption_headers @@ -339,8 +386,9 @@ def test_list_blobs_defaults(self): self.assertEqual(kw['path'], '/b/%s/o' % NAME) self.assertEqual(kw['query_params'], {'projection': 'noAcl'}) - def test_list_blobs_w_all_arguments(self): + def test_list_blobs_w_all_arguments_and_user_project(self): NAME = 'name' + USER_PROJECT = 'user-project-123' MAX_RESULTS = 10 PAGE_TOKEN = 'ABCD' PREFIX = 'subfolder' @@ -356,10 +404,11 @@ def test_list_blobs_w_all_arguments(self): 'versions': VERSIONS, 'projection': PROJECTION, 'fields': FIELDS, + 'userProject': USER_PROJECT, } connection = _Connection({'items': []}) client = _Client(connection) - bucket = self._make_one(name=NAME) + bucket = self._make_one(name=NAME, user_project=USER_PROJECT) iterator = bucket.list_blobs( max_results=MAX_RESULTS, page_token=PAGE_TOKEN, @@ -453,23 +502,27 @@ def test_delete_miss(self): expected_cw = [{ 'method': 'DELETE', 'path': bucket.path, + 'query_params': {}, '_target_object': None, }] self.assertEqual(connection._deleted_buckets, expected_cw) - def test_delete_hit(self): + def test_delete_hit_with_user_project(self): NAME = 'name' + USER_PROJECT = 'user-project-123' GET_BLOBS_RESP = {'items': []} connection = _Connection(GET_BLOBS_RESP) connection._delete_bucket = True client = _Client(connection) - bucket = self._make_one(client=client, name=NAME) + bucket = self._make_one( + client=client, name=NAME, user_project=USER_PROJECT) result = bucket.delete(force=True) self.assertIsNone(result) expected_cw = [{ 'method': 'DELETE', 'path': bucket.path, '_target_object': None, + 'query_params': {'userProject': USER_PROJECT}, }] self.assertEqual(connection._deleted_buckets, expected_cw) @@ -494,6 +547,7 @@ def test_delete_force_delete_blobs(self): expected_cw = [{ 'method': 'DELETE', 'path': bucket.path, + 'query_params': {}, '_target_object': None, }] self.assertEqual(connection._deleted_buckets, expected_cw) @@ -512,6 +566,7 @@ def test_delete_force_miss_blobs(self): expected_cw = [{ 'method': 'DELETE', 'path': bucket.path, + 'query_params': {}, '_target_object': None, }] self.assertEqual(connection._deleted_buckets, expected_cw) @@ -548,18 +603,22 @@ def test_delete_blob_miss(self): kw, = connection._requested self.assertEqual(kw['method'], 'DELETE') self.assertEqual(kw['path'], '/b/%s/o/%s' % (NAME, NONESUCH)) + self.assertEqual(kw['query_params'], {}) - def test_delete_blob_hit(self): + def test_delete_blob_hit_with_user_project(self): NAME = 'name' BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' connection = _Connection({}) client = _Client(connection) - bucket = self._make_one(client=client, name=NAME) + bucket = self._make_one( + client=client, name=NAME, user_project=USER_PROJECT) result = bucket.delete_blob(BLOB_NAME) self.assertIsNone(result) kw, = connection._requested self.assertEqual(kw['method'], 'DELETE') self.assertEqual(kw['path'], '/b/%s/o/%s' % (NAME, BLOB_NAME)) + self.assertEqual(kw['query_params'], {'userProject': USER_PROJECT}) def test_delete_blobs_empty(self): NAME = 'name' @@ -569,17 +628,20 @@ def test_delete_blobs_empty(self): bucket.delete_blobs([]) self.assertEqual(connection._requested, []) - def test_delete_blobs_hit(self): + def test_delete_blobs_hit_w_user_project(self): NAME = 'name' BLOB_NAME = 'blob-name' + USER_PROJECT = 'user-project-123' connection = _Connection({}) client = _Client(connection) - bucket = self._make_one(client=client, name=NAME) + bucket = self._make_one( + client=client, name=NAME, user_project=USER_PROJECT) bucket.delete_blobs([BLOB_NAME]) kw = connection._requested self.assertEqual(len(kw), 1) self.assertEqual(kw[0]['method'], 'DELETE') self.assertEqual(kw[0]['path'], '/b/%s/o/%s' % (NAME, BLOB_NAME)) + self.assertEqual(kw[0]['query_params'], {'userProject': USER_PROJECT}) def test_delete_blobs_miss_no_on_error(self): from google.cloud.exceptions import NotFound @@ -637,6 +699,7 @@ class _Blob(object): DEST, BLOB_NAME) self.assertEqual(kw['method'], 'POST') self.assertEqual(kw['path'], COPY_PATH) + self.assertEqual(kw['query_params'], {}) def test_copy_blobs_preserve_acl(self): from google.cloud.storage.acl import ObjectACL @@ -668,14 +731,17 @@ class _Blob(object): self.assertEqual(len(kw), 2) self.assertEqual(kw[0]['method'], 'POST') self.assertEqual(kw[0]['path'], COPY_PATH) + self.assertEqual(kw[0]['query_params'], {}) self.assertEqual(kw[1]['method'], 'PATCH') self.assertEqual(kw[1]['path'], NEW_BLOB_PATH) + self.assertEqual(kw[1]['query_params'], {'projection': 'full'}) - def test_copy_blobs_w_name(self): + def test_copy_blobs_w_name_and_user_project(self): SOURCE = 'source' DEST = 'dest' BLOB_NAME = 'blob-name' NEW_NAME = 'new_name' + USER_PROJECT = 'user-project-123' class _Blob(object): name = BLOB_NAME @@ -683,7 +749,8 @@ class _Blob(object): connection = _Connection({}) client = _Client(connection) - source = self._make_one(client=client, name=SOURCE) + source = self._make_one( + client=client, name=SOURCE, user_project=USER_PROJECT) dest = self._make_one(client=client, name=DEST) blob = _Blob() new_blob = source.copy_blob(blob, dest, NEW_NAME) @@ -694,6 +761,7 @@ class _Blob(object): DEST, NEW_NAME) self.assertEqual(kw['method'], 'POST') self.assertEqual(kw['path'], COPY_PATH) + self.assertEqual(kw['query_params'], {'userProject': USER_PROJECT}) def test_rename_blob(self): BUCKET_NAME = 'BUCKET_NAME' @@ -1026,6 +1094,24 @@ def test_versioning_enabled_setter(self): bucket.versioning_enabled = True self.assertTrue(bucket.versioning_enabled) + def test_requester_pays_getter_missing(self): + NAME = 'name' + bucket = self._make_one(name=NAME) + self.assertEqual(bucket.requester_pays, False) + + def test_requester_pays_getter(self): + NAME = 'name' + before = {'billing': {'requesterPays': True}} + bucket = self._make_one(name=NAME, properties=before) + self.assertEqual(bucket.requester_pays, True) + + def test_requester_pays_setter(self): + NAME = 'name' + bucket = self._make_one(name=NAME) + self.assertFalse(bucket.requester_pays) + bucket.requester_pays = True + self.assertTrue(bucket.requester_pays) + def test_configure_website_defaults(self): NAME = 'name' UNSET = {'website': {'mainPageSuffix': None, @@ -1094,6 +1180,40 @@ def test_get_iam_policy(self): self.assertEqual(len(kw), 1) self.assertEqual(kw[0]['method'], 'GET') self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0]['query_params'], {}) + + def test_get_iam_policy_w_user_project(self): + from google.cloud.iam import Policy + + NAME = 'name' + USER_PROJECT = 'user-project-123' + PATH = '/b/%s' % (NAME,) + ETAG = 'DEADBEEF' + VERSION = 17 + RETURNED = { + 'resourceId': PATH, + 'etag': ETAG, + 'version': VERSION, + 'bindings': [], + } + EXPECTED = {} + connection = _Connection(RETURNED) + client = _Client(connection, None) + bucket = self._make_one( + client=client, name=NAME, user_project=USER_PROJECT) + + policy = bucket.get_iam_policy() + + self.assertIsInstance(policy, Policy) + self.assertEqual(policy.etag, RETURNED['etag']) + self.assertEqual(policy.version, RETURNED['version']) + self.assertEqual(dict(policy), EXPECTED) + + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'GET') + self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0]['query_params'], {'userProject': USER_PROJECT}) def test_set_iam_policy(self): import operator @@ -1140,6 +1260,66 @@ def test_set_iam_policy(self): self.assertEqual(len(kw), 1) self.assertEqual(kw[0]['method'], 'PUT') self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0]['query_params'], {}) + sent = kw[0]['data'] + self.assertEqual(sent['resourceId'], PATH) + self.assertEqual(len(sent['bindings']), len(BINDINGS)) + key = operator.itemgetter('role') + for found, expected in zip( + sorted(sent['bindings'], key=key), + sorted(BINDINGS, key=key)): + self.assertEqual(found['role'], expected['role']) + self.assertEqual( + sorted(found['members']), sorted(expected['members'])) + + def test_set_iam_policy_w_user_project(self): + import operator + from google.cloud.storage.iam import STORAGE_OWNER_ROLE + from google.cloud.storage.iam import STORAGE_EDITOR_ROLE + from google.cloud.storage.iam import STORAGE_VIEWER_ROLE + from google.cloud.iam import Policy + + NAME = 'name' + USER_PROJECT = 'user-project-123' + PATH = '/b/%s' % (NAME,) + ETAG = 'DEADBEEF' + VERSION = 17 + OWNER1 = 'user:phred@example.com' + OWNER2 = 'group:cloud-logs@google.com' + EDITOR1 = 'domain:google.com' + EDITOR2 = 'user:phred@example.com' + VIEWER1 = 'serviceAccount:1234-abcdef@service.example.com' + VIEWER2 = 'user:phred@example.com' + BINDINGS = [ + {'role': STORAGE_OWNER_ROLE, 'members': [OWNER1, OWNER2]}, + {'role': STORAGE_EDITOR_ROLE, 'members': [EDITOR1, EDITOR2]}, + {'role': STORAGE_VIEWER_ROLE, 'members': [VIEWER1, VIEWER2]}, + ] + RETURNED = { + 'etag': ETAG, + 'version': VERSION, + 'bindings': BINDINGS, + } + policy = Policy() + for binding in BINDINGS: + policy[binding['role']] = binding['members'] + + connection = _Connection(RETURNED) + client = _Client(connection, None) + bucket = self._make_one( + client=client, name=NAME, user_project=USER_PROJECT) + + returned = bucket.set_iam_policy(policy) + + self.assertEqual(returned.etag, ETAG) + self.assertEqual(returned.version, VERSION) + self.assertEqual(dict(returned), dict(policy)) + + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'PUT') + self.assertEqual(kw[0]['path'], '%s/iam' % (PATH,)) + self.assertEqual(kw[0]['query_params'], {'userProject': USER_PROJECT}) sent = kw[0]['data'] self.assertEqual(sent['resourceId'], PATH) self.assertEqual(len(sent['bindings']), len(BINDINGS)) @@ -1179,6 +1359,38 @@ def test_test_iam_permissions(self): self.assertEqual(kw[0]['path'], '%s/iam/testPermissions' % (PATH,)) self.assertEqual(kw[0]['query_params'], {'permissions': PERMISSIONS}) + def test_test_iam_permissions_w_user_project(self): + from google.cloud.storage.iam import STORAGE_OBJECTS_LIST + from google.cloud.storage.iam import STORAGE_BUCKETS_GET + from google.cloud.storage.iam import STORAGE_BUCKETS_UPDATE + + NAME = 'name' + USER_PROJECT = 'user-project-123' + PATH = '/b/%s' % (NAME,) + PERMISSIONS = [ + STORAGE_OBJECTS_LIST, + STORAGE_BUCKETS_GET, + STORAGE_BUCKETS_UPDATE, + ] + ALLOWED = PERMISSIONS[1:] + RETURNED = {'permissions': ALLOWED} + connection = _Connection(RETURNED) + client = _Client(connection, None) + bucket = self._make_one( + client=client, name=NAME, user_project=USER_PROJECT) + + allowed = bucket.test_iam_permissions(PERMISSIONS) + + self.assertEqual(allowed, ALLOWED) + + kw = connection._requested + self.assertEqual(len(kw), 1) + self.assertEqual(kw[0]['method'], 'GET') + self.assertEqual(kw[0]['path'], '%s/iam/testPermissions' % (PATH,)) + self.assertEqual( + kw[0]['query_params'], + {'permissions': PERMISSIONS, 'userProject': USER_PROJECT}) + def test_make_public_defaults(self): from google.cloud.storage.acl import _ACLEntity diff --git a/storage/tests/unit/test_client.py b/storage/tests/unit/test_client.py index 39a27b9c0773..cb3da57117ad 100644 --- a/storage/tests/unit/test_client.py +++ b/storage/tests/unit/test_client.py @@ -140,6 +140,22 @@ def test_bucket(self): self.assertIsInstance(bucket, Bucket) self.assertIs(bucket.client, client) self.assertEqual(bucket.name, BUCKET_NAME) + self.assertIsNone(bucket.user_project) + + def test_bucket_w_user_project(self): + from google.cloud.storage.bucket import Bucket + + PROJECT = 'PROJECT' + USER_PROJECT = 'USER_PROJECT' + CREDENTIALS = _make_credentials() + BUCKET_NAME = 'BUCKET_NAME' + + client = self._make_one(project=PROJECT, credentials=CREDENTIALS) + bucket = client.bucket(BUCKET_NAME, user_project=USER_PROJECT) + self.assertIsInstance(bucket, Bucket) + self.assertIs(bucket.client, client) + self.assertEqual(bucket.name, BUCKET_NAME) + self.assertEqual(bucket.user_project, USER_PROJECT) def test_batch(self): from google.cloud.storage.batch import Batch @@ -184,23 +200,23 @@ def test_get_bucket_hit(self): CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', client._connection.API_VERSION, 'b', - '%s?projection=noAcl' % (BLOB_NAME,), + '%s?projection=noAcl' % (BUCKET_NAME,), ]) - data = {'name': BLOB_NAME} + data = {'name': BUCKET_NAME} http = _make_requests_session([_make_json_response(data)]) client._http_internal = http - bucket = client.get_bucket(BLOB_NAME) + bucket = client.get_bucket(BUCKET_NAME) self.assertIsInstance(bucket, Bucket) - self.assertEqual(bucket.name, BLOB_NAME) + self.assertEqual(bucket.name, BUCKET_NAME) http.request.assert_called_once_with( method='GET', url=URI, data=mock.ANY, headers=mock.ANY) @@ -234,22 +250,22 @@ def test_lookup_bucket_hit(self): CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', client._connection.API_VERSION, 'b', - '%s?projection=noAcl' % (BLOB_NAME,), + '%s?projection=noAcl' % (BUCKET_NAME,), ]) - data = {'name': BLOB_NAME} + data = {'name': BUCKET_NAME} http = _make_requests_session([_make_json_response(data)]) client._http_internal = http - bucket = client.lookup_bucket(BLOB_NAME) + bucket = client.lookup_bucket(BUCKET_NAME) self.assertIsInstance(bucket, Bucket) - self.assertEqual(bucket.name, BLOB_NAME) + self.assertEqual(bucket.name, BUCKET_NAME) http.request.assert_called_once_with( method='GET', url=URI, data=mock.ANY, headers=mock.ANY) @@ -260,7 +276,7 @@ def test_create_bucket_conflict(self): CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', @@ -268,13 +284,16 @@ def test_create_bucket_conflict(self): 'b?project=%s' % (PROJECT,), ]) data = {'error': {'message': 'Conflict'}} + sent = {'name': BUCKET_NAME} http = _make_requests_session([ _make_json_response(data, status=http_client.CONFLICT)]) client._http_internal = http - self.assertRaises(Conflict, client.create_bucket, BLOB_NAME) + self.assertRaises(Conflict, client.create_bucket, BUCKET_NAME) http.request.assert_called_once_with( method='POST', url=URI, data=mock.ANY, headers=mock.ANY) + json_sent = http.request.call_args_list[0][1]['data'] + self.assertEqual(sent, json.loads(json_sent)) def test_create_bucket_success(self): from google.cloud.storage.bucket import Bucket @@ -283,23 +302,27 @@ def test_create_bucket_success(self): CREDENTIALS = _make_credentials() client = self._make_one(project=PROJECT, credentials=CREDENTIALS) - BLOB_NAME = 'blob-name' + BUCKET_NAME = 'bucket-name' URI = '/'.join([ client._connection.API_BASE_URL, 'storage', client._connection.API_VERSION, 'b?project=%s' % (PROJECT,), ]) - data = {'name': BLOB_NAME} + sent = {'name': BUCKET_NAME, 'billing': {'requesterPays': True}} + data = sent http = _make_requests_session([_make_json_response(data)]) client._http_internal = http - bucket = client.create_bucket(BLOB_NAME) + bucket = client.create_bucket(BUCKET_NAME, requester_pays=True) self.assertIsInstance(bucket, Bucket) - self.assertEqual(bucket.name, BLOB_NAME) + self.assertEqual(bucket.name, BUCKET_NAME) + self.assertTrue(bucket.requester_pays) http.request.assert_called_once_with( method='POST', url=URI, data=mock.ANY, headers=mock.ANY) + json_sent = http.request.call_args_list[0][1]['data'] + self.assertEqual(sent, json.loads(json_sent)) def test_list_buckets_empty(self): from six.moves.urllib.parse import parse_qs @@ -422,7 +445,7 @@ def test_page_non_empty_response(self): credentials = _make_credentials() client = self._make_one(project=project, credentials=credentials) - blob_name = 'blob-name' + blob_name = 'bucket-name' response = {'items': [{'name': blob_name}]} def dummy_response():