Skip to content

Commit 1b29f0a

Browse files
authored
Fix precision (#951)
* Fix precision * Using different pow version for windows and *nix * Use int64_t pow * Minor fixes
1 parent 3b83183 commit 1b29f0a

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

k2/csrc/math.h

+19-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@
2727

2828
namespace k2 {
2929

30+
// Currently, only used in k2/csrc/rnnt_decode.cu
31+
// See https://github.com/k2-fsa/k2/pull/951#issuecomment-1096650842
32+
__host__ __device__ __forceinline__ int64_t Pow(int64_t base,
33+
int64_t exponent) {
34+
K2_CHECK_GE(exponent, 0);
35+
int64_t exp = 0;
36+
int64_t result = 1;
37+
while (exp < exponent) {
38+
result *= base;
39+
exp++;
40+
}
41+
return result;
42+
}
43+
3044
/*
3145
Returns index of highest bit set, in range -1..30.
3246
HighestBitSet(0) = -1,
@@ -106,29 +120,29 @@ int32_t RandIntGeometric(int32_t min, int32_t max);
106120
type, but for types float and double it "fixes" the broken behavior of
107121
the C++ standard w.r.t. infinity allowing infinities to be parsed.
108122
*/
109-
template<class T> struct InputFixer {
123+
template <class T>
124+
struct InputFixer {
110125
T t;
111126
// cast operator
112127
operator T() const { return t; }
113128
};
114129

115-
116130
namespace internal {
117131
template <typename Real>
118132
Real FixedRead(std::istream &is);
119133
}
120134

121135
template <typename T>
122-
inline std::istream &operator >>(std::istream &is, InputFixer<T> &f) {
136+
inline std::istream &operator>>(std::istream &is, InputFixer<T> &f) {
123137
return is >> f.t;
124138
}
125139
template <>
126-
inline std::istream &operator >>(std::istream &is, InputFixer<float> &f) {
140+
inline std::istream &operator>>(std::istream &is, InputFixer<float> &f) {
127141
f.t = internal::FixedRead<float>(is);
128142
return is;
129143
}
130144
template <>
131-
inline std::istream &operator >>(std::istream &is, InputFixer<double> &f) {
145+
inline std::istream &operator>>(std::istream &is, InputFixer<double> &f) {
132146
f.t = internal::FixedRead<double>(is);
133147
return is;
134148
}

k2/csrc/rnnt_decode.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ void RnntDecodingStreams::GetContexts(RaggedShape *shape,
159159
int64_t state_value = states_values_data[state_idx01x],
160160
context_state = state_value / num_graph_states,
161161
exp = decoder_history_len - col,
162-
state = context_state % (int64_t)powf(vocab_size, exp);
163-
state = state / (int64_t)powf(vocab_size, exp - 1);
162+
state = context_state % Pow(vocab_size, exp);
163+
state = state / Pow(vocab_size, exp - 1);
164164
contexts_acc(row, col) = state;
165165
});
166166
}
@@ -540,7 +540,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
540540
// can be done with `358 % 10^2`, then we append 6 to 58, that can be
541541
// done with `58 * 10 + 6`.
542542
context_state = this_context_state %
543-
(int64_t)powf(vocab_size, decoder_history_len - 1);
543+
Pow(vocab_size, decoder_history_len - 1);
544544
context_state = context_state * vocab_size + arc.label;
545545
}
546546

0 commit comments

Comments
 (0)