Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new top-level compression parameter #609

Merged
merged 20 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions smart_open/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@
_COMPRESSOR_REGISTRY = {}


NO_COMPRESSION = 'none'
"""Use no compression. Read/write the data as-is."""
INFER_FROM_EXTENSION = 'extension'
"""Determine the compression to use from the file extension.

See get_supported_extensions().
"""


def get_supported_compression_types():
"""Return the list of supported compression types available to open.

See compression paratemeter to smart_open.open().
"""
return [NO_COMPRESSION, INFER_FROM_EXTENSION] + [ext[1:] for ext in get_supported_extensions()]


def get_supported_extensions():
"""Return the list of file extensions for which we have registered compressors."""
return sorted(_COMPRESSOR_REGISTRY.keys())
Expand Down
37 changes: 29 additions & 8 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
# smart_open.submodule to reference to the submodules.
#
import smart_open.local_file as so_file
import smart_open.compression as so_compression

from smart_open import compression
from smart_open import doctools
from smart_open import transport

Expand Down Expand Up @@ -107,6 +107,7 @@ def open(
closefd=True,
opener=None,
ignore_ext=False,
compression=None,
transport_params=None,
):
r"""Open the URI object, returning a file-like object.
Expand Down Expand Up @@ -139,6 +140,9 @@ def open(
Mimicks built-in open parameter of the same name. Ignored.
ignore_ext: boolean, optional
Disable transparent compression/decompression based on the file extension.
compression: str, optional (see smart_open.compression.get_supported_compression_types)
Explicitly specify the compression/decompression behavior.
If you specify this parameter, then ignore_ext must not be specified.
transport_params: dict, optional
Additional parameters for the transport layer (see notes below).

Expand Down Expand Up @@ -168,13 +172,23 @@ def open(
if not isinstance(mode, str):
raise TypeError('mode should be a string')

if compression and ignore_ext:
raise ValueError('ignore_ext and compression parameters are mutually exclusive')
elif compression and compression not in so_compression.get_supported_compression_types():
raise ValueError(f'invalid compression type: {compression}')
elif ignore_ext:
compression = so_compression.NO_COMPRESSION
warnings.warn("'ignore_ext' will be deprecated in a future release", PendingDeprecationWarning)
elif compression is None:
compression = so_compression.INFER_FROM_EXTENSION

if transport_params is None:
transport_params = {}

fobj = _shortcut_open(
uri,
mode,
ignore_ext=ignore_ext,
compression=compression,
buffering=buffering,
encoding=encoding,
errors=errors,
Expand Down Expand Up @@ -219,10 +233,13 @@ def open(
raise NotImplementedError(ve.args[0])

binary = _open_binary_stream(uri, binary_mode, transport_params)
if ignore_ext:
if compression == so_compression.NO_COMPRESSION:
decompressed = binary
elif compression == so_compression.INFER_FROM_EXTENSION:
decompressed = so_compression.compression_wrapper(binary, binary_mode)
else:
decompressed = compression.compression_wrapper(binary, binary_mode)
faked_extension = f"{binary.name}.{compression.lower()}"
decompressed = so_compression.compression_wrapper(binary, binary_mode, filename=faked_extension)

if 'b' not in mode or explicit_encoding is not None:
decoded = _encoding_wrapper(
Expand Down Expand Up @@ -295,7 +312,7 @@ def transfer(char):
def _shortcut_open(
uri,
mode,
ignore_ext=False,
compression,
buffering=-1,
encoding=None,
errors=None,
Expand All @@ -309,12 +326,13 @@ def _shortcut_open(
This is only possible under the following conditions:

1. Opening a local file; and
2. Ignore extension is set to True
2. Compression is disabled

If it is not possible to use the built-in open for the specified URI, returns None.

:param str uri: A string indicating what to open.
:param str mode: The mode to pass to the open function.
:param str compression: The compression type selected.
:returns: The opened file
:rtype: file
"""
Expand All @@ -326,8 +344,11 @@ def _shortcut_open(
return None

local_path = so_file.extract_local_path(uri)
_, extension = P.splitext(local_path)
if extension in compression.get_supported_extensions() and not ignore_ext:
if compression == so_compression.INFER_FROM_EXTENSION:
_, extension = P.splitext(local_path)
if extension in so_compression.get_supported_extensions():
return None
elif compression != so_compression.NO_COMPRESSION:
return None

open_kwargs = {}
Expand Down
139 changes: 139 additions & 0 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import hashlib
import logging
import os
from smart_open.compression import INFER_FROM_EXTENSION, NO_COMPRESSION
import tempfile
import unittest
import warnings
Expand Down Expand Up @@ -1840,6 +1841,144 @@ def test(self):
self.assertEqual(expected, actual)


_RAW_DATA = "не слышны в саду даже шорохи".encode("utf-8")


@mock_s3
class HandleS3CompressionTestCase(parameterizedtestcase.ParameterizedTestCase):

def setUp(self):
s3 = boto3.resource("s3")
s3.create_bucket(Bucket="bucket").wait_until_exists()

# compression | ignore_ext | behavior |
# ----------- | ---------- | -------- |
# 'gz' | False | Override |
# 'bz2' | False | Override |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("_compression", "decompressor"),
[
("gz", gzip.decompress),
("bz2", bz2.decompress),
],
)
def test_rw_compression_prescribed(self, _compression, decompressor):
"""Should read/write files with `_compression`, as prescribed."""
key = "s3://bucket/key.txt"

with smart_open.open(key, "wb", compression=_compression) as fout:
fout.write(_RAW_DATA)

#
# Check that what we've created is compressed as expected.
#
with smart_open.open(key, "rb", compression=NO_COMPRESSION) as fin:
data = decompressor(fin.read())
assert data == _RAW_DATA

# compression | ignore_ext | behavior |
# ----------- | ---------- | -------- |
# 'extension' | False | Enable |
# 'none' | False | Disable |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("_compression", "decompressor"),
[
(
"gz",
gzip.decompress,
),
(
"bz2",
bz2.decompress,
)
],
)
def test_rw_compression_by_extension(
self, _compression, decompressor
):
"""Should read/write files with `_compression`, explicitily inferred by file extension."""
key = f"s3://bucket/key.{_compression}"

with smart_open.open(key, "wb", compression=INFER_FROM_EXTENSION) as fout:
fout.write(_RAW_DATA)

#
# Check that what we've created is compressed as expected.
#
with smart_open.open(key, "rb", compression=NO_COMPRESSION) as fin:
assert decompressor(fin.read()) == _RAW_DATA

# compression | ignore_ext | behavior |
# ----------- | ---------- | -------- |
# None | False | Enable |
# None | True | Disable |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("_compression", "decompressor"),
[
(
"gz",
gzip.decompress,
),
(
"bz2",
bz2.decompress,
),
],
)
def test_rw_compression_by_extension_deprecated(
self, _compression, decompressor
):
"""Should read/write files with `_compression`, implicitly inferred by file extension."""
key = f"s3://bucket/key.{_compression}"

with smart_open.open(key, "wb") as fout:
fout.write(_RAW_DATA)

#
# Check that what we've created is compressed as expected.
#
with smart_open.open(key, "rb", ignore_ext=True) as fin:
assert decompressor(fin.read()) == _RAW_DATA

# extension | compression | ignore_ext | behavior |
# ----------| ----------- | ---------- | -------- |
# <any> | <invalid> | <any> | Error |
# <any> | 'none' | True | Error |
# 'gz' | 'extension' | True | Error |
# 'bz2' | 'extension' | True | Error |
# <any> | 'gz' | True | Error |
# <any> | 'bz2' | True | Error |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("extension", "kwargs", "error"),
[
("", dict(compression="foo"), ValueError),
("", dict(compression="foo", ignore_ext=True), ValueError),
("", dict(compression=NO_COMPRESSION, ignore_ext=True), ValueError),
(
".gz",
dict(compression=INFER_FROM_EXTENSION, ignore_ext=True),
ValueError,
),
(
".bz2",
dict(compression=INFER_FROM_EXTENSION, ignore_ext=True),
ValueError,
),
("", dict(compression="gz", ignore_ext=True), ValueError),
("", dict(compression="bz2", ignore_ext=True), ValueError),
],
)
def test_compression_invalid(self, extension, kwargs, error):
"""Should detect and error on these invalid inputs"""
key = f"s3://bucket/key{extension}"

with pytest.raises(error):
smart_open.open(key, "wb", **kwargs)

with pytest.raises(error):
smart_open.open(key, "rb", **kwargs)


class GetBinaryModeTest(parameterizedtestcase.ParameterizedTestCase):
@parameterizedtestcase.ParameterizedTestCase.parameterize(
('mode', 'expected'),
Expand Down