Skip to content

Commit cd97f77

Browse files
authored
Fix punctuation (k2-fsa#976)
1 parent 6c74975 commit cd97f77

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

.github/workflows/sanitizer.yaml

+8-6
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ jobs:
7676
otool -L build/bin/sherpa-onnx
7777
otool -l build/bin/sherpa-onnx
7878
79+
- name: Test offline punctuation
80+
shell: bash
81+
run: |
82+
export PATH=$PWD/build/bin:$PATH
83+
export EXE=sherpa-onnx-offline-punctuation
84+
85+
.github/scripts/test-offline-punctuation.sh
86+
7987
- name: Test offline transducer
8088
shell: bash
8189
run: |
@@ -92,13 +100,7 @@ jobs:
92100
93101
.github/scripts/test-online-ctc.sh
94102
95-
- name: Test offline punctuation
96-
shell: bash
97-
run: |
98-
export PATH=$PWD/build/bin:$PATH
99-
export EXE=sherpa-onnx-offline-punctuation
100103
101-
.github/scripts/test-offline-punctuation.sh
102104
103105
- name: Test C API
104106
shell: bash

sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
6969
std::vector<int32_t> punctuations;
7070
int32_t last = -1;
7171
for (int32_t i = 0; i != num_segments; ++i) {
72-
int32_t this_start = i * segment_size; // inclusive
73-
int32_t this_end = this_start + segment_size; // exclusive
72+
int32_t this_start = i * segment_size; // included
73+
int32_t this_end = this_start + segment_size; // not included
7474
if (this_end > static_cast<int32_t>(token_ids.size())) {
7575
this_end = token_ids.size();
7676
}
@@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
113113
int32_t dot_index = -1;
114114
int32_t comma_index = -1;
115115

116-
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
116+
for (int32_t m = static_cast<int32_t>(this_punctuations.size()) - 2;
117+
m >= 1; --m) {
117118
int32_t punct_id = this_punctuations[m];
118119

119120
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
@@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
137138
}
138139

139140
if (i == num_segments - 1) {
140-
dot_index = token_ids.size() - 1;
141+
dot_index = static_cast<int32_t>(this_punctuations.size()) - 1;
141142
}
142143
} else {
143144
last = this_start + dot_index + 1;
144145
}
145146

146-
if (dot_index != 1) {
147+
if (dot_index != -1) {
147148
punctuations.insert(punctuations.end(), this_punctuations.begin(),
148149
this_punctuations.begin() + (dot_index + 1));
149150
}

0 commit comments

Comments
 (0)