Skip to content

Commit 0f86f33

Browse files
mengdilinfacebook-github-bot
authored andcommitted
Improve naming due to codemod (#4070)
Summary: Pull Request resolved: #4070 rename `d_2` to `d` from codemod Reviewed By: junjieqi Differential Revision: D66823543 fbshipit-source-id: d1b4c702a31127cba105d4d4f1514f5e33925a50
1 parent 0d568bc commit 0f86f33

File tree

1 file changed

+25
-21
lines changed

1 file changed

+25
-21
lines changed

faiss/utils/NeuralNet.cpp

+25-21
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,23 @@ nn::Int32Tensor2D QINCoStep::encode(
212212
// repeated codebook
213213
Tensor2D zqs_r(n * K, d); // size n, K, d
214214
Tensor2D cc(n * K, d * 2); // size n, K, d * 2
215-
size_t d_2 = this->d;
216215

217-
auto copy_row = [d_2](Tensor2D& t, size_t i, size_t j, const float* data) {
218-
assert(i <= t.shape[0] && j <= t.shape[1]);
219-
memcpy(t.data() + i * t.shape[1] + j, data, sizeof(float) * d_2);
220-
};
216+
size_t local_d = this->d;
217+
218+
auto copy_row =
219+
[local_d](Tensor2D& t, size_t i, size_t j, const float* data) {
220+
assert(i <= t.shape[0] && j <= t.shape[1]);
221+
memcpy(t.data() + i * t.shape[1] + j,
222+
data,
223+
sizeof(float) * local_d);
224+
};
221225

222226
// manual broadcasting
223227
for (size_t i = 0; i < n; i++) {
224228
for (size_t j = 0; j < K; j++) {
225-
copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d_2);
226-
copy_row(cc, i * K + j, 0, codebook.data() + j * d_2);
227-
copy_row(cc, i * K + j, d_2, xhat.data() + i * d_2);
229+
copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d);
230+
copy_row(cc, i * K + j, 0, codebook.data() + j * d);
231+
copy_row(cc, i * K + j, d, xhat.data() + i * d);
228232
}
229233
}
230234

@@ -237,13 +241,13 @@ nn::Int32Tensor2D QINCoStep::encode(
237241

238242
// add the xhat
239243
for (size_t i = 0; i < n; i++) {
240-
float* zqs_r_row = zqs_r.data() + i * K * d_2;
241-
const float* xhat_row = xhat.data() + i * d_2;
244+
float* zqs_r_row = zqs_r.data() + i * K * d;
245+
const float* xhat_row = xhat.data() + i * d;
242246
for (size_t l = 0; l < K; l++) {
243-
for (size_t j = 0; j < d_2; j++) {
247+
for (size_t j = 0; j < d; j++) {
244248
zqs_r_row[j] += xhat_row[j];
245249
}
246-
zqs_r_row += d_2;
250+
zqs_r_row += d;
247251
}
248252
}
249253

@@ -252,31 +256,31 @@ nn::Int32Tensor2D QINCoStep::encode(
252256
float* res = nullptr;
253257
if (residuals) {
254258
FAISS_THROW_IF_NOT(
255-
residuals->shape[0] == n && residuals->shape[1] == d_2);
259+
residuals->shape[0] == n && residuals->shape[1] == d);
256260
res = residuals->data();
257261
}
258262

259263
for (size_t i = 0; i < n; i++) {
260-
const float* q = x.data() + i * d_2;
261-
const float* db = zqs_r.data() + i * K * d_2;
264+
const float* q = x.data() + i * d;
265+
const float* db = zqs_r.data() + i * K * d;
262266
float dis_min = HUGE_VALF;
263267
int64_t idx = -1;
264268
for (size_t j = 0; j < K; j++) {
265-
float dis = fvec_L2sqr(q, db, d_2);
269+
float dis = fvec_L2sqr(q, db, d);
266270
if (dis < dis_min) {
267271
dis_min = dis;
268272
idx = j;
269273
}
270-
db += d_2;
274+
db += d;
271275
}
272276
codes.v[i] = idx;
273277
if (res) {
274-
const float* xhat_row = xhat.data() + i * d_2;
275-
const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d_2;
276-
for (size_t j = 0; j < d_2; j++) {
278+
const float* xhat_row = xhat.data() + i * d;
279+
const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d;
280+
for (size_t j = 0; j < d; j++) {
277281
res[j] = xhat_next_row[j] - xhat_row[j];
278282
}
279-
res += d_2;
283+
res += d;
280284
}
281285
}
282286
return codes;

0 commit comments

Comments
 (0)