diff --git a/docs/advanced.rst b/docs/advanced.rst index cf93dcc0..9cba3ba6 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -136,7 +136,8 @@ Create your own persistence class, see the example below: Your custom persister must implement both ``load_cassette`` and ``save_cassette`` methods. The ``load_cassette`` method must return a deserialized cassette or raise -``ValueError`` if no cassette is found. +either ``CassetteNotFoundError`` if no cassette is found, or ``CassetteDecodeError`` +if the cassette cannot be successfully deserialized. Once the persister class is defined, register with VCR like so... diff --git a/tests/integration/test_register_persister.py b/tests/integration/test_register_persister.py index 1391b981..42b8736b 100644 --- a/tests/integration/test_register_persister.py +++ b/tests/integration/test_register_persister.py @@ -5,9 +5,11 @@ import os from urllib.request import urlopen +import pytest + # Internal imports import vcr -from vcr.persisters.filesystem import FilesystemPersister +from vcr.persisters.filesystem import CassetteDecodeError, CassetteNotFoundError, FilesystemPersister class CustomFilesystemPersister: @@ -25,6 +27,19 @@ def save_cassette(cassette_path, cassette_dict, serializer): FilesystemPersister.save_cassette(cassette_path, cassette_dict, serializer) +class BadPersister(FilesystemPersister): + """A bad persister that raises different errors.""" + + @staticmethod + def load_cassette(cassette_path, serializer): + if "nonexistent" in cassette_path: + raise CassetteNotFoundError() + elif "encoding" in cassette_path: + raise CassetteDecodeError() + else: + raise ValueError("buggy persister") + + def test_save_cassette_with_custom_persister(tmpdir, httpbin): """Ensure you can save a cassette using custom persister""" my_vcr = vcr.VCR() @@ -53,3 +68,22 @@ def test_load_cassette_with_custom_persister(tmpdir, httpbin): with my_vcr.use_cassette(test_fixture, serializer="json"): response = urlopen(httpbin.url).read() assert b"difficult sometimes" in response + + +def test_load_cassette_persister_exception_handling(tmpdir, httpbin): + """ + Ensure expected errors from persister are swallowed while unexpected ones + are passed up the call stack. + """ + my_vcr = vcr.VCR() + my_vcr.register_persister(BadPersister) + + with my_vcr.use_cassette("bad/nonexistent") as cass: + assert len(cass) == 0 + + with my_vcr.use_cassette("bad/encoding") as cass: + assert len(cass) == 0 + + with pytest.raises(ValueError): + with my_vcr.use_cassette("bad/buggy") as cass: + pass diff --git a/tests/unit/test_cassettes.py b/tests/unit/test_cassettes.py index 41e3df53..cbe9de1a 100644 --- a/tests/unit/test_cassettes.py +++ b/tests/unit/test_cassettes.py @@ -29,6 +29,19 @@ def test_cassette_load(tmpdir): assert len(a_cassette) == 1 +def test_cassette_load_nonexistent(): + a_cassette = Cassette.load(path="something/nonexistent.yml") + assert len(a_cassette) == 0 + + +def test_cassette_load_invalid_encoding(tmpdir): + a_file = tmpdir.join("invalid_encoding.yml") + with open(a_file, "wb") as fd: + fd.write(b"\xda") + a_cassette = Cassette.load(path=str(a_file)) + assert len(a_cassette) == 0 + + def test_cassette_not_played(): a = Cassette("test") assert not a.play_count diff --git a/vcr/cassette.py b/vcr/cassette.py index 5822afac..77ffe616 100644 --- a/vcr/cassette.py +++ b/vcr/cassette.py @@ -11,7 +11,7 @@ from .errors import UnhandledHTTPRequestError from .matchers import get_matchers_results, method, requests_match, uri from .patch import CassettePatcherBuilder -from .persisters.filesystem import FilesystemPersister +from .persisters.filesystem import CassetteDecodeError, CassetteNotFoundError, FilesystemPersister from .record_mode import RecordMode from .serializers import yamlserializer from .util import partition_dict @@ -352,7 +352,7 @@ def _load(self): self.append(request, response) self.dirty = False self.rewound = True - except ValueError: + except (CassetteDecodeError, CassetteNotFoundError): pass def __str__(self): diff --git a/vcr/persisters/filesystem.py b/vcr/persisters/filesystem.py index e9710638..d7bd4518 100644 --- a/vcr/persisters/filesystem.py +++ b/vcr/persisters/filesystem.py @@ -5,17 +5,25 @@ from ..serialize import deserialize, serialize +class CassetteNotFoundError(FileNotFoundError): + pass + + +class CassetteDecodeError(ValueError): + pass + + class FilesystemPersister: @classmethod def load_cassette(cls, cassette_path, serializer): cassette_path = Path(cassette_path) # if cassette path is already Path this is no operation if not cassette_path.is_file(): - raise ValueError("Cassette not found.") + raise CassetteNotFoundError() try: with cassette_path.open() as f: data = f.read() - except UnicodeEncodeError as err: - raise ValueError("Can't read Cassette, Encoding is broken") from err + except UnicodeDecodeError as err: + raise CassetteDecodeError("Can't read Cassette, Encoding is broken") from err return deserialize(data, serializer)