Skip to content

Commit b3e3c2d

Browse files
Amir Sadoughifacebook-github-bot
Amir Sadoughi
authored andcommitted
TimeoutCallback C++ and Python (facebookresearch#3417)
Summary: Pull Request resolved: facebookresearch#3417 facebookresearch#3351 Reviewed By: junjieqi Differential Revision: D57120422 fbshipit-source-id: e2e446642e7be8647f5115f90916fad242e31286
1 parent 0cc0e19 commit b3e3c2d

8 files changed

+127
-2
lines changed

faiss/gpu/perf/PerfClustering.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <vector>
1818

1919
#include <cuda_profiler_api.h>
20+
#include <faiss/impl/AuxIndexStructures.h>
2021

2122
DEFINE_int32(num, 10000, "# of vecs");
2223
DEFINE_int32(k, 100, "# of clusters");
@@ -34,6 +35,7 @@ DEFINE_int64(
3435
"minimum size to use CPU -> GPU paged copies");
3536
DEFINE_int64(pinned_mem, -1, "pinned memory allocation to use");
3637
DEFINE_int32(max_points, -1, "max points per centroid");
38+
DEFINE_double(timeout, 0, "timeout in seconds");
3739

3840
using namespace faiss::gpu;
3941

@@ -99,10 +101,14 @@ int main(int argc, char** argv) {
99101
cp.max_points_per_centroid = FLAGS_max_points;
100102
}
101103

104+
auto tc = new faiss::TimeoutCallback();
105+
faiss::InterruptCallback::instance.reset(tc);
106+
102107
faiss::Clustering kmeans(FLAGS_dim, FLAGS_k, cp);
103108

104109
// Time k-means
105110
{
111+
tc->set_timeout(FLAGS_timeout);
106112
CpuTimer timer;
107113

108114
kmeans.train(FLAGS_num, vecs.data(), *(gpuIndex.getIndex()));

faiss/impl/AuxIndexStructures.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -236,4 +236,29 @@ size_t InterruptCallback::get_period_hint(size_t flops) {
236236
return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
237237
}
238238

239+
void TimeoutCallback::set_timeout(double timeout_in_seconds) {
240+
timeout = timeout_in_seconds;
241+
start = std::chrono::steady_clock::now();
242+
}
243+
244+
bool TimeoutCallback::want_interrupt() {
245+
if (timeout == 0) {
246+
return false;
247+
}
248+
auto end = std::chrono::steady_clock::now();
249+
std::chrono::duration<float, std::milli> duration = end - start;
250+
float elapsed_in_seconds = duration.count() / 1000.0;
251+
if (elapsed_in_seconds > timeout) {
252+
timeout = 0;
253+
return true;
254+
}
255+
return false;
256+
}
257+
258+
void TimeoutCallback::reset(double timeout_in_seconds) {
259+
auto tc(new faiss::TimeoutCallback());
260+
faiss::InterruptCallback::instance.reset(tc);
261+
tc->set_timeout(timeout_in_seconds);
262+
}
263+
239264
} // namespace faiss

faiss/impl/AuxIndexStructures.h

+8
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ struct FAISS_API InterruptCallback {
161161
static size_t get_period_hint(size_t flops);
162162
};
163163

164+
struct TimeoutCallback : InterruptCallback {
165+
std::chrono::time_point<std::chrono::steady_clock> start;
166+
double timeout;
167+
bool want_interrupt() override;
168+
void set_timeout(double timeout_in_seconds);
169+
static void reset(double timeout_in_seconds);
170+
};
171+
164172
/// set implementation optimized for fast access.
165173
struct VisitedTable {
166174
std::vector<uint8_t> visited;

faiss/python/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,14 @@ def deserialize_index_binary(data):
316316
reader = VectorIOReader()
317317
copy_array_to_vector(data, reader.data)
318318
return read_index_binary(reader)
319+
320+
321+
class TimeoutGuard:
322+
def __init__(self, timeout_in_seconds: float):
323+
self.timeout = timeout_in_seconds
324+
325+
def __enter__(self):
326+
TimeoutCallback.reset(self.timeout)
327+
328+
def __exit__(self, exc_type, exc_value, traceback):
329+
PythonInterruptCallback.reset()

faiss/python/swigfaiss.swig

+7-2
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,9 @@ PyObject *swig_ptr (PyObject *a)
10411041
PyErr_SetString(PyExc_ValueError, "did not recognize array type");
10421042
return NULL;
10431043
}
1044+
%}
10441045

1046+
%inline %{
10451047

10461048
struct PythonInterruptCallback: faiss::InterruptCallback {
10471049

@@ -1056,15 +1058,18 @@ struct PythonInterruptCallback: faiss::InterruptCallback {
10561058
return err == -1;
10571059
}
10581060

1061+
static void reset() {
1062+
faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
1063+
}
10591064
};
1065+
10601066
%}
10611067

10621068
%init %{
10631069
/* needed, else crash at runtime */
10641070
import_array();
10651071

1066-
faiss::InterruptCallback::instance.reset(new PythonInterruptCallback());
1067-
1072+
PythonInterruptCallback::reset();
10681073
%}
10691074

10701075
// return a pointer usable as input for functions that expect pointers

tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(FAISS_TEST_SRC
3434
test_fastscan_perf.cpp
3535
test_disable_pq_sdc_tables.cpp
3636
test_common_ivf_empty_index.cpp
37+
test_callback.cpp
3738
)
3839

3940
add_executable(faiss_test ${FAISS_TEST_SRC})

tests/test_callback.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <faiss/Clustering.h>
11+
#include <faiss/IndexFlat.h>
12+
#include <faiss/impl/AuxIndexStructures.h>
13+
#include <faiss/impl/FaissException.h>
14+
#include <faiss/utils/random.h>
15+
16+
TEST(TestCallback, timeout) {
17+
int n = 1000;
18+
int k = 100;
19+
int d = 128;
20+
int niter = 1000000000;
21+
int seed = 42;
22+
23+
std::vector<float> vecs(n * d);
24+
faiss::float_rand(vecs.data(), vecs.size(), seed);
25+
26+
auto index(new faiss::IndexFlat(d));
27+
28+
faiss::ClusteringParameters cp;
29+
cp.niter = niter;
30+
cp.verbose = false;
31+
32+
faiss::Clustering kmeans(d, k, cp);
33+
34+
faiss::TimeoutCallback::reset(0.010);
35+
EXPECT_THROW(kmeans.train(n, vecs.data(), *index), faiss::FaissException);
36+
delete index;
37+
}

tests/test_callback_py.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
import numpy as np
8+
import faiss
9+
10+
11+
class TestCallbackPy(unittest.TestCase):
12+
def setUp(self) -> None:
13+
super().setUp()
14+
15+
def test_timeout(self) -> None:
16+
n = 1000
17+
k = 100
18+
d = 128
19+
niter = 1_000_000_000
20+
21+
x = np.random.rand(n, d).astype('float32')
22+
index = faiss.IndexFlat(d)
23+
24+
cp = faiss.ClusteringParameters()
25+
cp.niter = niter
26+
cp.verbose = False
27+
28+
kmeans = faiss.Clustering(d, k, cp)
29+
30+
with self.assertRaises(RuntimeError):
31+
with faiss.TimeoutGuard(0.010):
32+
kmeans.train(x, index)

0 commit comments

Comments
 (0)