Skip to content

Commit d64675f

Browse files
committed
Using different pow version for windows and *nix
1 parent a4e7814 commit d64675f

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

k2/csrc/math.h

+8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727

2828
namespace k2 {
2929

30+
// Currently, only used in k2/csrc/rnnt_decode.cu
31+
// See https://github.com/k2-fsa/k2/pull/951#issuecomment-1096650842
32+
#ifndef _MSC_VER
33+
#define K2_POW pow
34+
#else
35+
#define K2_POW powf
36+
#endif
37+
3038
/*
3139
Returns index of highest bit set, in range -1..30.
3240
HighestBitSet(0) = -1,

k2/csrc/rnnt_decode.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ void RnntDecodingStreams::GetContexts(RaggedShape *shape,
159159
int64_t state_value = states_values_data[state_idx01x],
160160
context_state = state_value / num_graph_states,
161161
exp = decoder_history_len - col,
162-
state = context_state % (int64_t)pow(vocab_size, exp);
163-
state = state / (int64_t)pow(vocab_size, exp - 1);
162+
state = context_state % (int64_t)K2_POW(vocab_size, exp);
163+
state = state / (int64_t)K2_POW(vocab_size, exp - 1);
164164
contexts_acc(row, col) = state;
165165
});
166166
}
@@ -540,7 +540,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
540540
// can be done with `358 % 10^2`, then we append 6 to 58, that can be
541541
// done with `58 * 10 + 6`.
542542
context_state = this_context_state %
543-
(int64_t)pow(vocab_size, decoder_history_len - 1);
543+
(int64_t)K2_POW(vocab_size, decoder_history_len - 1);
544544
context_state = context_state * vocab_size + arc.label;
545545
}
546546

0 commit comments

Comments
 (0)