Skip to content

Commit

Permalink
Merge pull request ggml-org#9 from anon998/stopping-strings
Browse files Browse the repository at this point in the history
Fix stopping strings.
  • Loading branch information
digiwombat authored Jun 1, 2023
2 parents 342604b + e9b1f0b commit 5f6e16d
Showing 1 changed file with 78 additions and 16 deletions.
94 changes: 78 additions & 16 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
return i;
}

enum stop_type {
STOP_FULL,
STOP_PARTIAL,
};

bool ends_with(const std::string &str, const std::string &suffix)
{
return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}

size_t find_partial_stop_string(const std::string &stop, const std::string &text)
{
if (!text.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}

struct llama_server_context
{
bool stream = false;
Expand Down Expand Up @@ -248,6 +275,31 @@ struct llama_server_context
return result;
}

size_t findStoppingStrings(const std::string &text, const size_t last_token_size,
const stop_type type)
{
size_t stop_pos = std::string::npos;
for (const std::string &word : params.antiprompt) {
size_t pos;
if (type == STOP_FULL) {
const size_t tmp = word.size() + last_token_size;
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
pos = text.find(word, from_pos);
} else {
pos = find_partial_stop_string(word, text);
}
if (pos != std::string::npos &&
(stop_pos == std::string::npos || pos < stop_pos)) {
if (type == STOP_FULL) {
stopping_word = word;
has_next_token = false;
}
stop_pos = pos;
}
}
return stop_pos;
}

std::string doCompletion()
{
llama_token token = nextToken();
Expand All @@ -272,16 +324,6 @@ struct llama_server_context
stopping_word.c_str());
}

for (const std::string& word : params.antiprompt) {
size_t i = generated_text.find(word, generated_text.size() - (word.size() + token_text.size()));
if (i != std::string::npos) {
generated_text.erase(generated_text.begin() + i, generated_text.end());
stopping_word = word;
has_next_token = false;
break;
}
}

return token_text;
}

Expand Down Expand Up @@ -711,7 +753,14 @@ int main(int argc, char **argv)

if (!llama.stream) {
while (llama.has_next_token) {
llama.doCompletion();
const std::string token_text = llama.doCompletion();
const size_t stop_pos = llama.findStoppingStrings(
llama.generated_text, token_text.size(), STOP_FULL);

if (stop_pos != std::string::npos) {
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
llama.generated_text.end());
}
}

json data = {{"content", llama.generated_text},
Expand All @@ -724,7 +773,7 @@ int main(int argc, char **argv)

llama_print_timings(llama.ctx);

return res.set_content(
res.set_content(
data.dump(llama.json_indent, ' ', false, json::error_handler_t::replace),
"application/json");
} else {
Expand All @@ -733,7 +782,7 @@ int main(int argc, char **argv)
int32_t multibyte_pending = 0;

while (llama.has_next_token) {
std::string token_text = llama.doCompletion();
const std::string token_text = llama.doCompletion();

if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
Expand Down Expand Up @@ -761,8 +810,22 @@ int main(int argc, char **argv)
continue;
}

const size_t pos = std::min(sent_count, llama.generated_text.size());
std::string to_send = llama.generated_text.substr(pos);
size_t pos = std::min(sent_count, llama.generated_text.size());

const char *str_test = llama.generated_text.c_str() + pos;
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL);
}

std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size();

json data;
Expand Down Expand Up @@ -808,7 +871,6 @@ int main(int argc, char **argv)
}
});


svr.Post("/tokenize", [&llama](const Request &req, Response &res)
{
json body = json::parse(req.body);
Expand Down

0 comments on commit 5f6e16d

Please sign in to comment.