@@ -1247,3 +1247,152 @@ def replacement_insert(self, codes, inserted=None):
1247
1247
return inserted
1248
1248
1249
1249
replace_method (the_class , 'insert' , replacement_insert )
1250
+
1251
+ ######################################################
1252
+ # Syntatic sugar for NeuralNet classes
1253
+ ######################################################
1254
+
1255
+
1256
+ def handle_Tensor2D (the_class ):
1257
+ the_class .original_init = the_class .__init__
1258
+
1259
+ def replacement_init (self , * args ):
1260
+ if len (args ) == 1 :
1261
+ array , = args
1262
+ n , d = array .shape
1263
+ self .original_init (n , d )
1264
+ faiss .copy_array_to_vector (
1265
+ np .ascontiguousarray (array ).ravel (), self .v )
1266
+ else :
1267
+ self .original_init (* args )
1268
+
1269
+ def numpy (self ):
1270
+ shape = np .zeros (2 , dtype = np .int64 )
1271
+ faiss .memcpy (faiss .swig_ptr (shape ), self .shape , shape .nbytes )
1272
+ return faiss .vector_to_array (self .v ).reshape (shape [0 ], shape [1 ])
1273
+
1274
+ the_class .__init__ = replacement_init
1275
+ the_class .numpy = numpy
1276
+
1277
+
1278
+ def handle_Embedding (the_class ):
1279
+ the_class .original_init = the_class .__init__
1280
+
1281
+ def replacement_init (self , * args ):
1282
+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1283
+ self .original_init (* args )
1284
+ return
1285
+ # assume it's a torch.Embedding
1286
+ emb = args [0 ]
1287
+ self .original_init (emb .num_embeddings , emb .embedding_dim )
1288
+ self .from_torch (emb )
1289
+
1290
+ def from_torch (self , emb ):
1291
+ """ copy weights from torch.Embedding """
1292
+ assert emb .weight .shape == (self .num_embeddings , self .embedding_dim )
1293
+ faiss .copy_array_to_vector (
1294
+ np .ascontiguousarray (emb .weight .data ).ravel (), self .weight )
1295
+
1296
+ def from_array (self , array ):
1297
+ """ copy weights from numpy array """
1298
+ assert array .shape == (self .num_embeddings , self .embedding_dim )
1299
+ faiss .copy_array_to_vector (
1300
+ np .ascontiguousarray (array ).ravel (), self .weight )
1301
+
1302
+ the_class .from_array = from_array
1303
+ the_class .from_torch = from_torch
1304
+ the_class .__init__ = replacement_init
1305
+
1306
+
1307
+ def handle_Linear (the_class ):
1308
+ the_class .original_init = the_class .__init__
1309
+
1310
+ def replacement_init (self , * args ):
1311
+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1312
+ self .original_init (* args )
1313
+ return
1314
+ # assume it's a torch.Linear
1315
+ linear = args [0 ]
1316
+ bias = linear .bias is not None
1317
+ self .original_init (linear .in_features , linear .out_features , bias )
1318
+ self .from_torch (linear )
1319
+
1320
+ def from_torch (self , linear ):
1321
+ """ copy weights from torch.Linear """
1322
+ assert linear .weight .shape == (self .out_features , self .in_features )
1323
+ faiss .copy_array_to_vector (
1324
+ linear .weight .data .numpy ().ravel (), self .weight )
1325
+ if linear .bias is not None :
1326
+ assert linear .bias .shape == (self .out_features ,)
1327
+ faiss .copy_array_to_vector (linear .bias .data .numpy (), self .bias )
1328
+
1329
+ def from_array (self , array , bias = None ):
1330
+ """ copy weights from numpy array """
1331
+ assert array .shape == (self .out_features , self .in_features )
1332
+ faiss .copy_array_to_vector (
1333
+ np .ascontiguousarray (array ).ravel (), self .weight )
1334
+ if bias is not None :
1335
+ assert bias .shape == (self .out_features ,)
1336
+ faiss .copy_array_to_vector (bias , self .bias )
1337
+
1338
+ the_class .__init__ = replacement_init
1339
+ the_class .from_array = from_array
1340
+ the_class .from_torch = from_torch
1341
+
1342
+ ######################################################
1343
+ # Syntatic sugar for QINCo and QINCoStep
1344
+ ######################################################
1345
+
1346
+ def handle_QINCoStep (the_class ):
1347
+ the_class .original_init = the_class .__init__
1348
+
1349
+ def replacement_init (self , * args ):
1350
+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1351
+ self .original_init (* args )
1352
+ return
1353
+ step = args [0 ]
1354
+ # assume it's a Torch QINCoStep
1355
+ self .original_init (step .d , step .K , step .L , step .h )
1356
+ self .from_torch (step )
1357
+
1358
+ def from_torch (self , step ):
1359
+ """ copy weights from torch.QINCoStep """
1360
+ assert (step .d , step .K , step .L , step .h ) == (self .d , self .K , self .L , self .h )
1361
+ self .codebook .from_torch (step .codebook )
1362
+ self .MLPconcat .from_torch (step .MLPconcat )
1363
+
1364
+ for l in range (step .L ):
1365
+ src = step .residual_blocks [l ]
1366
+ dest = self .get_residual_block (l )
1367
+ dest .linear1 .from_torch (src [0 ])
1368
+ dest .linear2 .from_torch (src [2 ])
1369
+
1370
+ the_class .__init__ = replacement_init
1371
+ the_class .from_torch = from_torch
1372
+
1373
+
1374
+ def handle_QINCo (the_class ):
1375
+ the_class .original_init = the_class .__init__
1376
+
1377
+ def replacement_init (self , * args ):
1378
+ if len (args ) != 1 or args [0 ].__class__ == the_class :
1379
+ self .original_init (* args )
1380
+ return
1381
+
1382
+ # assume it's a Torch QINCo
1383
+ qinco = args [0 ]
1384
+ self .original_init (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h )
1385
+ self .from_torch (qinco )
1386
+
1387
+ def from_torch (self , qinco ):
1388
+ """ copy weights from torch.QINCo """
1389
+ assert (
1390
+ (qinco .d , qinco .K , qinco .L , qinco .M , qinco .h ) ==
1391
+ (self .d , self .K , self .L , self .M , self .h )
1392
+ )
1393
+ self .codebook0 .from_torch (qinco .codebook0 )
1394
+ for m in range (qinco .M - 1 ):
1395
+ self .get_step (m ).from_torch (qinco .steps [m ])
1396
+
1397
+ the_class .__init__ = replacement_init
1398
+ the_class .from_torch = from_torch
0 commit comments