diff --git a/benchs/bench_fw/benchmark.py b/benchs/bench_fw/benchmark.py index 237d08bd9a..e6220330b7 100644 --- a/benchs/bench_fw/benchmark.py +++ b/benchs/bench_fw/benchmark.py @@ -276,6 +276,7 @@ def train( @dataclass class BuildOperator(IndexOperator): index_descs: List[IndexDescriptor] = field(default_factory=lambda: []) + serialize_index: bool = False def get_desc(self, name: str) -> Optional[IndexDescriptor]: for desc in self.index_descs: @@ -312,6 +313,7 @@ def build_index_wrapper(self, index_desc: IndexDescriptor): path=index_desc.codec_desc.path, index_name=index_desc.get_name(), codec_name=index_desc.codec_desc.get_name(), + serialize_full_index=self.serialize_index, ) index.set_io(self.io) index_desc.index = index diff --git a/benchs/bench_fw/benchmark_io.py b/benchs/bench_fw/benchmark_io.py index e6f337b89c..79b0fd09c4 100644 --- a/benchs/bench_fw/benchmark_io.py +++ b/benchs/bench_fw/benchmark_io.py @@ -45,7 +45,7 @@ def merge_rcq_itq( @dataclass class BenchmarkIO: - path: str + path: str # local path def __init__(self, path: str): self.path = path @@ -54,8 +54,7 @@ def __init__(self, path: str): def clone(self): return BenchmarkIO(path=self.path) - # TODO(kuarora): rename it as get_local_file - def get_local_filename(self, filename): + def get_local_filepath(self, filename): if len(filename) > 184: fn, ext = os.path.splitext(filename) filename = ( @@ -72,7 +71,7 @@ def download_file_from_blobstore( bucket: Optional[str] = None, path: Optional[str] = None, ): - return self.get_local_filename(filename) + return self.get_local_filepath(filename) def upload_file_to_blobstore( self, @@ -84,7 +83,7 @@ def upload_file_to_blobstore( pass def file_exist(self, filename: str): - fn = self.get_local_filename(filename) + fn = self.get_local_filepath(filename) exists = os.path.exists(fn) logger.info(f"{filename} {exists=}") return exists @@ -112,7 +111,7 @@ def write_file( values: List[Any], overwrite: bool = False, ): - fn = self.get_local_filename(filename) + fn = self.get_local_filepath(filename) with ZipFile(fn, "w") as zip_file: for key, value in zip(keys, values, strict=True): with zip_file.open(key, "w", force_zip64=True) as f: @@ -187,7 +186,7 @@ def write_nparray( nparray: np.ndarray, filename: str, ): - fn = self.get_local_filename(filename) + fn = self.get_local_filepath(filename) logger.info(f"Saving nparray {nparray.shape} to {fn}") np.save(fn, nparray) self.upload_file_to_blobstore(filename) @@ -209,7 +208,7 @@ def write_json( filename: str, overwrite: bool = False, ): - fn = self.get_local_filename(filename) + fn = self.get_local_filepath(filename) logger.info(f"Saving json {json_dict} to {fn}") with open(fn, "w") as fp: json.dump(json_dict, fp) @@ -239,7 +238,7 @@ def write_index( index: faiss.Index, filename: str, ): - fn = self.get_local_filename(filename) + fn = self.get_local_filepath(filename) logger.info(f"Saving index to {fn}") faiss.write_index(index, fn) self.upload_file_to_blobstore(filename) diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index a9bae2a1ba..bcd454154d 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -236,7 +236,7 @@ def name_from_path(self): name = filename return name - def alias(self, benchmark_io : BenchmarkIO): + def alias(self, benchmark_io: BenchmarkIO): if hasattr(benchmark_io, "bucket"): return CodecDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric) return CodecDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric) diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index 6b6c2d93af..090722f54a 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -786,12 +786,12 @@ def is_flat_index(self): # are used to wrap pre-trained Faiss indices (codecs) @dataclass class IndexFromCodec(Index): - path: Optional[str] = None + path: Optional[str] = None # remote or local path to the codec def __post_init__(self): super().__post_init__() - if self.path is None: - raise ValueError("path is not set") + if self.path is None and self.codec_name is None: + raise ValueError("path or desc_name is not set") def get_quantizer(self): if not self.is_ivf(): @@ -814,10 +814,17 @@ def fetch_meta(self, dry_run=False): return None, None def fetch_codec(self): + if self.path is not None: + codec_filename = os.path.basename(self.path) + remote_path = os.path.dirname(self.path) + else: + codec_filename = self.get_codec_name() + "codec" + remote_path = None + codec = self.io.read_index( - os.path.basename(self.path), + codec_filename, self.bucket, - os.path.dirname(self.path), + remote_path, ) assert self.d == codec.d assert self.metric_type == codec.metric_type