diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f5f9127cf2074a..eaa540389f791e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1663,7 +1663,7 @@ struct llama_sampler_dry { // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { - for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) { + for (llama_token token_id = 0; token_id < (llama_token) vocab.n_vocab(); token_id++) { std::string word = vocab.detokenize({token_id}, true); if (word.find(str) != std::string::npos) { token_sequences.emplace(token_id, std::vector()); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 549139c4dba9d9..55d107ba1125a2 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -208,7 +208,7 @@ struct llm_tokenizer_spm_session { return; } - if (static_cast(token) >= vocab.n_vocab) { + if (static_cast(token) >= vocab.n_vocab()) { return; } @@ -734,7 +734,7 @@ struct llm_tokenizer_ugm : llm_tokenizer { prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset; } - for (uint32_t id = 0; id < vocab.n_vocab; ++id) { + for (uint32_t id = 0; id < vocab.n_vocab(); ++id) { const auto & token_data = vocab.get_token_data(id); if (vocab.is_normal(id)) { @@ -1119,7 +1119,7 @@ struct llm_tokenizer_rwkv : llm_tokenizer { // For now, we decode the vocab here into the lookup we'll use for tokenization. // build trie - for (uint32_t id = 0; id < vocab.n_vocab; ++id) { + for (uint32_t id = 0; id < vocab.n_vocab(); ++id) { const auto & data = vocab.get_token_data(id); const auto text = llama_unescape_rwkv_token(data.text); token_matcher.insert((const char *) text.data(), text.size(), id); @@ -1204,6 +1204,8 @@ struct fragment_buffer_variant { }; struct llama_vocab::impl { + uint32_t n_vocab = 0; + std::unordered_map token_to_id; std::vector id_to_token; @@ -1283,6 +1285,13 @@ llama_vocab::~llama_vocab() { void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { struct gguf_context * ctx = ml.meta.get(); + auto & n_vocab = pimpl->n_vocab; + auto & id_to_token = pimpl->id_to_token; + auto & token_to_id = pimpl->token_to_id; + auto & special_eog_ids = pimpl->special_eog_ids; + auto & cache_special_tokens = pimpl->cache_special_tokens; + auto & cache_token_to_piece = pimpl->cache_token_to_piece; + // determine vocab type { std::string tokenizer_model; @@ -1589,12 +1598,6 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); } - auto & id_to_token = pimpl->id_to_token; - auto & token_to_id = pimpl->token_to_id; - auto & special_eog_ids = pimpl->special_eog_ids; - auto & cache_special_tokens = pimpl->cache_special_tokens; - auto & cache_token_to_piece = pimpl->cache_token_to_piece; - n_vocab = gguf_get_arr_n(ctx, token_idx); id_to_token.resize(n_vocab); @@ -1908,7 +1911,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { // build special tokens cache { - for (llama_token id = 0; id < (llama_token)n_vocab; ++id) { + for (llama_token id = 0; id < (llama_token) n_vocab; ++id) { if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) { cache_special_tokens.push_back(id); } @@ -2002,6 +2005,10 @@ enum llama_vocab_pre_type llama_vocab::get_pre_type() const { return pre_type; } +uint32_t llama_vocab::n_vocab() const { + return (uint32_t) pimpl->id_to_token.size(); +} + std::string llama_vocab::type_name() const{ switch (type) { case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; @@ -2366,8 +2373,8 @@ int llama_vocab::max_token_text_len() const { void llama_vocab::print_info() const { LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, n_vocab); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) pimpl->bpe_ranks.size()); + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, pimpl->n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) pimpl->bpe_ranks.size()); auto & id_to_token = pimpl->id_to_token; auto & special_eog_ids = pimpl->special_eog_ids; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 434294258c9f23..84bd7c4402e7b3 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -4,9 +4,6 @@ #include #include -#include -#include -#include #include struct LLM_KV; @@ -19,8 +16,6 @@ struct llama_vocab { llama_token_attr attr; }; - uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab - llama_vocab(); ~llama_vocab(); @@ -29,6 +24,9 @@ struct llama_vocab { enum llama_vocab_type get_type() const; enum llama_vocab_pre_type get_pre_type() const; + // TODO: how to deduplicate with llama_hparams.n_vocab ? + uint32_t n_vocab() const; + std::string type_name() const; bool is_normal (llama_token id) const; diff --git a/src/llama.cpp b/src/llama.cpp index 3012692ce19a0d..1d36b592a5f681 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -66,7 +66,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam model.print_info(); if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE && - model.hparams.n_vocab != model.vocab.n_vocab) { + model.hparams.n_vocab != model.vocab.n_vocab()) { throw std::runtime_error("vocab size mismatch"); } @@ -8317,7 +8317,7 @@ static int llama_decode_impl( if (batch.token) { for (uint32_t i = 0; i < n_tokens_all; ++i) { - if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; } @@ -8652,7 +8652,7 @@ static int llama_encode_impl( if (batch.token) { for (uint32_t i = 0; i < n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; }