Skip to content

Commit 5ad6689

Browse files
kuaroraabhinavdangeti
authored andcommitted
Refactor bench_fw to support train, build & search in parallel (facebookresearch#3527)
Summary: Pull Request resolved: facebookresearch#3527 **Context** Design Doc: [Faiss Benchmarking](https://docs.google.com/document/d/1c7zziITa4RD6jZsbG9_yOgyRjWdyueldSPH6QdZzL98/edit) **In this diff** 1. Be able to reference codec and index from blobstore (bucket & path) outside the experiment 2. To support #1, naming is moved to descriptors. 3. Build index can be written as well. 4. You can run benchmark with train and then refer it in index built and then refer index built in knn search. Index serialization is optional. Although not yet exposed through index descriptor. 5. Benchmark can support index with different datasets sizes 6. Working with varying dataset now support multiple ground truth. There may be small fixes before we could use this. 7. Added targets for bench_fw_range, ivf, codecs and optimize. **Analysis of ivf result**: D58823037 Reviewed By: algoriddle Differential Revision: D57236543 fbshipit-source-id: ad03b28bae937a35f8c20f12e0a5b0a27c34ff3b
1 parent 083bfdd commit 5ad6689

9 files changed

+906
-265
lines changed

benchs/bench_fw/benchmark.py

+598-165
Large diffs are not rendered by default.

benchs/bench_fw/benchmark_io.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def clone(self):
5353
def __post_init__(self):
5454
self.cached_ds = {}
5555

56+
# TODO(kuarora): rename it as get_local_file
5657
def get_local_filename(self, filename):
5758
if len(filename) > 184:
5859
fn, ext = os.path.splitext(filename)
@@ -61,6 +62,9 @@ def get_local_filename(self, filename):
6162
)
6263
return os.path.join(self.path, filename)
6364

65+
def get_remote_filepath(self, filename) -> Optional[str]:
66+
return None
67+
6468
def download_file_from_blobstore(
6569
self,
6670
filename: str,
@@ -219,7 +223,7 @@ def read_index(
219223
fn = self.download_file_from_blobstore(filename, bucket, path)
220224
logger.info(f"Loading index {fn}")
221225
ext = os.path.splitext(fn)[1]
222-
if ext in [".faiss", ".codec"]:
226+
if ext in [".faiss", ".codec", ".index"]:
223227
index = faiss.read_index(fn)
224228
elif ext == ".pkl":
225229
with open(fn, "rb") as model_file:

benchs/bench_fw/descriptors.py

+211-4
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,21 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from dataclasses import dataclass
76
import logging
7+
import os
8+
from dataclasses import dataclass
89
from typing import Any, Dict, List, Optional
910

1011
import faiss # @manual=//faiss/python:pyfaiss_gpu
12+
13+
from .benchmark_io import BenchmarkIO
1114
from .utils import timer
1215

1316
logger = logging.getLogger(__name__)
1417

1518

1619
@dataclass
17-
class IndexDescriptor:
20+
class IndexDescriptorClassic:
1821
bucket: Optional[str] = None
1922
# either path or factory should be set,
2023
# but not both at the same time.
@@ -45,7 +48,6 @@ class IndexDescriptor:
4548
def __hash__(self):
4649
return hash(str(self))
4750

48-
4951
@dataclass
5052
class DatasetDescriptor:
5153
# namespace possible values:
@@ -81,7 +83,7 @@ def __hash__(self):
8183

8284
def get_filename(
8385
self,
84-
prefix: str = None,
86+
prefix: Optional[str] = None,
8587
) -> str:
8688
filename = ""
8789
if prefix is not None:
@@ -116,3 +118,208 @@ def k_means(self, io, k, dry_run):
116118
else:
117119
t = io.read_json(meta_filename)["k_means_time"]
118120
return kmeans_vectors, t, None
121+
122+
@dataclass
123+
class IndexBaseDescriptor:
124+
d: int
125+
metric: str
126+
desc_name: Optional[str] = None
127+
flat_desc_name: Optional[str] = None
128+
bucket: Optional[str] = None
129+
path: Optional[str] = None
130+
num_threads: int = 1
131+
132+
def get_name(self) -> str:
133+
raise NotImplementedError()
134+
135+
def get_path(self, benchmark_io: BenchmarkIO) -> Optional[str]:
136+
if self.path is not None:
137+
return self.path
138+
self.path = benchmark_io.get_remote_filepath(self.desc_name)
139+
return self.path
140+
141+
@staticmethod
142+
def param_dict_list_to_name(param_dict_list):
143+
if not param_dict_list:
144+
return ""
145+
l = 0
146+
n = ""
147+
for param_dict in param_dict_list:
148+
n += IndexBaseDescriptor.param_dict_to_name(param_dict, f"cp{l}")
149+
l += 1
150+
return n
151+
152+
@staticmethod
153+
def param_dict_to_name(param_dict, prefix="sp"):
154+
if not param_dict:
155+
return ""
156+
n = prefix
157+
for name, val in param_dict.items():
158+
if name == "snap":
159+
continue
160+
if name == "lsq_gpu" and val == 0:
161+
continue
162+
if name == "use_beam_LUT" and val == 0:
163+
continue
164+
n += f"_{name}_{val}"
165+
if n == prefix:
166+
return ""
167+
n += "."
168+
return n
169+
170+
171+
@dataclass
172+
class CodecDescriptor(IndexBaseDescriptor):
173+
# either path or factory should be set,
174+
# but not both at the same time.
175+
factory: Optional[str] = None
176+
construction_params: Optional[List[Dict[str, int]]] = None
177+
training_vectors: Optional[DatasetDescriptor] = None
178+
179+
def __post_init__(self):
180+
self.get_name()
181+
182+
def is_trained(self):
183+
return self.factory is None and self.path is not None
184+
185+
def is_valid(self):
186+
return self.factory is not None or self.path is not None
187+
188+
def get_name(self) -> str:
189+
if self.desc_name is not None:
190+
return self.desc_name
191+
if self.factory is not None:
192+
self.desc_name = self.name_from_factory()
193+
return self.desc_name
194+
if self.path is not None:
195+
self.desc_name = self.name_from_path()
196+
return self.desc_name
197+
raise ValueError("name, factory or path must be set")
198+
199+
def flat_name(self) -> str:
200+
if self.flat_desc_name is not None:
201+
return self.flat_desc_name
202+
self.flat_desc_name = f"Flat.d_{self.d}.{self.metric.upper()}."
203+
return self.flat_desc_name
204+
205+
def path(self, benchmark_io) -> str:
206+
if self.path is not None:
207+
return self.path
208+
return benchmark_io.get_remote_filepath(self.get_name())
209+
210+
def name_from_factory(self) -> str:
211+
assert self.factory is not None
212+
name = f"{self.factory.replace(',', '_')}."
213+
assert self.d is not None
214+
assert self.metric is not None
215+
name += f"d_{self.d}.{self.metric.upper()}."
216+
if self.factory != "Flat":
217+
assert self.training_vectors is not None
218+
name += self.training_vectors.get_filename("xt")
219+
name += IndexBaseDescriptor.param_dict_list_to_name(self.construction_params)
220+
return name
221+
222+
def name_from_path(self):
223+
assert self.path is not None
224+
filename = os.path.basename(self.path)
225+
ext = filename.split(".")[-1]
226+
if filename.endswith(ext):
227+
name = filename[:-len(ext)]
228+
else: # should never hit this rather raise value error
229+
name = filename
230+
return name
231+
232+
def alias(self, benchmark_io : BenchmarkIO):
233+
if hasattr(benchmark_io, "bucket"):
234+
return CodecDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
235+
return CodecDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)
236+
237+
238+
@dataclass
239+
class IndexDescriptor(IndexBaseDescriptor):
240+
codec_desc: Optional[CodecDescriptor] = None
241+
database_desc: Optional[DatasetDescriptor] = None
242+
243+
def __hash__(self):
244+
return hash(str(self))
245+
246+
def __post_init__(self):
247+
self.get_name()
248+
249+
def is_built(self):
250+
return self.codec_desc is None and self.database_desc is None
251+
252+
def get_name(self) -> str:
253+
if self.desc_name is None:
254+
self.desc_name = self.codec_desc.get_name() + self.database_desc.get_filename(prefix="xb")
255+
256+
return self.desc_name
257+
258+
def flat_name(self):
259+
if self.flat_desc_name is not None:
260+
return self.flat_desc_name
261+
self.flat_desc_name = self.codec_desc.flat_name() + self.database_desc.get_filename(prefix="xb")
262+
return self.flat_desc_name
263+
264+
# alias is used to refer when index is uploaded to blobstore and refered again
265+
def alias(self, benchmark_io: BenchmarkIO):
266+
if hasattr(benchmark_io, "bucket"):
267+
return IndexDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
268+
return IndexDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)
269+
270+
@dataclass
271+
class KnnDescriptor(IndexBaseDescriptor):
272+
index_desc: Optional[IndexDescriptor] = None
273+
gt_index_desc: Optional[IndexDescriptor] = None
274+
query_dataset: Optional[DatasetDescriptor] = None
275+
search_params: Optional[Dict[str, int]] = None
276+
reconstruct: bool = False
277+
# range metric definitions
278+
# key: name
279+
# value: one of the following:
280+
#
281+
# radius
282+
# [0..radius) -> 1
283+
# [radius..inf) -> 0
284+
#
285+
# [[radius1, score1], ...]
286+
# [0..radius1) -> score1
287+
# [radius1..radius2) -> score2
288+
#
289+
# [[radius1_from, radius1_to, score1], ...]
290+
# [radius1_from, radius1_to) -> score1,
291+
# [radius2_from, radius2_to) -> score2
292+
range_metrics: Optional[Dict[str, Any]] = None
293+
radius: Optional[float] = None
294+
k: int = 1
295+
296+
range_ref_index_desc: Optional[str] = None
297+
298+
def __hash__(self):
299+
return hash(str(self))
300+
301+
def get_name(self):
302+
name = self.index_desc.get_name()
303+
name += IndexBaseDescriptor.param_dict_to_name(self.search_params)
304+
name += self.query_dataset.get_filename("q")
305+
name += f"k_{self.k}."
306+
name += f"t_{self.num_threads}."
307+
if self.reconstruct:
308+
name += "rec."
309+
else:
310+
name += "knn."
311+
return name
312+
313+
def flat_name(self):
314+
if self.flat_desc_name is not None:
315+
return self.flat_desc_name
316+
name = self.index_desc.flat_name()
317+
name += self.query_dataset.get_filename("q")
318+
name += f"k_{self.k}."
319+
name += f"t_{self.num_threads}."
320+
if self.reconstruct:
321+
name += "rec."
322+
else:
323+
name += "knn."
324+
self.flat_desc_name = name
325+
return name

0 commit comments

Comments
 (0)