@@ -69,8 +69,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
69
69
std::vector<int32_t > punctuations;
70
70
int32_t last = -1 ;
71
71
for (int32_t i = 0 ; i != num_segments; ++i) {
72
- int32_t this_start = i * segment_size; // inclusive
73
- int32_t this_end = this_start + segment_size; // exclusive
72
+ int32_t this_start = i * segment_size; // included
73
+ int32_t this_end = this_start + segment_size; // not included
74
74
if (this_end > static_cast <int32_t >(token_ids.size ())) {
75
75
this_end = token_ids.size ();
76
76
}
@@ -113,7 +113,8 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
113
113
int32_t dot_index = -1 ;
114
114
int32_t comma_index = -1 ;
115
115
116
- for (int32_t m = this_punctuations.size () - 2 ; m >= 1 ; --m) {
116
+ for (int32_t m = static_cast <int32_t >(this_punctuations.size ()) - 2 ;
117
+ m >= 1 ; --m) {
117
118
int32_t punct_id = this_punctuations[m];
118
119
119
120
if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id ) {
@@ -137,13 +138,13 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
137
138
}
138
139
139
140
if (i == num_segments - 1 ) {
140
- dot_index = token_ids .size () - 1 ;
141
+ dot_index = static_cast < int32_t >(this_punctuations .size () ) - 1 ;
141
142
}
142
143
} else {
143
144
last = this_start + dot_index + 1 ;
144
145
}
145
146
146
- if (dot_index != 1 ) {
147
+ if (dot_index != - 1 ) {
147
148
punctuations.insert (punctuations.end (), this_punctuations.begin (),
148
149
this_punctuations.begin () + (dot_index + 1 ));
149
150
}
0 commit comments