Skip to content

Commit 68a39d9

Browse files
derpferdmpenkov
andauthored
Added credentials option for iter_bucket. (#372)
* Added credentials option for iter_bucket. This closes issue #259 * Applied PR Feedback * Formatting changes * explicitly construct session with session_kwargs * fix bug, update unit tests * remove test with tricky mocks * isolate credentials test Co-authored-by: Michael Penkov <m@penkov.dev>
1 parent f729054 commit 68a39d9

File tree

2 files changed

+53
-8
lines changed

2 files changed

+53
-8
lines changed

smart_open/s3.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -655,8 +655,14 @@ def _accept_all(key):
655655
return True
656656

657657

658-
def iter_bucket(bucket_name, prefix='', accept_key=None,
659-
key_limit=None, workers=16, retries=3):
658+
def iter_bucket(
659+
bucket_name,
660+
prefix='',
661+
accept_key=None,
662+
key_limit=None,
663+
workers=16,
664+
retries=3,
665+
**session_kwargs):
660666
"""
661667
Iterate and download all S3 objects under `s3://bucket_name/prefix`.
662668
@@ -676,6 +682,11 @@ def iter_bucket(bucket_name, prefix='', accept_key=None,
676682
The number of subprocesses to use.
677683
retries: int, optional
678684
The number of time to retry a failed download.
685+
session_kwargs: dict, optional
686+
Keyword arguments to pass when creating a new session.
687+
For a list of available names and values, see:
688+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session
689+
679690
680691
Yields
681692
------
@@ -716,8 +727,16 @@ def iter_bucket(bucket_name, prefix='', accept_key=None,
716727
pass
717728

718729
total_size, key_no = 0, -1
719-
key_iterator = _list_bucket(bucket_name, prefix=prefix, accept_key=accept_key)
720-
download_key = functools.partial(_download_key, bucket_name=bucket_name, retries=retries)
730+
key_iterator = _list_bucket(
731+
bucket_name,
732+
prefix=prefix,
733+
accept_key=accept_key,
734+
**session_kwargs)
735+
download_key = functools.partial(
736+
_download_key,
737+
bucket_name=bucket_name,
738+
retries=retries,
739+
**session_kwargs)
721740

722741
with _create_process_pool(processes=workers) as pool:
723742
result_iterator = pool.imap_unordered(download_key, key_iterator)
@@ -736,8 +755,13 @@ def iter_bucket(bucket_name, prefix='', accept_key=None,
736755
logger.info("processed %i keys, total size %i" % (key_no + 1, total_size))
737756

738757

739-
def _list_bucket(bucket_name, prefix='', accept_key=lambda k: True):
740-
client = boto3.client('s3')
758+
def _list_bucket(
759+
bucket_name,
760+
prefix='',
761+
accept_key=lambda k: True,
762+
**session_kwargs):
763+
session = boto3.session.Session(**session_kwargs)
764+
client = session.client('s3')
741765
ctoken = None
742766

743767
while True:
@@ -762,14 +786,14 @@ def _list_bucket(bucket_name, prefix='', accept_key=lambda k: True):
762786
break
763787

764788

765-
def _download_key(key_name, bucket_name=None, retries=3):
789+
def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs):
766790
if bucket_name is None:
767791
raise ValueError('bucket_name may not be None')
768792

769793
#
770794
# https://geekpete.com/blog/multithreading-boto3/
771795
#
772-
session = boto3.session.Session()
796+
session = boto3.session.Session(**session_kwargs)
773797
s3 = session.resource('s3')
774798
bucket = s3.Bucket(bucket_name)
775799

smart_open/tests/test_s3.py

+21
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,27 @@ def test(self):
564564
self.assertEqual(sorted(keys), sorted(expected))
565565

566566

567+
#
568+
# This has to be a separate test because we cannot run it against real S3
569+
# (we don't want to expose our real S3 credentials).
570+
#
571+
@moto.mock_s3
572+
class IterBucketCredentialsTest(unittest.TestCase):
573+
574+
def test(self):
575+
num_keys = 10
576+
populate_bucket(num_keys=num_keys)
577+
result = list(
578+
smart_open.s3.iter_bucket(
579+
BUCKET_NAME,
580+
workers=None,
581+
aws_access_key_id='access_id',
582+
aws_secret_access_key='access_secret'
583+
)
584+
)
585+
self.assertEqual(len(result), num_keys)
586+
587+
567588
@maybe_mock_s3
568589
class DownloadKeyTest(unittest.TestCase):
569590
def setUp(self):

0 commit comments

Comments
 (0)