Skip to content

Commit 4c83965

Browse files
algoriddlefacebook-github-bot
authored andcommitted
benchmark view results (facebookresearch#3144)
Summary: Pull Request resolved: facebookresearch#3144 Visualize results of running the benchmark with Pareto optima filtering: 1. per index or across indices 2. for space, time or space & time 3. knn or range search, the latter @ specific precision Reviewed By: mdouze Differential Revision: D51552775 fbshipit-source-id: d4f29e3d46ef044e71b54439b3972548c86af5a7
1 parent 9519a19 commit 4c83965

File tree

1 file changed

+289
-0
lines changed

1 file changed

+289
-0
lines changed

benchs/bench_fw_notebook.ipynb

+289
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "be081589-e1b2-4569-acb7-44203e273899",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import matplotlib.pyplot as plt\n",
11+
"import itertools\n",
12+
"from faiss.contrib.evaluation import OperatingPoints\n",
13+
"from enum import Enum\n",
14+
"from bench_fw.benchmark_io import BenchmarkIO as BIO"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"id": "a6492e95-24c7-4425-bf0a-27e10e879ca6",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"root = \"/checkpoint\"\n",
25+
"results = BIO(root).read_json(\"result.json\")\n",
26+
"results.keys()"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"id": "0875d269-aef4-426d-83dd-866970f43777",
33+
"metadata": {},
34+
"outputs": [],
35+
"source": [
36+
"results['indices']"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": null,
42+
"id": "a7ff7078-29c7-407c-a079-201877b764ad",
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"class Cost:\n",
47+
" def __init__(self, values):\n",
48+
" self.values = values\n",
49+
"\n",
50+
" def __le__(self, other):\n",
51+
" return all(v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True))\n",
52+
"\n",
53+
" def __lt__(self, other):\n",
54+
" return all(v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True))\n",
55+
"\n",
56+
"class ParetoMode(Enum):\n",
57+
" DISABLE = 1 # no Pareto filtering\n",
58+
" INDEX = 2 # index-local optima\n",
59+
" GLOBAL = 3 # global optima\n",
60+
"\n",
61+
"\n",
62+
"class ParetoMetric(Enum):\n",
63+
" TIME = 0 # time vs accuracy\n",
64+
" SPACE = 1 # space vs accuracy\n",
65+
" TIME_SPACE = 2 # (time, space) vs accuracy\n",
66+
"\n",
67+
"def range_search_recall_at_precision(experiment, precision):\n",
68+
" return round(max(r for r, p in zip(experiment['range_search_pr']['recall'], experiment['range_search_pr']['precision']) if p > precision), 6)\n",
69+
"\n",
70+
"def filter_results(\n",
71+
" results,\n",
72+
" evaluation,\n",
73+
" accuracy_metric, # str or func\n",
74+
" time_metric=None, # func or None -> use default\n",
75+
" space_metric=None, # func or None -> use default\n",
76+
" min_accuracy=0,\n",
77+
" max_space=0,\n",
78+
" max_time=0,\n",
79+
" scaling_factor=1.0,\n",
80+
" \n",
81+
" pareto_mode=ParetoMode.DISABLE,\n",
82+
" pareto_metric=ParetoMetric.TIME,\n",
83+
"):\n",
84+
" if isinstance(accuracy_metric, str):\n",
85+
" accuracy_key = accuracy_metric\n",
86+
" accuracy_metric = lambda v: v[accuracy_key]\n",
87+
"\n",
88+
" if time_metric is None:\n",
89+
" time_metric = lambda v: v['time'] * scaling_factor + (v['quantizer']['time'] if 'quantizer' in v else 0)\n",
90+
"\n",
91+
" if space_metric is None:\n",
92+
" space_metric = lambda v: results['indices'][v['codec']]['code_size']\n",
93+
" \n",
94+
" fe = []\n",
95+
" ops = {}\n",
96+
" if pareto_mode == ParetoMode.GLOBAL:\n",
97+
" op = OperatingPoints()\n",
98+
" ops[\"global\"] = op\n",
99+
" for k, v in results['experiments'].items():\n",
100+
" if f\".{evaluation}\" in k:\n",
101+
" accuracy = accuracy_metric(v)\n",
102+
" if min_accuracy > 0 and accuracy < min_accuracy:\n",
103+
" continue\n",
104+
" space = space_metric(v)\n",
105+
" if max_space > 0 and space > max_space:\n",
106+
" continue\n",
107+
" time = time_metric(v)\n",
108+
" if max_time > 0 and time > max_time:\n",
109+
" continue\n",
110+
" idx_name = v['index']\n",
111+
" experiment = (accuracy, space, time, k, v)\n",
112+
" if pareto_mode == ParetoMode.DISABLE:\n",
113+
" fe.append(experiment)\n",
114+
" continue\n",
115+
" if pareto_mode == ParetoMode.INDEX:\n",
116+
" if idx_name not in ops:\n",
117+
" ops[idx_name] = OperatingPoints()\n",
118+
" op = ops[idx_name]\n",
119+
" if pareto_metric == ParetoMetric.TIME:\n",
120+
" op.add_operating_point(experiment, accuracy, time)\n",
121+
" elif pareto_metric == ParetoMetric.SPACE:\n",
122+
" op.add_operating_point(experiment, accuracy, space)\n",
123+
" else:\n",
124+
" op.add_operating_point(experiment, accuracy, Cost([time, space]))\n",
125+
"\n",
126+
" if ops:\n",
127+
" for op in ops.values():\n",
128+
" for v, _, _ in op.operating_points:\n",
129+
" fe.append(v)\n",
130+
"\n",
131+
" fe.sort()\n",
132+
" return fe"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"id": "f080a6e2-1565-418b-8732-4adeff03a099",
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"def plot_metric(experiments, accuracy_title, cost_title, plot_space=False):\n",
143+
" x = {}\n",
144+
" y = {}\n",
145+
" for accuracy, space, time, k, v in experiments:\n",
146+
" idx_name = v['index']\n",
147+
" if idx_name not in x:\n",
148+
" x[idx_name] = []\n",
149+
" y[idx_name] = []\n",
150+
" x[idx_name].append(accuracy)\n",
151+
" if plot_space:\n",
152+
" y[idx_name].append(space)\n",
153+
" else:\n",
154+
" y[idx_name].append(time)\n",
155+
"\n",
156+
" #plt.figure(figsize=(10,6))\n",
157+
" plt.yscale(\"log\")\n",
158+
" plt.title(accuracy_title)\n",
159+
" plt.xlabel(accuracy_title)\n",
160+
" plt.ylabel(cost_title)\n",
161+
" marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n",
162+
" for index in x.keys():\n",
163+
" plt.plot(x[index], y[index], marker=next(marker), label=index)\n",
164+
" plt.legend(bbox_to_anchor=(1, 1), loc='upper left')"
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"id": "61007155-5edc-449e-835e-c141a01a2ae5",
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"accuracy_metric = \"knn_intersection\"\n",
175+
"fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
176+
"plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 16 cores)\")"
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": null,
182+
"id": "36e82084-18f6-4546-a717-163eb0224ee8",
183+
"metadata": {},
184+
"outputs": [],
185+
"source": [
186+
"precision = 0.8\n",
187+
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
188+
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
189+
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
190+
]
191+
},
192+
{
193+
"cell_type": "code",
194+
"execution_count": null,
195+
"id": "aff79376-39f7-47c0-8b83-1efe5192bb7e",
196+
"metadata": {},
197+
"outputs": [],
198+
"source": [
199+
"# index local optima\n",
200+
"precision = 0.2\n",
201+
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
202+
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
203+
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
204+
]
205+
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": null,
209+
"id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1",
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"# global optima\n",
214+
"precision = 0.8\n",
215+
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
216+
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
217+
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
218+
]
219+
},
220+
{
221+
"cell_type": "code",
222+
"execution_count": null,
223+
"id": "9aead830-6209-4956-b7ea-4a5e0029d616",
224+
"metadata": {},
225+
"outputs": [],
226+
"source": [
227+
"def plot_range_search_pr_curves(experiments):\n",
228+
" x = {}\n",
229+
" y = {}\n",
230+
" show = {\n",
231+
" 'Flat': None,\n",
232+
" }\n",
233+
" for _, _, _, k, v in fr:\n",
234+
" if \".weighted\" in k: # and v['index'] in show:\n",
235+
" x[k] = v['range_search_pr']['recall']\n",
236+
" y[k] = v['range_search_pr']['precision']\n",
237+
" \n",
238+
" plt.title(\"range search recall\")\n",
239+
" plt.xlabel(\"recall\")\n",
240+
" plt.ylabel(\"precision\")\n",
241+
" for index in x.keys():\n",
242+
" plt.plot(x[index], y[index], '.', label=index)\n",
243+
" plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')"
244+
]
245+
},
246+
{
247+
"cell_type": "code",
248+
"execution_count": null,
249+
"id": "92e45502-7a31-4a15-90df-fa3032d7d350",
250+
"metadata": {},
251+
"outputs": [],
252+
"source": [
253+
"precision = 0.8\n",
254+
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
255+
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n",
256+
"plot_range_search_pr_curves(fr)"
257+
]
258+
},
259+
{
260+
"cell_type": "code",
261+
"execution_count": null,
262+
"id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c",
263+
"metadata": {},
264+
"outputs": [],
265+
"source": []
266+
}
267+
],
268+
"metadata": {
269+
"kernelspec": {
270+
"display_name": "Python [conda env:faiss_cpu_from_source] *",
271+
"language": "python",
272+
"name": "conda-env-faiss_cpu_from_source-py"
273+
},
274+
"language_info": {
275+
"codemirror_mode": {
276+
"name": "ipython",
277+
"version": 3
278+
},
279+
"file_extension": ".py",
280+
"mimetype": "text/x-python",
281+
"name": "python",
282+
"nbconvert_exporter": "python",
283+
"pygments_lexer": "ipython3",
284+
"version": "3.11.5"
285+
}
286+
},
287+
"nbformat": 4,
288+
"nbformat_minor": 5
289+
}

0 commit comments

Comments
 (0)