1
1
/* *
2
- * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey)
2
+ * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey, Wei kang )
3
3
*
4
4
* See LICENSE for clarification regarding multiple authors
5
5
*
@@ -1014,6 +1014,350 @@ class Hash {
1014
1014
};
1015
1015
1016
1016
1017
+ /*
1018
+ How class Hash64 works:
1019
+
1020
+ - It can function as a map from key=uint64_t to value=uint64_t, you must
1021
+ decide the number of buckets, when you create the hash, but you can resize
1022
+ it (manually).
1023
+
1024
+ Note:
1025
+ Each bucket contains a pair of key/value, each 64bits, key is stored at
1026
+ data[2 * bucket_index] and value is stored at data[2 * bucket_index + 1].
1027
+
1028
+ Some constraints:
1029
+ - You can store any (key,value) pair, except the pair where all the bits of
1030
+ both key and value are set [that is used to mean "nothing here"]
1031
+ - The number of buckets must always be a power of 2.
1032
+ - When deleting values from the hash you must delete them all at
1033
+ once (necessary because there is no concept of a "tombstone".
1034
+
1035
+ Some notes on usage:
1036
+
1037
+ You use it by: constructing it, obtaining its Accessor with GetAccessor();
1038
+ and inside kernels (or host code), calling functions Insert(), Find() or
1039
+ Delete() of the Accessor object. Resizing is not automatic; it is the
1040
+ user's responsibility to make sure the hash does not get too full
1041
+ (which could cause assertion failures in kernels, and will be very slow).
1042
+
1043
+ Some implementation notes:
1044
+ - When accessing hash[key], we use bucket_index == key % num_buckets,
1045
+ bucket_inc = 1 | (((key * 2) / num_buckets) ^ key).
1046
+ - If the bucket at `bucket_index` is occupied, we look in locations
1047
+ `(bucket_index + n * bucket_inc)%num_buckets` for n = 1, 2, ...;
1048
+ this choice ensures that if multiple keys hash to the same bucket,
1049
+ they don't all access the same sequence of locations; and bucket_inc
1050
+ being odd ensures we eventually try all locations (of course for
1051
+ reasonable hash occupancy levels, we shouldn't ever have to try
1052
+ more than two or three).
1053
+
1054
+ */
1055
+ class Hash64 {
1056
+ public:
1057
+ /* Constructor. Context can be for CPU or GPU.
1058
+
1059
+ @param [in] num_buckets Number of buckets in the hash; must be
1060
+ a power of 2 and >= 128 (this limit was arbitrarily chosen).
1061
+ The number of items in the hash cannot exceed the number of
1062
+ buckets, or the code will loop infinitely when you try to add
1063
+ items; aim for less than 50% occupancy.
1064
+ */
1065
+ Hash64 (ContextPtr c, int64_t num_buckets) {
1066
+ K2_CHECK_GE (num_buckets, 128 );
1067
+ data_ = Array1<uint64_t >(c, num_buckets * 2 , ~(uint64_t )0 );
1068
+ int64_t n = 2 ;
1069
+ for (buckets_num_bitsm1_ = 0 ; n < num_buckets;
1070
+ n *= 2 , buckets_num_bitsm1_++) {
1071
+ }
1072
+ K2_CHECK_EQ (num_buckets, 2 << buckets_num_bitsm1_)
1073
+ << " num_buckets must be a power of 2." ;
1074
+ }
1075
+
1076
+ // Only to be used prior to assignment.
1077
+ Hash64 () = default ;
1078
+
1079
+ int64_t NumBuckets () const { return data_.Dim () / 2 ; }
1080
+
1081
+ // Returns data pointer; for testing..
1082
+ uint64_t *Data () { return data_.Data (); }
1083
+
1084
+ // Shallow copy
1085
+ Hash64 &operator =(const Hash64 &src) = default ;
1086
+ // Copy constructor (shallow copy)
1087
+ explicit Hash64 (const Hash64 &src) = default;
1088
+
1089
+ ContextPtr &Context () const { return data_.Context (); }
1090
+
1091
+ class Accessor {
1092
+ public:
1093
+ Accessor (Hash64 &hash)
1094
+ : data_(hash.data_.Data()),
1095
+ num_buckets_mask_ (uint64_t (hash.NumBuckets()) - 1),
1096
+ buckets_num_bitsm1_(hash.buckets_num_bitsm1_) {}
1097
+
1098
+ // Copy constructor
1099
+ Accessor (const Accessor &src) = default;
1100
+
1101
+ /*
1102
+ Try to insert pair (key,value) into hash.
1103
+ @param [in] key Key into hash, it is an error if ~key == 0, i.e. if all
1104
+ the allowed bits of `key` are set.
1105
+ @param [in] value Value to set, it is an error if ~value == 0, i.e. if
1106
+ all the allowed bits `value` are set.
1107
+ @param [out] old_value If not nullptr, this location will be set to
1108
+ the existing value *if this key was already present* in the
1109
+ hash (or set by another thread in this kernel), i.e. only
1110
+ if this function returns false.
1111
+ @param [out] key_value_location If not nullptr, its contents will be
1112
+ set to the address of the (key,value) pair (either the
1113
+ existing or newly-written one).
1114
+ @return Returns true if this (key,value) pair was inserted, false
1115
+ otherwise.
1116
+
1117
+ Note: the const is with respect to the metadata only; it is required, to
1118
+ avoid compilation errors.
1119
+ */
1120
+ __forceinline__ __host__ __device__ bool Insert (
1121
+ uint64_t key, uint64_t value, uint64_t *old_value = nullptr ,
1122
+ uint64_t **key_value_location = nullptr ) const {
1123
+ uint64_t cur_bucket = key & num_buckets_mask_,
1124
+ bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
1125
+
1126
+ while (1 ) {
1127
+ uint64_t cur_key = data_[2 * cur_bucket];
1128
+ uint64_t cur_value = data_[2 * cur_bucket + 1 ];
1129
+ if (cur_key == key) {
1130
+ if (old_value) *old_value = cur_value;
1131
+ if (key_value_location) *key_value_location = data_ + 2 * cur_bucket;
1132
+ return false ; // key exists in hash
1133
+ } else if (~cur_key == 0 ) {
1134
+ // we have a version of AtomicCAS that also works on host.
1135
+ uint64_t old_key = AtomicCAS (
1136
+ (unsigned long long *)(data_ + 2 * cur_bucket), cur_key, key);
1137
+ if (old_key == cur_key) {
1138
+ // set value
1139
+ data_[2 * cur_bucket + 1 ] = value;
1140
+ if (key_value_location)
1141
+ *key_value_location = data_ + 2 * cur_bucket;
1142
+ return true ; // Successfully inserted.
1143
+ }
1144
+ if (old_key == key) {
1145
+ if (old_value) *old_value = cur_value;
1146
+ if (key_value_location)
1147
+ *key_value_location = data_ + 2 * cur_bucket;
1148
+ return false ; // Another thread inserted this key
1149
+ }
1150
+ }
1151
+ // Rotate bucket index until we find a free location. This will
1152
+ // eventually visit all bucket indexes before it returns to the same
1153
+ // location, because bucket_inc is odd (so only satisfies
1154
+ // (n * bucket_inc) % num_buckets == 0 for n == num_buckets).
1155
+ // Note: n here is the number of times we went around the loop.
1156
+ cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
1157
+ }
1158
+ }
1159
+
1160
+ /*
1161
+ Look up this key in this hash; output the value and optionally the
1162
+ location of the (key,value) pair if found.
1163
+
1164
+ @param [in] key Key to look up;
1165
+ @param [out] value_out If found, value will be written to here. This may
1166
+ seem redundant with key_value_location, but this should
1167
+ compile to a local variable, and we want to avoid
1168
+ redundant memory reads.
1169
+ @param [out] key_value_location (optional) The memory address of the
1170
+ (key,value) pair, in case the caller wants to overwrite
1171
+ the value via SetValue(); must be used for no other
1172
+ purpose.
1173
+ @return Returns true if an item with this key was found in the
1174
+ hash, otherwise false.
1175
+
1176
+ Note: the const is with respect to the metadata only; it is required, to
1177
+ avoid compilation errors.
1178
+ */
1179
+ __forceinline__ __host__ __device__ bool Find (
1180
+ uint64_t key, uint64_t *value_out,
1181
+ uint64_t **key_value_location = nullptr ) const {
1182
+ uint64_t cur_bucket = key & num_buckets_mask_,
1183
+ bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
1184
+ while (1 ) {
1185
+ uint64_t old_key = data_[2 * cur_bucket];
1186
+ uint64_t old_value = data_[2 * cur_bucket + 1 ];
1187
+ if (~old_key == 0 ) {
1188
+ return false ;
1189
+ } else if (old_key == key) {
1190
+ while (~old_value == 0 ) old_value = data_[2 * cur_bucket + 1 ];
1191
+ *value_out = old_value;
1192
+ if (key_value_location) *key_value_location = data_ + 2 * cur_bucket;
1193
+ return true ;
1194
+ } else {
1195
+ cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
1196
+ }
1197
+ }
1198
+ }
1199
+
1200
+ /*
1201
+ Overwrite a value in a (key,value) pair whose location was obtained using
1202
+ Find().
1203
+ @param [in] key_value_location Location that was obtained from
1204
+ a successful call to Find().
1205
+ @param [in] value Value to write;
1206
+
1207
+ Note: the const is with respect to the metadata only; it is required, to
1208
+ avoid compilation errors.
1209
+ */
1210
+ __forceinline__ __host__ __device__ void SetValue (
1211
+ uint64_t *key_value_location, uint64_t value) const {
1212
+ *(key_value_location + 1 ) = value;
1213
+ }
1214
+
1215
+ /* Deletes a key from a hash. Caution: this cannot be combined with other
1216
+ operations on a hash; after you delete a key you cannot do Insert() or
1217
+ Find() until you have deleted all keys. This is an open-addressing hash
1218
+ table with no tombstones, which is why this limitation exists).
1219
+
1220
+ @param [in] key Key to be deleted. Each key present in the hash must
1221
+ be deleted by exactly one thread, or it will loop
1222
+ forever!
1223
+
1224
+ Note: the const is with respect to the metadata only; required, to avoid
1225
+ compilation errors.
1226
+ */
1227
+ __forceinline__ __host__ __device__ void Delete (uint64_t key) const {
1228
+ uint64_t cur_bucket = key & num_buckets_mask_,
1229
+ bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
1230
+ while (1 ) {
1231
+ uint64_t old_key = data_[2 * cur_bucket];
1232
+ if (old_key == key) {
1233
+ data_[2 * cur_bucket] = ~((uint64_t )0 );
1234
+ data_[2 * cur_bucket + 1 ] = ~((uint64_t )0 );
1235
+ return ;
1236
+ } else {
1237
+ cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
1238
+ }
1239
+ }
1240
+ }
1241
+
1242
+ private:
1243
+ // pointer to data
1244
+ uint64_t *data_;
1245
+ // num_buckets_mask is num_buckets (i.e. size of `data_` array) minus one;
1246
+ // num_buckets is a power of 2 so this can be used as a mask to get a number
1247
+ // modulo num_buckets.
1248
+ uint64_t num_buckets_mask_;
1249
+ // A number satisfying num_buckets == 1 << (1+buckets_num_bitsm1_)
1250
+ // the number of bits in `num_buckets` minus one.
1251
+ uint64_t buckets_num_bitsm1_;
1252
+ };
1253
+
1254
+ /*
1255
+ Return an Accessor object which can be used in kernel code (or on CPU if the
1256
+ context is a CPU context).
1257
+ */
1258
+ Accessor GetAccessor () { return Accessor (*this ); }
1259
+
1260
+ // You should call this before the destructor is called if the hash will still
1261
+ // contain values when it is destroyed, to bypass a check.
1262
+ void Destroy () { data_ = Array1<uint64_t >(); }
1263
+
1264
+ void CheckEmpty () const {
1265
+ if (data_.Dim () == 0 ) return ;
1266
+ ContextPtr c = Context ();
1267
+ Array1<int64_t > error (c, 1 , -1 );
1268
+ int64_t *error_data = error.Data ();
1269
+ const uint64_t *hash_data = data_.Data ();
1270
+
1271
+ K2_EVAL (
1272
+ Context (), data_.Dim (), lambda_check_data, (int64_t i)->void {
1273
+ if (~(hash_data[i]) != 0 ) error_data[0 ] = i;
1274
+ });
1275
+ int64_t i = error[0 ];
1276
+ if (i >= 0 ) { // there was an error; i is the index into the hash where
1277
+ // there was an element.
1278
+ int64_t elem = data_[i];
1279
+ // We don't know the number of bits the user was using for the key vs.
1280
+ // value, so print in hex, maybe they can figure it out.
1281
+ K2_LOG (FATAL) << " Destroying hash: still contains values: position " << i
1282
+ << " , content = " << std::hex << elem;
1283
+ }
1284
+ }
1285
+
1286
+ /* Resize the hash to a new number of buckets.
1287
+
1288
+ @param [in] new_num_buckets New number of buckets; must be a power of 2,
1289
+ and must be large enough to accommodate all values in the hash
1290
+ (we assume the caller is keeping track of the number of elements
1291
+ in the hash somehow).
1292
+
1293
+ CAUTION: Resizing will invalidate any accessor objects you have; you need
1294
+ to re-get the accessors before accessing the hash again.
1295
+ */
1296
+ void Resize (int64_t new_num_buckets, bool copy_data = true ) {
1297
+ NVTX_RANGE (K2_FUNC);
1298
+
1299
+ K2_CHECK_GT (new_num_buckets, 0 );
1300
+ K2_CHECK_EQ (new_num_buckets & (new_num_buckets - 1 ), 0 ); // power of 2.
1301
+
1302
+ ContextPtr c = data_.Context ();
1303
+ Hash64 new_hash (c, new_num_buckets);
1304
+
1305
+ if (copy_data) {
1306
+ new_hash.CopyDataFromSimple (*this );
1307
+ }
1308
+
1309
+ *this = new_hash;
1310
+ new_hash.Destroy (); // avoid failed check in destructor (it would otherwise
1311
+ // expect the hash to be empty when destroyed).
1312
+ }
1313
+
1314
+ /*
1315
+ Copies all data elements from `src` to `*this`.
1316
+ */
1317
+ void CopyDataFromSimple (Hash64 &src) {
1318
+ NVTX_RANGE (K2_FUNC);
1319
+ int64_t num_buckets = data_.Dim () / 2 ,
1320
+ src_num_buckets = src.data_ .Dim () / 2 ;
1321
+ const uint64_t *src_data = src.data_ .Data ();
1322
+ uint64_t *data = data_.Data ();
1323
+ uint64_t new_num_buckets_mask = static_cast <uint64_t >(num_buckets) - 1 ,
1324
+ new_buckets_num_bitsm1 = buckets_num_bitsm1_;
1325
+ ContextPtr c = data_.Context ();
1326
+ K2_EVAL (c, src_num_buckets, lambda_copy_data, (uint64_t i) -> void {
1327
+ uint64_t key = src_data[2 * i];
1328
+ uint64_t value = src_data[2 * i + 1 ];
1329
+ if (~key == 0 ) return ; // equals -1.. nothing there.
1330
+ uint64_t bucket_inc = 1 | ((key >> new_buckets_num_bitsm1) ^ key);
1331
+ uint64_t cur_bucket = key & new_num_buckets_mask;
1332
+ while (1 ) {
1333
+ uint64_t assumed = ~((uint64_t )0 ),
1334
+ old_elem = AtomicCAS ((unsigned long long *)(data + 2 * cur_bucket),
1335
+ assumed, key);
1336
+ if (old_elem == assumed) {
1337
+ *(data + 2 * cur_bucket + 1 ) = value;
1338
+ return ;
1339
+ }
1340
+ cur_bucket = (cur_bucket + bucket_inc) & new_num_buckets_mask;
1341
+ // Keep iterating until we find a free spot in the new hash...
1342
+ }
1343
+ });
1344
+ }
1345
+
1346
+ // The destructor checks that the hash is empty, if we are in debug mode.
1347
+ // If you don't want this, call Destroy() before the destructor is called.
1348
+ ~Hash64 () {
1349
+ #ifndef NDEBUG
1350
+ if (data_.Dim () != 0 ) CheckEmpty ();
1351
+ #endif
1352
+ }
1353
+
1354
+ private:
1355
+ Array1<uint64_t > data_;
1356
+
1357
+ // number satisfying data_.Dim() == 1 << (1+buckets_num_bitsm1_)
1358
+ uint64_t buckets_num_bitsm1_;
1359
+ };
1360
+
1017
1361
/*
1018
1362
Returns the number of bits needed for an unsigned integer sufficient to
1019
1363
store the nonnegative value `size`.
@@ -1029,7 +1373,6 @@ inline int32_t NumBitsNeededFor(int64_t size) {
1029
1373
return 1 + HighestBitSet (size);
1030
1374
}
1031
1375
1032
-
1033
1376
} // namespace k2
1034
1377
1035
1378
#endif // K2_CSRC_HASH_H_
0 commit comments