Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept custom HTTP headers when reading over HTTP(S) #272

Merged
merged 15 commits into from
Jul 18, 2019
Merged
35 changes: 25 additions & 10 deletions smart_open/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"""


def open(uri, mode, kerberos=False, user=None, password=None):
def open(uri, mode, kerberos=False, user=None, password=None, headers=None):
"""Implement streamed reader from a web site.

Supports Kerberos and Basic HTTP authentication.
Expand All @@ -39,20 +39,27 @@ def open(uri, mode, kerberos=False, user=None, password=None):
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
headers: dict, optional
Any headers to send in the request. If none, default headers sent are:
{'Accept-Encoding': 'identity'}. To not use default headers or any other
headers, set this variable to an empty dict, {}.

Note
----
If neither kerberos or (user, password) are set, will connect unauthenticated.
If neither kerberos or (user, password) are set, will connect
unauthenticated, unless set separately in headers.

"""
if mode == 'rb':
return BufferedInputBase(uri, mode, kerberos=kerberos, user=user, password=password)
return BufferedInputBase(uri, mode, kerberos=kerberos,
user=user, password=password, headers=headers)
else:
raise NotImplementedError('http support for mode %r not implemented' % mode)


class BufferedInputBase(io.BufferedIOBase):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, headers=None):
if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
Expand All @@ -64,7 +71,12 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=Fals
self.buffer_size = buffer_size
self.mode = mode

self.response = requests.get(url, auth=auth, stream=True, headers=_HEADERS)
if headers is None:
self.headers = _HEADERS.copy()
else:
self.headers = headers

self.response = requests.get(url, auth=auth, stream=True, headers=self.headers)

if not self.response.ok:
self.response.raise_for_status()
Expand Down Expand Up @@ -154,7 +166,7 @@ class SeekableBufferedInputBase(BufferedInputBase):
"""

def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None):
kerberos=False, user=None, password=None, headers=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
Otherwise, will try to use "basic" HTTP authentication via username/password.
Expand All @@ -171,6 +183,11 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
else:
self.auth = None

if headers is None:
self.headers = _HEADERS.copy()
else:
self.headers = headers

self.buffer_size = buffer_size
self.mode = mode
self.response = self._partial_request()
Expand Down Expand Up @@ -253,10 +270,8 @@ def truncate(self, size=None):
raise io.UnsupportedOperation

def _partial_request(self, start_pos=None):
headers = _HEADERS.copy()

if start_pos is not None:
headers.update({"range": s3.make_range_string(start_pos)})
self.headers.update({"range": s3.make_range_string(start_pos)})

response = requests.get(self.url, auth=self.auth, stream=True, headers=headers)
response = requests.get(self.url, auth=self.auth, stream=True, headers=self.headers)
return response
24 changes: 24 additions & 0 deletions smart_open/tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,27 @@ def test_seek_from_end(self):
read_bytes = reader.read(size=10)
self.assertEqual(reader.tell(), len(BYTES))
self.assertEqual(BYTES[-10:], read_bytes)

@responses.activate
def test_headers_are_as_assigned(self):
responses.add_callback(responses.GET, URL, callback=request_callback)

# use default _HEADERS
x = smart_open.http.BufferedInputBase(URL)
# set different ones
x.headers['Accept-Encoding'] = 'compress, gzip'
x.headers['Other-Header'] = 'value'

# use default again, global shoudn't overwritten from x
y = smart_open.http.BufferedInputBase(URL)
# should be default headers
self.assertEqual(y.headers, {'Accept-Encoding': 'identity'})
# should be assigned headers
self.assertEqual(x.headers, {'Accept-Encoding': 'compress, gzip', 'Other-Header': 'value'})

@responses.activate
def test_headers(self):
"""Does the top-level http.open function correctly handle headers?"""
responses.add_callback(responses.GET, URL, callback=request_callback)
reader = smart_open.http.open(URL, 'rb', headers={'Foo': 'bar'})
self.assertEqual(reader.headers['Foo'], 'bar')