Skip to content

Commit 37f6b76

Browse files
kuarorafacebook-github-bot
authored andcommitted
Adding support for index builder (#3800)
Summary: Pull Request resolved: #3800 In this diff, 1. codec can be referred both using desc name or remote path in IndexFromCodec 2. expose serialization of full index through BuildOperator 3. Rename get_local_filename to get_local_filepath. Reviewed By: satymish Differential Revision: D61813717 fbshipit-source-id: ed422751a1d3712565efa87ecf615620799cb8eb
1 parent 084496a commit 37f6b76

File tree

4 files changed

+23
-15
lines changed

4 files changed

+23
-15
lines changed

benchs/bench_fw/benchmark.py

+2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def train(
276276
@dataclass
277277
class BuildOperator(IndexOperator):
278278
index_descs: List[IndexDescriptor] = field(default_factory=lambda: [])
279+
serialize_index: bool = False
279280

280281
def get_desc(self, name: str) -> Optional[IndexDescriptor]:
281282
for desc in self.index_descs:
@@ -312,6 +313,7 @@ def build_index_wrapper(self, index_desc: IndexDescriptor):
312313
path=index_desc.codec_desc.path,
313314
index_name=index_desc.get_name(),
314315
codec_name=index_desc.codec_desc.get_name(),
316+
serialize_full_index=self.serialize_index,
315317
)
316318
index.set_io(self.io)
317319
index_desc.index = index

benchs/bench_fw/benchmark_io.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def merge_rcq_itq(
4545

4646
@dataclass
4747
class BenchmarkIO:
48-
path: str
48+
path: str # local path
4949

5050
def __init__(self, path: str):
5151
self.path = path
@@ -54,8 +54,7 @@ def __init__(self, path: str):
5454
def clone(self):
5555
return BenchmarkIO(path=self.path)
5656

57-
# TODO(kuarora): rename it as get_local_file
58-
def get_local_filename(self, filename):
57+
def get_local_filepath(self, filename):
5958
if len(filename) > 184:
6059
fn, ext = os.path.splitext(filename)
6160
filename = (
@@ -72,7 +71,7 @@ def download_file_from_blobstore(
7271
bucket: Optional[str] = None,
7372
path: Optional[str] = None,
7473
):
75-
return self.get_local_filename(filename)
74+
return self.get_local_filepath(filename)
7675

7776
def upload_file_to_blobstore(
7877
self,
@@ -84,7 +83,7 @@ def upload_file_to_blobstore(
8483
pass
8584

8685
def file_exist(self, filename: str):
87-
fn = self.get_local_filename(filename)
86+
fn = self.get_local_filepath(filename)
8887
exists = os.path.exists(fn)
8988
logger.info(f"{filename} {exists=}")
9089
return exists
@@ -112,7 +111,7 @@ def write_file(
112111
values: List[Any],
113112
overwrite: bool = False,
114113
):
115-
fn = self.get_local_filename(filename)
114+
fn = self.get_local_filepath(filename)
116115
with ZipFile(fn, "w") as zip_file:
117116
for key, value in zip(keys, values, strict=True):
118117
with zip_file.open(key, "w", force_zip64=True) as f:
@@ -187,7 +186,7 @@ def write_nparray(
187186
nparray: np.ndarray,
188187
filename: str,
189188
):
190-
fn = self.get_local_filename(filename)
189+
fn = self.get_local_filepath(filename)
191190
logger.info(f"Saving nparray {nparray.shape} to {fn}")
192191
np.save(fn, nparray)
193192
self.upload_file_to_blobstore(filename)
@@ -209,7 +208,7 @@ def write_json(
209208
filename: str,
210209
overwrite: bool = False,
211210
):
212-
fn = self.get_local_filename(filename)
211+
fn = self.get_local_filepath(filename)
213212
logger.info(f"Saving json {json_dict} to {fn}")
214213
with open(fn, "w") as fp:
215214
json.dump(json_dict, fp)
@@ -239,7 +238,7 @@ def write_index(
239238
index: faiss.Index,
240239
filename: str,
241240
):
242-
fn = self.get_local_filename(filename)
241+
fn = self.get_local_filepath(filename)
243242
logger.info(f"Saving index to {fn}")
244243
faiss.write_index(index, fn)
245244
self.upload_file_to_blobstore(filename)

benchs/bench_fw/descriptors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def name_from_path(self):
236236
name = filename
237237
return name
238238

239-
def alias(self, benchmark_io : BenchmarkIO):
239+
def alias(self, benchmark_io: BenchmarkIO):
240240
if hasattr(benchmark_io, "bucket"):
241241
return CodecDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
242242
return CodecDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)

benchs/bench_fw/index.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -786,12 +786,12 @@ def is_flat_index(self):
786786
# are used to wrap pre-trained Faiss indices (codecs)
787787
@dataclass
788788
class IndexFromCodec(Index):
789-
path: Optional[str] = None
789+
path: Optional[str] = None # remote or local path to the codec
790790

791791
def __post_init__(self):
792792
super().__post_init__()
793-
if self.path is None:
794-
raise ValueError("path is not set")
793+
if self.path is None and self.codec_name is None:
794+
raise ValueError("path or desc_name is not set")
795795

796796
def get_quantizer(self):
797797
if not self.is_ivf():
@@ -814,10 +814,17 @@ def fetch_meta(self, dry_run=False):
814814
return None, None
815815

816816
def fetch_codec(self):
817+
if self.path is not None:
818+
codec_filename = os.path.basename(self.path)
819+
remote_path = os.path.dirname(self.path)
820+
else:
821+
codec_filename = self.get_codec_name() + "codec"
822+
remote_path = None
823+
817824
codec = self.io.read_index(
818-
os.path.basename(self.path),
825+
codec_filename,
819826
self.bucket,
820-
os.path.dirname(self.path),
827+
remote_path,
821828
)
822829
assert self.d == codec.d
823830
assert self.metric_type == codec.metric_type

0 commit comments

Comments
 (0)