Skip to content

Commit 2f75026

Browse files
committed
refactor: avoid convulted code
1 parent f5edabf commit 2f75026

File tree

3 files changed

+90
-27
lines changed

3 files changed

+90
-27
lines changed

src/picklescan/scanner.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def merge(self, sr: "ScanResult"):
5050

5151

5252
class GenOpsError(Exception):
53-
def __init__(self, msg: str):
53+
def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]):
5454
self.msg = msg
55+
self.globals = globals
5556
super().__init__()
5657

5758
def __str__(self) -> str:
@@ -177,16 +178,10 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
177178
try:
178179
ops = list(pickletools.genops(data))
179180
except Exception as e:
180-
# XXX: pickle will happily load files that contain arbitrarily placed new lines whereas pickletools errors in such cases.
181-
# below is code to circumvent or skip these newlines while succeeding at parsing the opcodes.
182-
err = str(e)
183-
if "opcode b'\\n' unknown" not in err:
184-
raise GenOpsError(err)
185-
else:
186-
pos = int(err.split(",")[0].replace("at position ", ""))
187-
data.seek(-(pos + 1), 1)
188-
ops = list(pickletools.genops(data.read(pos)))
189-
data.seek(1, 1)
181+
# XXX: given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle.
182+
# Thus we return the already found globals in the error & to let the caller decide what to do.
183+
globals_opt = globals if len(globals) > 0 else None
184+
raise GenOpsError(str(e), globals_opt)
190185

191186
last_byte = data.read(1)
192187
data.seek(-1, 1)
@@ -241,18 +236,12 @@ def _list_globals(data: IO[bytes], multiple_pickles=True) -> Set[Tuple[str, str]
241236
return globals
242237

243238

244-
def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult:
245-
"""Disassemble a Pickle stream and report issues"""
246-
239+
def _build_scan_result_from_raw_globals(
240+
raw_globals: Set[Tuple[str, str]],
241+
file_id,
242+
scan_err=False,
243+
) -> ScanResult:
247244
globals = []
248-
try:
249-
raw_globals = _list_globals(data, multiple_pickles)
250-
except GenOpsError as e:
251-
_log.error(f"ERROR: parsing pickle in {file_id}: {e}")
252-
return ScanResult(globals, scan_err=True)
253-
254-
_log.debug("Global imports in %s: %s", file_id, raw_globals)
255-
256245
issues_count = 0
257246
for rg in raw_globals:
258247
g = Global(rg[0], rg[1], SafetyLevel.Dangerous)
@@ -278,7 +267,26 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe
278267
g.safety = SafetyLevel.Suspicious
279268
globals.append(g)
280269

281-
return ScanResult(globals, 1, issues_count, 1 if issues_count > 0 else 0, False)
270+
return ScanResult(globals, 1, issues_count, 1 if issues_count > 0 else 0, scan_err)
271+
272+
273+
def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanResult:
274+
"""Disassemble a Pickle stream and report issues"""
275+
276+
try:
277+
raw_globals = _list_globals(data, multiple_pickles)
278+
except GenOpsError as e:
279+
_log.error(f"ERROR: parsing pickle in {file_id}: {e}")
280+
if e.globals is not None:
281+
return _build_scan_result_from_raw_globals(
282+
e.globals, file_id, scan_err=True
283+
)
284+
else:
285+
return ScanResult([], scan_err=True)
286+
287+
_log.debug("Global imports in %s: %s", file_id, raw_globals)
288+
289+
return _build_scan_result_from_raw_globals(raw_globals, file_id)
282290

283291

284292
def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Vos
2+
p2
3+
0Vsystem
4+
p3
5+
0Vtorch
6+
p0
7+
0VLongStorage
8+
p1
9+
0g2
10+
g3
11+
�(Vcat flag.txt
12+
tR.
13+
14+

tests/test_scanner.py

+45-4
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,35 @@ def initialize_pickle_files():
243243
),
244244
)
245245

246+
initialize_data_file(
247+
f"{_root_path}/data/malicious-invalid-bytes.pkl",
248+
b"".join(
249+
[
250+
pickle.UNICODE + b"os\n",
251+
pickle.PUT + b"2\n",
252+
pickle.POP,
253+
pickle.UNICODE + b"system\n",
254+
pickle.PUT + b"3\n",
255+
pickle.POP,
256+
pickle.UNICODE + b"torch\n",
257+
pickle.PUT + b"0\n",
258+
pickle.POP,
259+
pickle.UNICODE + b"LongStorage\n",
260+
pickle.PUT + b"1\n",
261+
pickle.POP,
262+
pickle.GET + b"2\n",
263+
pickle.GET + b"3\n",
264+
pickle.STACK_GLOBAL,
265+
pickle.MARK,
266+
pickle.UNICODE + b"cat flag.txt\n",
267+
pickle.TUPLE,
268+
pickle.REDUCE,
269+
pickle.STOP,
270+
b"\n\n\t\t",
271+
]
272+
),
273+
)
274+
246275
# Code which created malicious12.pkl using pickleassem (see https://github.com/gousaiyang/pickleassem)
247276
#
248277
# p = PickleAssembler(proto=4)
@@ -351,7 +380,6 @@ def test_scan_pickle_bytes():
351380

352381

353382
def test_scan_zip_bytes():
354-
355383
buffer = io.BytesIO()
356384
with zipfile.ZipFile(buffer, "w") as zip:
357385
zip.writestr("data.pkl", pickle.dumps(Malicious1()))
@@ -559,15 +587,17 @@ def test_scan_directory_path():
559587
Global("torch", "_utils", SafetyLevel.Suspicious),
560588
Global("__builtin__", "exec", SafetyLevel.Dangerous),
561589
Global("os", "system", SafetyLevel.Dangerous),
590+
Global("os", "system", SafetyLevel.Dangerous),
562591
Global("operator", "attrgetter", SafetyLevel.Dangerous),
563592
Global("builtins", "__import__", SafetyLevel.Suspicious),
564593
Global("pickle", "loads", SafetyLevel.Dangerous),
565594
Global("_pickle", "loads", SafetyLevel.Dangerous),
566595
Global("_codecs", "encode", SafetyLevel.Suspicious),
567596
],
568-
scanned_files=26,
569-
issues_count=24,
570-
infected_files=21,
597+
scanned_files=27,
598+
issues_count=25,
599+
infected_files=22,
600+
scan_err=True,
571601
)
572602
compare_scan_results(scan_directory_path(f"{_root_path}/data/"), sr)
573603

@@ -610,3 +640,14 @@ def test_pickle_files():
610640
assert pickle.load(file) == 12345
611641
with open(f"{_root_path}/data/malicious13b.pkl", "rb") as file:
612642
assert pickle.load(file) == 12345
643+
644+
645+
def test_invalid_bytes_err():
646+
malicious_invalid_bytes = ScanResult(
647+
[Global("os", "system", SafetyLevel.Dangerous)], 1, 1, 1, True
648+
)
649+
with open(f"{_root_path}/data/malicious-invalid-bytes.pkl", "rb") as file:
650+
compare_scan_results(
651+
scan_pickle_bytes(file, f"{_root_path}/data/malicious-invalid-bytes.pkl"),
652+
malicious_invalid_bytes,
653+
)

0 commit comments

Comments
 (0)