@@ -160,6 +160,8 @@ def __str__(self) -> str:
160
160
_pytorch_file_extensions = {".bin" , ".pt" , ".pth" , ".ckpt" }
161
161
_pickle_file_extensions = {".pkl" , ".pickle" , ".joblib" , ".dat" , ".data" }
162
162
_zip_file_extensions = {".zip" , ".npz" , ".7z" }
163
+ # Pickle files do not actually have magic bytes, but v2+ files
164
+ # start with a PROTO (\x80) opcode followed by a byte with the protocol version
163
165
_pickle_magic_bytes = {
164
166
b"\x80 \x00 " ,
165
167
b"\x80 \x01 " ,
@@ -168,6 +170,7 @@ def __str__(self) -> str:
168
170
b"\x80 \x04 " ,
169
171
b"\x80 \x05 " ,
170
172
}
173
+ _numpy_magic_bytes = b"\x93 NUMPY"
171
174
172
175
173
176
def _is_7z_file (f : IO [bytes ]) -> bool :
@@ -364,37 +367,36 @@ def scan_7z_bytes(data: IO[bytes], file_id) -> ScanResult:
364
367
return result
365
368
366
369
367
- def get_magic_bytes_from_zipfile (zip : zipfile .ZipFile , num_bytes = 8 ):
368
- magic_bytes = {}
369
- for file_info in zip .infolist ():
370
- with zip .open (file_info .filename ) as f :
371
- magic_bytes [file_info .filename ] = f .read (num_bytes )
372
-
373
- return magic_bytes
374
-
375
-
376
370
def scan_zip_bytes (data : IO [bytes ], file_id ) -> ScanResult :
377
371
result = ScanResult ([])
378
372
379
373
with RelaxedZipFile (data , "r" ) as zip :
380
- magic_bytes = get_magic_bytes_from_zipfile (zip )
381
374
file_names = zip .namelist ()
382
375
_log .debug ("Files in zip archive %s: %s" , file_id , file_names )
383
376
for file_name in file_names :
384
- magic_number = magic_bytes .get (file_name , b"" )
385
- file_ext = os .path .splitext (file_name )[1 ]
386
- if file_ext in _pickle_file_extensions or any (
387
- magic_number .startswith (mn ) for mn in _pickle_magic_bytes
388
- ):
389
- _log .debug ("Scanning file %s in zip archive %s" , file_name , file_id )
390
- with zip .open (file_name , "r" ) as file :
391
- result .merge (scan_pickle_bytes (file , f"{ file_id } :{ file_name } " ))
392
- elif file_ext in _numpy_file_extensions or magic_number .startswith (
393
- b"\x93 NUMPY"
394
- ):
395
- _log .debug ("Scanning file %s in zip archive %s" , file_name , file_id )
377
+ try :
396
378
with zip .open (file_name , "r" ) as file :
397
- result .merge (scan_numpy (file , f"{ file_id } :{ file_name } " ))
379
+ magic_bytes = file .read (8 )
380
+ file_ext = os .path .splitext (file_name )[1 ]
381
+
382
+ if file_ext in _pickle_file_extensions or any (
383
+ magic_bytes .startswith (mn ) for mn in _pickle_magic_bytes
384
+ ):
385
+ _log .debug ("Scanning file %s in zip archive %s" , file_name , file_id )
386
+ with zip .open (file_name , "r" ) as file :
387
+ result .merge (scan_pickle_bytes (file , f"{ file_id } :{ file_name } " ))
388
+
389
+ elif file_ext in _numpy_file_extensions or magic_bytes .startswith (
390
+ _numpy_magic_bytes
391
+ ):
392
+ _log .debug ("Scanning file %s in zip archive %s" , file_name , file_id )
393
+ with zip .open (file_name , "r" ) as file :
394
+ result .merge (scan_numpy (file , f"{ file_id } :{ file_name } " ))
395
+ except (zipfile .BadZipFile , RuntimeError ) as e :
396
+ # Log decompression issues (password protected, corrupted, etc.)
397
+ _log .warning (
398
+ "Invalid file %s in zip archive %s: %s" , file_name , file_id , str (e )
399
+ )
398
400
399
401
return result
400
402
0 commit comments