Skip to content

Commit

Permalink
llama : vocab cleanup
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Jan 8, 2025
1 parent 0f71186 commit 403dee8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ struct llama_sampler_dry {

// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_vocab(); token_id++) {
std::string word = vocab.detokenize({token_id}, true);
if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>());
Expand Down
31 changes: 19 additions & 12 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ struct llm_tokenizer_spm_session {
return;
}

if (static_cast<uint32_t>(token) >= vocab.n_vocab) {
if (static_cast<uint32_t>(token) >= vocab.n_vocab()) {
return;
}

Expand Down Expand Up @@ -734,7 +734,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
}

for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
const auto & token_data = vocab.get_token_data(id);

if (vocab.is_normal(id)) {
Expand Down Expand Up @@ -1119,7 +1119,7 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
// For now, we decode the vocab here into the lookup we'll use for tokenization.

// build trie
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
const auto & data = vocab.get_token_data(id);
const auto text = llama_unescape_rwkv_token(data.text);
token_matcher.insert((const char *) text.data(), text.size(), id);
Expand Down Expand Up @@ -1204,6 +1204,8 @@ struct fragment_buffer_variant {
};

struct llama_vocab::impl {
uint32_t n_vocab = 0;

std::unordered_map<std::string, llama_token> token_to_id;
std::vector<token_data> id_to_token;

Expand Down Expand Up @@ -1283,6 +1285,13 @@ llama_vocab::~llama_vocab() {
void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
struct gguf_context * ctx = ml.meta.get();

auto & n_vocab = pimpl->n_vocab;
auto & id_to_token = pimpl->id_to_token;
auto & token_to_id = pimpl->token_to_id;
auto & special_eog_ids = pimpl->special_eog_ids;
auto & cache_special_tokens = pimpl->cache_special_tokens;
auto & cache_token_to_piece = pimpl->cache_token_to_piece;

// determine vocab type
{
std::string tokenizer_model;
Expand Down Expand Up @@ -1589,12 +1598,6 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
}

auto & id_to_token = pimpl->id_to_token;
auto & token_to_id = pimpl->token_to_id;
auto & special_eog_ids = pimpl->special_eog_ids;
auto & cache_special_tokens = pimpl->cache_special_tokens;
auto & cache_token_to_piece = pimpl->cache_token_to_piece;

n_vocab = gguf_get_arr_n(ctx, token_idx);
id_to_token.resize(n_vocab);

Expand Down Expand Up @@ -1908,7 +1911,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {

// build special tokens cache
{
for (llama_token id = 0; id < (llama_token)n_vocab; ++id) {
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
cache_special_tokens.push_back(id);
}
Expand Down Expand Up @@ -2002,6 +2005,10 @@ enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
return pre_type;
}

uint32_t llama_vocab::n_vocab() const {
return (uint32_t) pimpl->id_to_token.size();
}

std::string llama_vocab::type_name() const{
switch (type) {
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
Expand Down Expand Up @@ -2366,8 +2373,8 @@ int llama_vocab::max_token_text_len() const {

void llama_vocab::print_info() const {
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) pimpl->bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, pimpl->n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) pimpl->bpe_ranks.size());

auto & id_to_token = pimpl->id_to_token;
auto & special_eog_ids = pimpl->special_eog_ids;
Expand Down
8 changes: 3 additions & 5 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

#include <string>
#include <vector>
#include <unordered_map>
#include <map>
#include <set>
#include <memory>

struct LLM_KV;
Expand All @@ -19,8 +16,6 @@ struct llama_vocab {
llama_token_attr attr;
};

uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab

llama_vocab();
~llama_vocab();

Expand All @@ -29,6 +24,9 @@ struct llama_vocab {
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;

// TODO: how to deduplicate with llama_hparams.n_vocab ?
uint32_t n_vocab() const;

std::string type_name() const;

bool is_normal (llama_token id) const;
Expand Down
6 changes: 3 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
model.print_info();

if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
model.hparams.n_vocab != model.vocab.n_vocab) {
model.hparams.n_vocab != model.vocab.n_vocab()) {
throw std::runtime_error("vocab size mismatch");
}

Expand Down Expand Up @@ -8317,7 +8317,7 @@ static int llama_decode_impl(

if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
Expand Down Expand Up @@ -8652,7 +8652,7 @@ static int llama_encode_impl(

if (batch.token) {
for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
Expand Down

0 comments on commit 403dee8

Please sign in to comment.