Skip to content

Commit a4e7814

Browse files
committed
Fix precision
1 parent 4fb6b88 commit a4e7814

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

k2/csrc/rnnt_decode.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ void RnntDecodingStreams::GetContexts(RaggedShape *shape,
159159
int64_t state_value = states_values_data[state_idx01x],
160160
context_state = state_value / num_graph_states,
161161
exp = decoder_history_len - col,
162-
state = context_state % (int64_t)powf(vocab_size, exp);
163-
state = state / (int64_t)powf(vocab_size, exp - 1);
162+
state = context_state % (int64_t)pow(vocab_size, exp);
163+
state = state / (int64_t)pow(vocab_size, exp - 1);
164164
contexts_acc(row, col) = state;
165165
});
166166
}
@@ -540,7 +540,7 @@ void RnntDecodingStreams::Advance(const Array2<float> &logprobs) {
540540
// can be done with `358 % 10^2`, then we append 6 to 58, that can be
541541
// done with `58 * 10 + 6`.
542542
context_state = this_context_state %
543-
(int64_t)powf(vocab_size, decoder_history_len - 1);
543+
(int64_t)pow(vocab_size, decoder_history_len - 1);
544544
context_state = context_state * vocab_size + arc.label;
545545
}
546546

0 commit comments

Comments
 (0)