Skip to content

Commit e529319

Browse files
authored
Refactor introduce recording storage (#274)
* refactor: separate injection and enable/disable logic * refactor: add class that handles request records
1 parent a5b5e34 commit e529319

File tree

6 files changed

+230
-130
lines changed

6 files changed

+230
-130
lines changed

.github/workflows/main.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
runs-on: ubuntu-20.04
2323
strategy:
2424
matrix:
25-
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', ' 3.13', 'pypy3.10']
25+
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10']
2626

2727
steps:
2828
- uses: actions/checkout@v4

mocket/inject.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import contextlib
4-
import os
54
import socket
65
import ssl
76
from types import ModuleType
@@ -23,10 +22,7 @@ def _restore(module: ModuleType, name: str) -> None:
2322
module.__dict__[name] = original_value
2423

2524

26-
def enable(
27-
namespace: str | None = None,
28-
truesocket_recording_dir: str | None = None,
29-
) -> None:
25+
def enable() -> None:
3026
from mocket.socket import (
3127
MocketSocket,
3228
mock_create_connection,
@@ -73,14 +69,6 @@ def enable(
7369

7470
extract_from_urllib3()
7571

76-
from mocket.mocket import Mocket
77-
78-
Mocket._namespace = namespace
79-
Mocket._truesocket_recording_dir = truesocket_recording_dir
80-
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
81-
# JSON dumps will be saved here
82-
raise AssertionError
83-
8472

8573
def disable() -> None:
8674
for module, name in list(_patches_restore.keys()):
@@ -90,7 +78,3 @@ def disable() -> None:
9078
from urllib3.contrib.pyopenssl import inject_into_urllib3
9179

9280
inject_into_urllib3()
93-
94-
from mocket.mocket import Mocket
95-
96-
Mocket.reset()

mocket/mocket.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import collections
44
import itertools
55
import os
6+
from pathlib import Path
67
from typing import TYPE_CHECKING, ClassVar
78

89
import mocket.inject
10+
from mocket.recording import MocketRecordStorage
911

1012
# NOTE this is here for backwards-compat to keep old import-paths working
1113
# from mocket.socket import MocketSocket as MocketSocket
@@ -20,11 +22,36 @@ class Mocket:
2022
_address: ClassVar[Address] = (None, None)
2123
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
2224
_requests: ClassVar[list] = []
23-
_namespace: ClassVar[str] = str(id(_entries))
24-
_truesocket_recording_dir: ClassVar[str | None] = None
25+
_record_storage: ClassVar[MocketRecordStorage | None] = None
2526

26-
enable = mocket.inject.enable
27-
disable = mocket.inject.disable
27+
@classmethod
28+
def enable(
29+
cls,
30+
namespace: str | None = None,
31+
truesocket_recording_dir: str | None = None,
32+
) -> None:
33+
if namespace is None:
34+
namespace = str(id(cls._entries))
35+
36+
if truesocket_recording_dir is not None:
37+
recording_dir = Path(truesocket_recording_dir)
38+
39+
if not recording_dir.is_dir():
40+
# JSON dumps will be saved here
41+
raise AssertionError
42+
43+
cls._record_storage = MocketRecordStorage(
44+
directory=recording_dir,
45+
namespace=namespace,
46+
)
47+
48+
mocket.inject.enable()
49+
50+
@classmethod
51+
def disable(cls) -> None:
52+
cls.reset()
53+
54+
mocket.inject.disable()
2855

2956
@classmethod
3057
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
@@ -69,6 +96,7 @@ def reset(cls) -> None:
6996
cls._socket_pairs = {}
7097
cls._entries = collections.defaultdict(list)
7198
cls._requests = []
99+
cls._record_storage = None
72100

73101
@classmethod
74102
def last_request(cls):
@@ -89,12 +117,16 @@ def has_requests(cls) -> bool:
89117
return bool(cls.request_list())
90118

91119
@classmethod
92-
def get_namespace(cls) -> str:
93-
return cls._namespace
120+
def get_namespace(cls) -> str | None:
121+
if not cls._record_storage:
122+
return None
123+
return cls._record_storage.namespace
94124

95125
@classmethod
96126
def get_truesocket_recording_dir(cls) -> str | None:
97-
return cls._truesocket_recording_dir
127+
if not cls._record_storage:
128+
return None
129+
return str(cls._record_storage.directory)
98130

99131
@classmethod
100132
def assert_fail_if_entries_not_served(cls) -> None:

mocket/recording.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import hashlib
5+
import json
6+
from collections import defaultdict
7+
from dataclasses import dataclass
8+
from pathlib import Path
9+
10+
from mocket.compat import decode_from_bytes, encode_to_bytes
11+
from mocket.types import Address
12+
from mocket.utils import hexdump, hexload
13+
14+
hash_function = hashlib.md5
15+
16+
with contextlib.suppress(ImportError):
17+
from xxhash_cffi import xxh32 as xxhash_cffi_xxh32
18+
19+
hash_function = xxhash_cffi_xxh32
20+
21+
with contextlib.suppress(ImportError):
22+
from xxhash import xxh32 as xxhash_xxh32
23+
24+
hash_function = xxhash_xxh32
25+
26+
27+
def _hash_prepare_request(data: bytes) -> bytes:
28+
_data = decode_from_bytes(data)
29+
return encode_to_bytes("".join(sorted(_data.split("\r\n"))))
30+
31+
32+
def _hash_request(data: bytes) -> str:
33+
_data = _hash_prepare_request(data)
34+
return hash_function(_data).hexdigest()
35+
36+
37+
def _hash_request_fallback(data: bytes) -> str:
38+
_data = _hash_prepare_request(data)
39+
return hashlib.md5(_data).hexdigest()
40+
41+
42+
@dataclass
43+
class MocketRecord:
44+
host: str
45+
port: int
46+
request: bytes
47+
response: bytes
48+
49+
50+
class MocketRecordStorage:
51+
def __init__(self, directory: Path, namespace: str) -> None:
52+
self._directory = directory
53+
self._namespace = namespace
54+
self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = (
55+
defaultdict(defaultdict)
56+
)
57+
58+
self._load()
59+
60+
@property
61+
def directory(self) -> Path:
62+
return self._directory
63+
64+
@property
65+
def namespace(self) -> str:
66+
return self._namespace
67+
68+
@property
69+
def file(self) -> Path:
70+
return self._directory / f"{self._namespace}.json"
71+
72+
def _load(self) -> None:
73+
if not self.file.exists():
74+
return
75+
76+
json_data = self.file.read_text()
77+
records = json.loads(json_data)
78+
for host, port_signature_record in records.items():
79+
for port, signature_record in port_signature_record.items():
80+
for signature, record in signature_record.items():
81+
# NOTE backward-compat
82+
try:
83+
request_data = hexload(record["request"])
84+
except ValueError:
85+
request_data = record["request"]
86+
87+
self._records[(host, int(port))][signature] = MocketRecord(
88+
host=host,
89+
port=port,
90+
request=request_data,
91+
response=hexload(record["response"]),
92+
)
93+
94+
def _save(self) -> None:
95+
data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict(
96+
lambda: defaultdict(defaultdict)
97+
)
98+
for address, signature_record in self._records.items():
99+
host, port = address
100+
for signature, record in signature_record.items():
101+
data[host][str(port)][signature] = dict(
102+
request=decode_from_bytes(record.request),
103+
response=hexdump(record.response),
104+
)
105+
106+
json_data = json.dumps(data, indent=4, sort_keys=True)
107+
self.file.parent.mkdir(exist_ok=True)
108+
self.file.write_text(json_data)
109+
110+
def get_records(self, address: Address) -> list[MocketRecord]:
111+
return list(self._records[address].values())
112+
113+
def get_record(self, address: Address, request: bytes) -> MocketRecord | None:
114+
# NOTE for backward-compat
115+
request_signature_fallback = _hash_request_fallback(request)
116+
if request_signature_fallback in self._records[address]:
117+
return self._records[address].get(request_signature_fallback)
118+
119+
request_signature = _hash_request(request)
120+
if request_signature in self._records[address]:
121+
return self._records[address][request_signature]
122+
123+
return None
124+
125+
def put_record(
126+
self,
127+
address: Address,
128+
request: bytes,
129+
response: bytes,
130+
) -> None:
131+
host, port = address
132+
record = MocketRecord(
133+
host=host,
134+
port=port,
135+
request=request,
136+
response=response,
137+
)
138+
139+
# NOTE for backward-compat
140+
request_signature_fallback = _hash_request_fallback(request)
141+
if request_signature_fallback in self._records[address]:
142+
self._records[address][request_signature_fallback] = record
143+
return
144+
145+
request_signature = _hash_request(request)
146+
self._records[address][request_signature] = record
147+
self._save()

0 commit comments

Comments
 (0)