Skip to content

Commit 3cc74f1

Browse files
authored
Add Hash64 (#895)
* Add hash64 * Fix tests * Resize hash64 * Fix comments * fix typo
1 parent 854b792 commit 3cc74f1

File tree

2 files changed

+442
-3
lines changed

2 files changed

+442
-3
lines changed

k2/csrc/hash.h

+345-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2020 Xiaomi Corporation (authors: Daniel Povey)
2+
* Copyright 2020 Xiaomi Corporation (authors: Daniel Povey, Wei kang)
33
*
44
* See LICENSE for clarification regarding multiple authors
55
*
@@ -1014,6 +1014,350 @@ class Hash {
10141014
};
10151015

10161016

1017+
/*
1018+
How class Hash64 works:
1019+
1020+
- It can function as a map from key=uint64_t to value=uint64_t, you must
1021+
decide the number of buckets, when you create the hash, but you can resize
1022+
it (manually).
1023+
1024+
Note:
1025+
Each bucket contains a pair of key/value, each 64bits, key is stored at
1026+
data[2 * bucket_index] and value is stored at data[2 * bucket_index + 1].
1027+
1028+
Some constraints:
1029+
- You can store any (key,value) pair, except the pair where all the bits of
1030+
both key and value are set [that is used to mean "nothing here"]
1031+
- The number of buckets must always be a power of 2.
1032+
- When deleting values from the hash you must delete them all at
1033+
once (necessary because there is no concept of a "tombstone".
1034+
1035+
Some notes on usage:
1036+
1037+
You use it by: constructing it, obtaining its Accessor with GetAccessor();
1038+
and inside kernels (or host code), calling functions Insert(), Find() or
1039+
Delete() of the Accessor object. Resizing is not automatic; it is the
1040+
user's responsibility to make sure the hash does not get too full
1041+
(which could cause assertion failures in kernels, and will be very slow).
1042+
1043+
Some implementation notes:
1044+
- When accessing hash[key], we use bucket_index == key % num_buckets,
1045+
bucket_inc = 1 | (((key * 2) / num_buckets) ^ key).
1046+
- If the bucket at `bucket_index` is occupied, we look in locations
1047+
`(bucket_index + n * bucket_inc)%num_buckets` for n = 1, 2, ...;
1048+
this choice ensures that if multiple keys hash to the same bucket,
1049+
they don't all access the same sequence of locations; and bucket_inc
1050+
being odd ensures we eventually try all locations (of course for
1051+
reasonable hash occupancy levels, we shouldn't ever have to try
1052+
more than two or three).
1053+
1054+
*/
1055+
class Hash64 {
1056+
public:
1057+
/* Constructor. Context can be for CPU or GPU.
1058+
1059+
@param [in] num_buckets Number of buckets in the hash; must be
1060+
a power of 2 and >= 128 (this limit was arbitrarily chosen).
1061+
The number of items in the hash cannot exceed the number of
1062+
buckets, or the code will loop infinitely when you try to add
1063+
items; aim for less than 50% occupancy.
1064+
*/
1065+
Hash64(ContextPtr c, int64_t num_buckets) {
1066+
K2_CHECK_GE(num_buckets, 128);
1067+
data_ = Array1<uint64_t>(c, num_buckets * 2, ~(uint64_t)0);
1068+
int64_t n = 2;
1069+
for (buckets_num_bitsm1_ = 0; n < num_buckets;
1070+
n *= 2, buckets_num_bitsm1_++) {
1071+
}
1072+
K2_CHECK_EQ(num_buckets, 2 << buckets_num_bitsm1_)
1073+
<< " num_buckets must be a power of 2.";
1074+
}
1075+
1076+
// Only to be used prior to assignment.
1077+
Hash64() = default;
1078+
1079+
int64_t NumBuckets() const { return data_.Dim() / 2; }
1080+
1081+
// Returns data pointer; for testing..
1082+
uint64_t *Data() { return data_.Data(); }
1083+
1084+
// Shallow copy
1085+
Hash64 &operator=(const Hash64 &src) = default;
1086+
// Copy constructor (shallow copy)
1087+
explicit Hash64(const Hash64 &src) = default;
1088+
1089+
ContextPtr &Context() const { return data_.Context(); }
1090+
1091+
class Accessor {
1092+
public:
1093+
Accessor(Hash64 &hash)
1094+
: data_(hash.data_.Data()),
1095+
num_buckets_mask_(uint64_t(hash.NumBuckets()) - 1),
1096+
buckets_num_bitsm1_(hash.buckets_num_bitsm1_) {}
1097+
1098+
// Copy constructor
1099+
Accessor(const Accessor &src) = default;
1100+
1101+
/*
1102+
Try to insert pair (key,value) into hash.
1103+
@param [in] key Key into hash, it is an error if ~key == 0, i.e. if all
1104+
the allowed bits of `key` are set.
1105+
@param [in] value Value to set, it is an error if ~value == 0, i.e. if
1106+
all the allowed bits `value` are set.
1107+
@param [out] old_value If not nullptr, this location will be set to
1108+
the existing value *if this key was already present* in the
1109+
hash (or set by another thread in this kernel), i.e. only
1110+
if this function returns false.
1111+
@param [out] key_value_location If not nullptr, its contents will be
1112+
set to the address of the (key,value) pair (either the
1113+
existing or newly-written one).
1114+
@return Returns true if this (key,value) pair was inserted, false
1115+
otherwise.
1116+
1117+
Note: the const is with respect to the metadata only; it is required, to
1118+
avoid compilation errors.
1119+
*/
1120+
__forceinline__ __host__ __device__ bool Insert(
1121+
uint64_t key, uint64_t value, uint64_t *old_value = nullptr,
1122+
uint64_t **key_value_location = nullptr) const {
1123+
uint64_t cur_bucket = key & num_buckets_mask_,
1124+
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
1125+
1126+
while (1) {
1127+
uint64_t cur_key = data_[2 * cur_bucket];
1128+
uint64_t cur_value = data_[2 * cur_bucket + 1];
1129+
if (cur_key == key) {
1130+
if (old_value) *old_value = cur_value;
1131+
if (key_value_location) *key_value_location = data_ + 2 * cur_bucket;
1132+
return false; // key exists in hash
1133+
} else if (~cur_key == 0) {
1134+
// we have a version of AtomicCAS that also works on host.
1135+
uint64_t old_key = AtomicCAS(
1136+
(unsigned long long *)(data_ + 2 * cur_bucket), cur_key, key);
1137+
if (old_key == cur_key) {
1138+
// set value
1139+
data_[2 * cur_bucket + 1] = value;
1140+
if (key_value_location)
1141+
*key_value_location = data_ + 2 * cur_bucket;
1142+
return true; // Successfully inserted.
1143+
}
1144+
if (old_key == key) {
1145+
if (old_value) *old_value = cur_value;
1146+
if (key_value_location)
1147+
*key_value_location = data_ + 2 * cur_bucket;
1148+
return false; // Another thread inserted this key
1149+
}
1150+
}
1151+
// Rotate bucket index until we find a free location. This will
1152+
// eventually visit all bucket indexes before it returns to the same
1153+
// location, because bucket_inc is odd (so only satisfies
1154+
// (n * bucket_inc) % num_buckets == 0 for n == num_buckets).
1155+
// Note: n here is the number of times we went around the loop.
1156+
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
1157+
}
1158+
}
1159+
1160+
/*
1161+
Look up this key in this hash; output the value and optionally the
1162+
location of the (key,value) pair if found.
1163+
1164+
@param [in] key Key to look up;
1165+
@param [out] value_out If found, value will be written to here. This may
1166+
seem redundant with key_value_location, but this should
1167+
compile to a local variable, and we want to avoid
1168+
redundant memory reads.
1169+
@param [out] key_value_location (optional) The memory address of the
1170+
(key,value) pair, in case the caller wants to overwrite
1171+
the value via SetValue(); must be used for no other
1172+
purpose.
1173+
@return Returns true if an item with this key was found in the
1174+
hash, otherwise false.
1175+
1176+
Note: the const is with respect to the metadata only; it is required, to
1177+
avoid compilation errors.
1178+
*/
1179+
__forceinline__ __host__ __device__ bool Find(
1180+
uint64_t key, uint64_t *value_out,
1181+
uint64_t **key_value_location = nullptr) const {
1182+
uint64_t cur_bucket = key & num_buckets_mask_,
1183+
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
1184+
while (1) {
1185+
uint64_t old_key = data_[2 * cur_bucket];
1186+
uint64_t old_value = data_[2 * cur_bucket + 1];
1187+
if (~old_key == 0) {
1188+
return false;
1189+
} else if (old_key == key) {
1190+
while (~old_value == 0) old_value = data_[2 * cur_bucket + 1];
1191+
*value_out = old_value;
1192+
if (key_value_location) *key_value_location = data_ + 2 * cur_bucket;
1193+
return true;
1194+
} else {
1195+
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
1196+
}
1197+
}
1198+
}
1199+
1200+
/*
1201+
Overwrite a value in a (key,value) pair whose location was obtained using
1202+
Find().
1203+
@param [in] key_value_location Location that was obtained from
1204+
a successful call to Find().
1205+
@param [in] value Value to write;
1206+
1207+
Note: the const is with respect to the metadata only; it is required, to
1208+
avoid compilation errors.
1209+
*/
1210+
__forceinline__ __host__ __device__ void SetValue(
1211+
uint64_t *key_value_location, uint64_t value) const {
1212+
*(key_value_location + 1) = value;
1213+
}
1214+
1215+
/* Deletes a key from a hash. Caution: this cannot be combined with other
1216+
operations on a hash; after you delete a key you cannot do Insert() or
1217+
Find() until you have deleted all keys. This is an open-addressing hash
1218+
table with no tombstones, which is why this limitation exists).
1219+
1220+
@param [in] key Key to be deleted. Each key present in the hash must
1221+
be deleted by exactly one thread, or it will loop
1222+
forever!
1223+
1224+
Note: the const is with respect to the metadata only; required, to avoid
1225+
compilation errors.
1226+
*/
1227+
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
1228+
uint64_t cur_bucket = key & num_buckets_mask_,
1229+
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
1230+
while (1) {
1231+
uint64_t old_key = data_[2 * cur_bucket];
1232+
if (old_key == key) {
1233+
data_[2 * cur_bucket] = ~((uint64_t)0);
1234+
data_[2 * cur_bucket + 1] = ~((uint64_t)0);
1235+
return;
1236+
} else {
1237+
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
1238+
}
1239+
}
1240+
}
1241+
1242+
private:
1243+
// pointer to data
1244+
uint64_t *data_;
1245+
// num_buckets_mask is num_buckets (i.e. size of `data_` array) minus one;
1246+
// num_buckets is a power of 2 so this can be used as a mask to get a number
1247+
// modulo num_buckets.
1248+
uint64_t num_buckets_mask_;
1249+
// A number satisfying num_buckets == 1 << (1+buckets_num_bitsm1_)
1250+
// the number of bits in `num_buckets` minus one.
1251+
uint64_t buckets_num_bitsm1_;
1252+
};
1253+
1254+
/*
1255+
Return an Accessor object which can be used in kernel code (or on CPU if the
1256+
context is a CPU context).
1257+
*/
1258+
Accessor GetAccessor() { return Accessor(*this); }
1259+
1260+
// You should call this before the destructor is called if the hash will still
1261+
// contain values when it is destroyed, to bypass a check.
1262+
void Destroy() { data_ = Array1<uint64_t>(); }
1263+
1264+
void CheckEmpty() const {
1265+
if (data_.Dim() == 0) return;
1266+
ContextPtr c = Context();
1267+
Array1<int64_t> error(c, 1, -1);
1268+
int64_t *error_data = error.Data();
1269+
const uint64_t *hash_data = data_.Data();
1270+
1271+
K2_EVAL(
1272+
Context(), data_.Dim(), lambda_check_data, (int64_t i)->void {
1273+
if (~(hash_data[i]) != 0) error_data[0] = i;
1274+
});
1275+
int64_t i = error[0];
1276+
if (i >= 0) { // there was an error; i is the index into the hash where
1277+
// there was an element.
1278+
int64_t elem = data_[i];
1279+
// We don't know the number of bits the user was using for the key vs.
1280+
// value, so print in hex, maybe they can figure it out.
1281+
K2_LOG(FATAL) << "Destroying hash: still contains values: position " << i
1282+
<< ", content = " << std::hex << elem;
1283+
}
1284+
}
1285+
1286+
/* Resize the hash to a new number of buckets.
1287+
1288+
@param [in] new_num_buckets New number of buckets; must be a power of 2,
1289+
and must be large enough to accommodate all values in the hash
1290+
(we assume the caller is keeping track of the number of elements
1291+
in the hash somehow).
1292+
1293+
CAUTION: Resizing will invalidate any accessor objects you have; you need
1294+
to re-get the accessors before accessing the hash again.
1295+
*/
1296+
void Resize(int64_t new_num_buckets, bool copy_data = true) {
1297+
NVTX_RANGE(K2_FUNC);
1298+
1299+
K2_CHECK_GT(new_num_buckets, 0);
1300+
K2_CHECK_EQ(new_num_buckets & (new_num_buckets - 1), 0); // power of 2.
1301+
1302+
ContextPtr c = data_.Context();
1303+
Hash64 new_hash(c, new_num_buckets);
1304+
1305+
if (copy_data) {
1306+
new_hash.CopyDataFromSimple(*this);
1307+
}
1308+
1309+
*this = new_hash;
1310+
new_hash.Destroy(); // avoid failed check in destructor (it would otherwise
1311+
// expect the hash to be empty when destroyed).
1312+
}
1313+
1314+
/*
1315+
Copies all data elements from `src` to `*this`.
1316+
*/
1317+
void CopyDataFromSimple(Hash64 &src) {
1318+
NVTX_RANGE(K2_FUNC);
1319+
int64_t num_buckets = data_.Dim() / 2,
1320+
src_num_buckets = src.data_.Dim() / 2;
1321+
const uint64_t *src_data = src.data_.Data();
1322+
uint64_t *data = data_.Data();
1323+
uint64_t new_num_buckets_mask = static_cast<uint64_t>(num_buckets) - 1,
1324+
new_buckets_num_bitsm1 = buckets_num_bitsm1_;
1325+
ContextPtr c = data_.Context();
1326+
K2_EVAL(c, src_num_buckets, lambda_copy_data, (uint64_t i) -> void {
1327+
uint64_t key = src_data[2 * i];
1328+
uint64_t value = src_data[2 * i + 1];
1329+
if (~key == 0) return; // equals -1.. nothing there.
1330+
uint64_t bucket_inc = 1 | ((key >> new_buckets_num_bitsm1) ^ key);
1331+
uint64_t cur_bucket = key & new_num_buckets_mask;
1332+
while (1) {
1333+
uint64_t assumed = ~((uint64_t)0),
1334+
old_elem = AtomicCAS((unsigned long long*)(data + 2 * cur_bucket),
1335+
assumed, key);
1336+
if (old_elem == assumed) {
1337+
*(data + 2 * cur_bucket + 1) = value;
1338+
return;
1339+
}
1340+
cur_bucket = (cur_bucket + bucket_inc) & new_num_buckets_mask;
1341+
// Keep iterating until we find a free spot in the new hash...
1342+
}
1343+
});
1344+
}
1345+
1346+
// The destructor checks that the hash is empty, if we are in debug mode.
1347+
// If you don't want this, call Destroy() before the destructor is called.
1348+
~Hash64() {
1349+
#ifndef NDEBUG
1350+
if (data_.Dim() != 0) CheckEmpty();
1351+
#endif
1352+
}
1353+
1354+
private:
1355+
Array1<uint64_t> data_;
1356+
1357+
// number satisfying data_.Dim() == 1 << (1+buckets_num_bitsm1_)
1358+
uint64_t buckets_num_bitsm1_;
1359+
};
1360+
10171361
/*
10181362
Returns the number of bits needed for an unsigned integer sufficient to
10191363
store the nonnegative value `size`.
@@ -1029,7 +1373,6 @@ inline int32_t NumBitsNeededFor(int64_t size) {
10291373
return 1 + HighestBitSet(size);
10301374
}
10311375

1032-
10331376
} // namespace k2
10341377

10351378
#endif // K2_CSRC_HASH_H_

0 commit comments

Comments
 (0)