|
7 | 7 | import os
|
8 | 8 | import pickletools
|
9 | 9 | from tarfile import TarError
|
| 10 | +from tempfile import TemporaryDirectory |
10 | 11 | from typing import IO, List, Optional, Set, Tuple
|
11 | 12 | import urllib.parse
|
12 | 13 | import zipfile
|
@@ -151,7 +152,23 @@ def __str__(self) -> str:
|
151 | 152 | _numpy_file_extensions = {".npy"} # Note: .npz is handled as zip files
|
152 | 153 | _pytorch_file_extensions = {".bin", ".pt", ".pth", ".ckpt"}
|
153 | 154 | _pickle_file_extensions = {".pkl", ".pickle", ".joblib", ".dat", ".data"}
|
154 |
| -_zip_file_extensions = {".zip", ".npz"} |
| 155 | +_zip_file_extensions = {".zip", ".npz", ".7z"} |
| 156 | + |
| 157 | + |
| 158 | +def _is_7z_file(f: IO[bytes]) -> bool: |
| 159 | + read_bytes = [] |
| 160 | + start = f.tell() |
| 161 | + |
| 162 | + byte = f.read(1) |
| 163 | + while byte != b"": |
| 164 | + read_bytes.append(byte) |
| 165 | + if len(read_bytes) == 6: |
| 166 | + break |
| 167 | + byte = f.read(1) |
| 168 | + f.seek(start) |
| 169 | + |
| 170 | + local_header_magic_number = [b"7", b"z", b"\xbc", b"\xaf", b"\x27", b"\x1c"] |
| 171 | + return read_bytes == local_header_magic_number |
155 | 172 |
|
156 | 173 |
|
157 | 174 | def _http_get(url) -> bytes:
|
@@ -307,12 +324,37 @@ def scan_pickle_bytes(data: IO[bytes], file_id, multiple_pickles=True) -> ScanRe
|
307 | 324 | return _build_scan_result_from_raw_globals(raw_globals, file_id)
|
308 | 325 |
|
309 | 326 |
|
| 327 | +# XXX: it appears there is not way to get the byte stream for a given file within the 7z archive and thus forcing us to unzip to disk before scanning |
| 328 | +def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult: |
| 329 | + try: |
| 330 | + import py7zr |
| 331 | + except ImportError: |
| 332 | + raise Exception( |
| 333 | + "py7zr is required to scan 7z archives, install picklescan using: 'pip install picklescan[7z]'" |
| 334 | + ) |
| 335 | + result = ScanResult([]) |
| 336 | + |
| 337 | + with py7zr.SevenZipFile(data, mode="r") as archive: |
| 338 | + file_names = archive.getnames() |
| 339 | + targets = [f for f in file_names if f.endswith(tuple(_pickle_file_extensions))] |
| 340 | + _log.debug("Files in 7z archive %s: %s", file_id, targets) |
| 341 | + with TemporaryDirectory() as tmpdir: |
| 342 | + archive.extract(path=tmpdir, targets=targets) |
| 343 | + for file_name in targets: |
| 344 | + file_path = os.path.join(tmpdir, file_name) |
| 345 | + _log.debug("Scanning file %s in 7z archive %s", file_name, file_id) |
| 346 | + if os.path.isfile(file_path): |
| 347 | + result.merge(scan_file_path(file_path)) |
| 348 | + |
| 349 | + return result |
| 350 | + |
| 351 | + |
310 | 352 | def scan_zip_bytes(data: IO[bytes], file_id) -> ScanResult:
|
311 | 353 | result = ScanResult([])
|
312 | 354 |
|
313 | 355 | with zipfile.ZipFile(data, "r") as zip:
|
314 | 356 | file_names = zip.namelist()
|
315 |
| - _log.debug("Files in archive %s: %s", file_id, file_names) |
| 357 | + _log.debug("Files in zip archive %s: %s", file_id, file_names) |
316 | 358 | for file_name in file_names:
|
317 | 359 | file_ext = os.path.splitext(file_name)[1]
|
318 | 360 | if file_ext in _pickle_file_extensions:
|
@@ -361,6 +403,8 @@ def scan_pytorch(data: IO[bytes], file_id) -> ScanResult:
|
361 | 403 | # new pytorch format
|
362 | 404 | if _is_zipfile(data):
|
363 | 405 | return scan_zip_bytes(data, file_id)
|
| 406 | + elif _is_7z_file(data): |
| 407 | + return scan_7z_bytes(data, file_id) |
364 | 408 | # old pytorch format
|
365 | 409 | else:
|
366 | 410 | scan_result = ScanResult([])
|
@@ -395,11 +439,12 @@ def scan_bytes(data: IO[bytes], file_id, file_ext: Optional[str] = None) -> Scan
|
395 | 439 | else:
|
396 | 440 | is_zip = zipfile.is_zipfile(data)
|
397 | 441 | data.seek(0)
|
398 |
| - return ( |
399 |
| - scan_zip_bytes(data, file_id) |
400 |
| - if is_zip |
401 |
| - else scan_pickle_bytes(data, file_id) |
402 |
| - ) |
| 442 | + if is_zip: |
| 443 | + return scan_zip_bytes(data, file_id) |
| 444 | + elif _is_7z_file(data): |
| 445 | + return scan_7z_bytes(data, file_id) |
| 446 | + else: |
| 447 | + return scan_pickle_bytes(data, file_id) |
403 | 448 |
|
404 | 449 |
|
405 | 450 | def scan_huggingface_model(repo_id):
|
|
0 commit comments