@@ -212,19 +212,23 @@ nn::Int32Tensor2D QINCoStep::encode(
212
212
// repeated codebook
213
213
Tensor2D zqs_r (n * K, d); // size n, K, d
214
214
Tensor2D cc (n * K, d * 2 ); // size n, K, d * 2
215
- size_t d_2 = this ->d ;
216
215
217
- auto copy_row = [d_2](Tensor2D& t, size_t i, size_t j, const float * data) {
218
- assert (i <= t.shape [0 ] && j <= t.shape [1 ]);
219
- memcpy (t.data () + i * t.shape [1 ] + j, data, sizeof (float ) * d_2);
220
- };
216
+ size_t local_d = this ->d ;
217
+
218
+ auto copy_row =
219
+ [local_d](Tensor2D& t, size_t i, size_t j, const float * data) {
220
+ assert (i <= t.shape [0 ] && j <= t.shape [1 ]);
221
+ memcpy (t.data () + i * t.shape [1 ] + j,
222
+ data,
223
+ sizeof (float ) * local_d);
224
+ };
221
225
222
226
// manual broadcasting
223
227
for (size_t i = 0 ; i < n; i++) {
224
228
for (size_t j = 0 ; j < K; j++) {
225
- copy_row (zqs_r, i * K + j, 0 , codebook.data () + j * d_2 );
226
- copy_row (cc, i * K + j, 0 , codebook.data () + j * d_2 );
227
- copy_row (cc, i * K + j, d_2 , xhat.data () + i * d_2 );
229
+ copy_row (zqs_r, i * K + j, 0 , codebook.data () + j * d );
230
+ copy_row (cc, i * K + j, 0 , codebook.data () + j * d );
231
+ copy_row (cc, i * K + j, d , xhat.data () + i * d );
228
232
}
229
233
}
230
234
@@ -237,13 +241,13 @@ nn::Int32Tensor2D QINCoStep::encode(
237
241
238
242
// add the xhat
239
243
for (size_t i = 0 ; i < n; i++) {
240
- float * zqs_r_row = zqs_r.data () + i * K * d_2 ;
241
- const float * xhat_row = xhat.data () + i * d_2 ;
244
+ float * zqs_r_row = zqs_r.data () + i * K * d ;
245
+ const float * xhat_row = xhat.data () + i * d ;
242
246
for (size_t l = 0 ; l < K; l++) {
243
- for (size_t j = 0 ; j < d_2 ; j++) {
247
+ for (size_t j = 0 ; j < d ; j++) {
244
248
zqs_r_row[j] += xhat_row[j];
245
249
}
246
- zqs_r_row += d_2 ;
250
+ zqs_r_row += d ;
247
251
}
248
252
}
249
253
@@ -252,31 +256,31 @@ nn::Int32Tensor2D QINCoStep::encode(
252
256
float * res = nullptr ;
253
257
if (residuals) {
254
258
FAISS_THROW_IF_NOT (
255
- residuals->shape [0 ] == n && residuals->shape [1 ] == d_2 );
259
+ residuals->shape [0 ] == n && residuals->shape [1 ] == d );
256
260
res = residuals->data ();
257
261
}
258
262
259
263
for (size_t i = 0 ; i < n; i++) {
260
- const float * q = x.data () + i * d_2 ;
261
- const float * db = zqs_r.data () + i * K * d_2 ;
264
+ const float * q = x.data () + i * d ;
265
+ const float * db = zqs_r.data () + i * K * d ;
262
266
float dis_min = HUGE_VALF;
263
267
int64_t idx = -1 ;
264
268
for (size_t j = 0 ; j < K; j++) {
265
- float dis = fvec_L2sqr (q, db, d_2 );
269
+ float dis = fvec_L2sqr (q, db, d );
266
270
if (dis < dis_min) {
267
271
dis_min = dis;
268
272
idx = j;
269
273
}
270
- db += d_2 ;
274
+ db += d ;
271
275
}
272
276
codes.v [i] = idx;
273
277
if (res) {
274
- const float * xhat_row = xhat.data () + i * d_2 ;
275
- const float * xhat_next_row = zqs_r.data () + (i * K + idx) * d_2 ;
276
- for (size_t j = 0 ; j < d_2 ; j++) {
278
+ const float * xhat_row = xhat.data () + i * d ;
279
+ const float * xhat_next_row = zqs_r.data () + (i * K + idx) * d ;
280
+ for (size_t j = 0 ; j < d ; j++) {
277
281
res[j] = xhat_next_row[j] - xhat_row[j];
278
282
}
279
- res += d_2 ;
283
+ res += d ;
280
284
}
281
285
}
282
286
return codes;
0 commit comments