diff --git a/smart_open/http.py b/smart_open/http.py index f3b3e40b..1510fce2 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -51,7 +51,7 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None): """ if mode == 'rb': - return BufferedInputBase( + return SeekableBufferedInputBase( uri, mode, kerberos=kerberos, user=user, password=password, headers=headers ) diff --git a/smart_open/tests/test_http.py b/smart_open/tests/test_http.py index dacc71cd..c8fe86af 100644 --- a/smart_open/tests/test_http.py +++ b/smart_open/tests/test_http.py @@ -8,6 +8,7 @@ BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter' URL = 'http://localhost' +HTTPS_URL = 'https://localhost' HEADERS = { 'Content-Length': str(len(BYTES)), 'Accept-Ranges': 'bytes', @@ -107,3 +108,35 @@ def test_headers(self): responses.add_callback(responses.GET, URL, callback=request_callback) reader = smart_open.http.open(URL, 'rb', headers={'Foo': 'bar'}) self.assertEqual(reader.headers['Foo'], 'bar') + + @responses.activate + def test_https_seek_start(self): + """Did the seek start over HTTPS work?""" + responses.add_callback(responses.GET, HTTPS_URL, callback=request_callback) + + with smart_open.open(HTTPS_URL, "rb") as fin: + read_bytes_1 = fin.read(size=10) + fin.seek(0) + read_bytes_2 = fin.read(size=10) + self.assertEqual(read_bytes_1, read_bytes_2) + + @responses.activate + def test_https_seek_forward(self): + """Did the seek forward over HTTPS work?""" + responses.add_callback(responses.GET, HTTPS_URL, callback=request_callback) + + with smart_open.open(HTTPS_URL, "rb") as fin: + fin.seek(10) + read_bytes = fin.read(size=10) + self.assertEqual(BYTES[10:20], read_bytes) + + @responses.activate + def test_https_seek_reverse(self): + """Did the seek in reverse over HTTPS work?""" + responses.add_callback(responses.GET, HTTPS_URL, callback=request_callback) + + with smart_open.open(HTTPS_URL, "rb") as fin: + read_bytes_1 = fin.read(size=10) + fin.seek(-10, whence=smart_open.s3.CURRENT) + read_bytes_2 = fin.read(size=10) + self.assertEqual(read_bytes_1, read_bytes_2)