Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add desc_name to dataset descriptor #3935

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 114 additions & 110 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,10 @@ def build_index_wrapper(self, codec_desc: CodecDescriptor):
else:
assert codec_desc.is_trained()

def train(
def train_one(
self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run=False
):
faiss.omp_set_num_threads(codec_desc.num_threads)
self.build_index_wrapper(codec_desc)
if codec_desc.is_trained():
return results, None
Expand All @@ -274,6 +275,16 @@ def train(
results["indices"][codec_desc.get_name()] = meta
return results, requires

def train(self, results, dry_run=False):
for desc in self.codec_descs:
results, requires = self.train_one(desc, results, dry_run=dry_run)
if dry_run:
if requires is None:
continue
return results, requires
assert requires is None
return results, None


@dataclass
class BuildOperator(IndexOperator):
Expand Down Expand Up @@ -322,17 +333,25 @@ def build_index_wrapper(self, index_desc: IndexDescriptor):
else:
assert index_desc.is_built()

def build(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
def build_one(self, index_desc: IndexDescriptor, results: Dict[str, Any]):
faiss.omp_set_num_threads(index_desc.num_threads)
self.build_index_wrapper(index_desc)
if index_desc.is_built():
return
index_desc.index.get_index()

def build(self, results: Dict[str, Any]):
# TODO: add support for dry_run
for index_desc in self.index_descs:
self.build_one(index_desc, results)
return results, None


@dataclass
class SearchOperator(IndexOperator):
knn_descs: List[KnnDescriptor] = field(default_factory=lambda: [])
range: bool = False
compute_gt: bool = True

def get_desc(self, name: str) -> Optional[KnnDescriptor]:
for desc in self.knn_descs:
Expand Down Expand Up @@ -655,85 +674,16 @@ def range_search_benchmark(
index=index,
)


@dataclass
class ExecutionOperator:
distance_metric: str = "L2"
num_threads: int = 1
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError

def set_io(self, io: BenchmarkIO):
self.io = io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type
if self.train_op:
self.train_op.set_io(io)
if self.build_op:
self.build_op.set_io(io)
if self.search_op:
self.search_op.set_io(io)

def train_one(self, codec_desc: CodecDescriptor, results: Dict[str, Any], dry_run):
faiss.omp_set_num_threads(self.num_threads)
assert self.train_op is not None
self.train_op.train(codec_desc, results, dry_run)

def train(self, results, dry_run=False):
faiss.omp_set_num_threads(self.num_threads)
if self.train_op is None:
return

for codec_desc in self.train_op.codec_descs:
self.train_one(codec_desc, results, dry_run)

def build_one(self, results: Dict[str, Any], index_desc: IndexDescriptor):
faiss.omp_set_num_threads(self.num_threads)
assert self.build_op is not None
self.build_op.build(index_desc, results)

def build(self, results: Dict[str, Any]):
faiss.omp_set_num_threads(self.num_threads)
if self.build_op is None:
return

for index_desc in self.build_op.index_descs:
self.build_one(index_desc, results)

def search(self):
faiss.omp_set_num_threads(self.num_threads)
if self.search_op is None:
return

for index_desc in self.search_op.knn_descs:
self.search_one(index_desc)

def search_one(
self,
knn_desc: KnnDescriptor,
results: Dict[str, Any],
dry_run=False,
range=False,
):
faiss.omp_set_num_threads(self.num_threads)
assert self.search_op is not None

if not dry_run and self.compute_gt:
self.create_gt_knn(knn_desc)
self.create_range_ref_knn(knn_desc)

self.search_op.build_index_wrapper(knn_desc)
faiss.omp_set_num_threads(knn_desc.num_threads)

self.build_index_wrapper(knn_desc)
# results, requires = self.reconstruct_benchmark(
# dry_run=True,
# results=results,
Expand All @@ -749,7 +699,7 @@ def search_one(
# index=index_desc.index,
# )
# assert requires is None
results, requires = self.search_op.knn_search_benchmark(
results, requires = self.knn_search_benchmark(
dry_run=True,
results=results,
knn_desc=knn_desc,
Expand All @@ -758,7 +708,7 @@ def search_one(
if dry_run:
return results, requires
else:
results, requires = self.search_op.knn_search_benchmark(
results, requires = self.knn_search_benchmark(
dry_run=False,
results=results,
knn_desc=knn_desc,
Expand All @@ -771,7 +721,7 @@ def search_one(
):
return results, None

ref_index_desc = self.search_op.get_desc(knn_desc.range_ref_index_desc)
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
if ref_index_desc is None:
raise ValueError(
f"{knn_desc.get_name()}: Unknown range index {knn_desc.range_ref_index_desc}"
Expand All @@ -786,17 +736,18 @@ def search_one(
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.search_op.range_search_reference(
) = self.range_search_reference(
ref_index_desc.index,
ref_index_desc.search_params,
range_metric,
query_dataset=knn_desc.query_dataset,
)
gt_rsm = None
if self.compute_gt:
gt_rsm = self.search_op.range_ground_truth(
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
)
results, requires = self.search_op.range_search_benchmark(
results, requires = self.range_search_benchmark(
dry_run=True,
results=results,
index=knn_desc.index,
Expand All @@ -805,13 +756,13 @@ def search_one(
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
query_vectors=knn_desc.query_dataset,
query_dataset=knn_desc.query_dataset,
)
if range and requires is not None:
if dry_run:
return results, requires
else:
results, requires = self.search_op.range_search_benchmark(
results, requires = self.range_search_benchmark(
dry_run=False,
results=results,
index=knn_desc.index,
Expand All @@ -820,12 +771,62 @@ def search_one(
gt_radius=gt_radius,
range_search_metric_function=range_search_metric_function,
gt_rsm=gt_rsm,
query_vectors=knn_desc.query_dataset,
query_dataset=knn_desc.query_dataset,
)
assert requires is None

return results, None

def search(
self,
results: Dict[str, Any],
dry_run: bool = False,):
for knn_desc in self.knn_descs:
results, requires = self.search_one(
knn_desc=knn_desc,
results=results,
dry_run=dry_run,
range=self.range)
if dry_run:
if requires is None:
continue
return results, requires

assert requires is None
return results, None


@dataclass
class ExecutionOperator:
distance_metric: str = "L2"
num_threads: int = 1
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError

if self.search_op is not None:
self.search_op.compute_gt = self.compute_gt

def set_io(self, io: BenchmarkIO):
self.io = io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type
if self.train_op:
self.train_op.set_io(io)
if self.build_op:
self.build_op.set_io(io)
if self.search_op:
self.search_op.set_io(io)

def create_gt_codec(
self, codec_desc, results, train=True
) -> Optional[CodecDescriptor]:
Expand All @@ -841,7 +842,7 @@ def create_gt_codec(
)
self.train_op.codec_descs.insert(0, gt_codec_desc)
if train:
self.train_op.train(gt_codec_desc, results, dry_run=False)
self.train_op.train_one(gt_codec_desc, results, dry_run=False)

return gt_codec_desc

Expand All @@ -865,7 +866,7 @@ def create_gt_index(
)
self.build_op.index_descs.insert(0, gt_index_desc)
if build:
self.build_op.build(gt_index_desc, results)
self.build_op.build_one(gt_index_desc, results)

return gt_index_desc

Expand Down Expand Up @@ -906,7 +907,9 @@ def create_range_ref_knn(self, knn_desc):
return

if knn_desc.range_ref_index_desc is not None:
ref_index_desc = self.get_desc(knn_desc.range_ref_index_desc)
ref_index_desc = (
self.search_op.get_desc(knn_desc.range_ref_index_desc)
)
if ref_index_desc is None:
raise ValueError(f"Unknown range index {knn_desc.range_ref_index_desc}")
if ref_index_desc.range_metrics is None:
Expand All @@ -921,19 +924,20 @@ def create_range_ref_knn(self, knn_desc):
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(
) = self.search_op.range_search_reference(
knn_desc.index, knn_desc.search_params, range_metric
)
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
}
knn_desc.gt_rsm = self.range_ground_truth(
knn_desc.gt_rsm = self.search_op.range_ground_truth(
knn_desc.gt_radius, range_search_metric_function
)

def create_ground_truths(self, results: Dict[str, Any]):
# TODO: Create all ground truth descriptors and put them in index descriptor as reference
# TODO: Create all ground truth descriptors and
# put them in index descriptor as reference
if self.train_op is not None:
for codec_desc in self.train_op.codec_descs:
self.create_gt_codec(codec_desc, results)
Expand All @@ -949,33 +953,33 @@ def create_ground_truths(self, results: Dict[str, Any]):
self.create_gt_knn(knn_desc, results)
self.create_range_ref_knn(knn_desc)

def execute(self, results: Dict[str, Any], dry_run: False):
def prepare_gt_or_range_knn(self, results: Dict[str, Any]):
if self.search_op is not None:
for knn_desc in self.search_op.knn_descs:
self.create_gt_knn(knn_desc, results)
self.create_range_ref_knn(knn_desc)

def execute(self, results: Dict[str, Any], dry_run: bool = False):
faiss.omp_set_num_threads(self.num_threads)
if self.train_op is not None:
for desc in self.train_op.codec_descs:
results, requires = self.train_op.train(desc, results, dry_run=dry_run)
if dry_run:
if requires is None:
continue
return results, requires
assert requires is None
results, requires = (
self.train_op.train(results=results, dry_run=dry_run)
)
if dry_run and requires:
return results, requires

if self.build_op is not None:
for desc in self.build_op.index_descs:
self.build_op.build(desc, results)
self.build_op.build(results)

if self.search_op is not None:
for desc in self.search_op.knn_descs:
results, requires = self.search_one(
knn_desc=desc,
results=results,
dry_run=dry_run,
range=self.search_op.range,
)
if dry_run:
if requires is None:
continue
return results, requires
if not dry_run and self.compute_gt:
self.prepare_gt_or_range_knn(results)

assert requires is None
results, requires = (
self.search_op.search(results=results, dry_run=dry_run)
)
if dry_run and requires:
return results, requires
return results, None

def execute_2(self, result_file=None):
Expand Down
Loading
Loading