diff --git a/gptme/commands.py b/gptme/commands.py index 70552962..c7613db2 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -134,7 +134,10 @@ def handle_cmd( yield from execute_msg(msg, confirm=lambda _: True) case "tokens": manager.undo(1, quiet=True) - n_tokens = len_tokens(manager.log.messages) + model = get_model() + n_tokens = len_tokens( + manager.log.messages, model.model if model else "gpt-4" + ) print(f"Tokens used: {n_tokens}") model = get_model() if model: @@ -149,7 +152,7 @@ def handle_cmd( f""" # {tool.name} {tool.desc.rstrip(".")} - tokens (example): {len_tokens(tool.get_examples(get_tool_format()))}""" + tokens (example): {len_tokens(tool.get_examples(get_tool_format()), "gpt-4")}""" ) case "export": manager.undo(1, quiet=True) diff --git a/gptme/eval/main.py b/gptme/eval/main.py index 357b8256..7af64f9c 100644 --- a/gptme/eval/main.py +++ b/gptme/eval/main.py @@ -50,7 +50,8 @@ def print_model_results(model_results: dict[str, list[EvalResult]]): for model, results in model_results.items(): print(f"\nResults for model: {model}") model_total_tokens = sum( - len_tokens(result.gen_stdout) + len_tokens(result.run_stdout) + len_tokens(result.gen_stdout, "gpt-4") + + len_tokens(result.run_stdout, "gpt-4") for result in results ) print(f"Completed {len(results)} tests in {model_total_tokens}tok:") @@ -60,8 +61,8 @@ def print_model_results(model_results: dict[str, list[EvalResult]]): duration_result = ( result.timings["gen"] + result.timings["run"] + result.timings["eval"] ) - gen_tokens = len_tokens(result.gen_stdout) - run_tokens = len_tokens(result.run_stdout) + gen_tokens = len_tokens(result.gen_stdout, "gpt-4") + run_tokens = len_tokens(result.run_stdout, "gpt-4") result_total_tokens = gen_tokens + run_tokens print( f"{checkmark} {result.name}: {duration_result:.0f}s/{result_total_tokens}tok " @@ -94,8 +95,8 @@ def print_model_results_table(model_results: dict[str, list[EvalResult]]): passed = all(case.passed for case in result.results) checkmark = "✅" if result.status == "success" and passed else "❌" duration = sum(result.timings.values()) - gen_tokens = len_tokens(result.gen_stdout) - run_tokens = len_tokens(result.run_stdout) + gen_tokens = len_tokens(result.gen_stdout, "gpt-4") + run_tokens = len_tokens(result.run_stdout, "gpt-4") reason = "timeout" if result.status == "timeout" else "" if reason: row.append(f"{checkmark} {reason}") @@ -125,8 +126,8 @@ def aggregate_and_display_results(result_files: list[str]): } all_results[model][result.name]["total"] += 1 all_results[model][result.name]["tokens"] += len_tokens( - result.gen_stdout - ) + len_tokens(result.run_stdout) + result.gen_stdout, "gpt-4" + ) + len_tokens(result.run_stdout, "gpt-4") if result.status == "success" and all( case.passed for case in result.results ): diff --git a/gptme/llm/__init__.py b/gptme/llm/__init__.py index 13ee878c..a97902d8 100644 --- a/gptme/llm/__init__.py +++ b/gptme/llm/__init__.py @@ -136,7 +136,7 @@ def print_clear(): f"Generation interrupted after {end_time - start_time:.1f}s " f"(ttft: {first_token_time - start_time:.2f}s, " f"gen: {end_time - first_token_time:.2f}s, " - f"tok/s: {len_tokens(output)/(end_time - first_token_time):.1f})" + f"tok/s: {len_tokens(output, model)/(end_time - first_token_time):.1f})" ) return Message("assistant", output) @@ -179,15 +179,15 @@ def _summarize_str(content: str) -> str: provider = _client_to_provider() model = get_summary_model(provider) context_limit = MODELS[provider][model]["context"] - if len_tokens(messages) > context_limit: + if len_tokens(messages, model) > context_limit: raise ValueError( - f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages)}" + f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages, model)}" ) summary = _chat_complete(messages, model, None) assert summary logger.debug( - f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): " + f"Summarized long output ({len_tokens(content, model)} -> {len_tokens(summary, model)} tokens): " + summary ) return summary @@ -260,7 +260,8 @@ def _summarize_helper(s: str, tok_max_start=400, tok_max_end=400) -> str: Helper function for summarizing long outputs. Truncates long outputs, then summarizes. """ - if len_tokens(s) > tok_max_start + tok_max_end: + # Use gpt-4 as default model for summarization helper + if len_tokens(s, "gpt-4") > tok_max_start + tok_max_end: beginning = " ".join(s.split()[:tok_max_start]) end = " ".join(s.split()[-tok_max_end:]) summary = _summarize_str(beginning + "\n...\n" + end) diff --git a/gptme/logmanager.py b/gptme/logmanager.py index e4794aec..5480e45e 100644 --- a/gptme/logmanager.py +++ b/gptme/logmanager.py @@ -356,16 +356,19 @@ def prepare_messages( - Enhances it with context such as file contents - Transforms it to the format expected by LLM providers """ + from .llm.models import get_model # fmt: skip + # Enrich with enabled context enhancements (RAG, fresh context) msgs = enrich_messages_with_context(msgs, workspace) # Then reduce and limit as before msgs_reduced = list(reduce_log(msgs)) - if len_tokens(msgs) != len_tokens(msgs_reduced): - logger.info( - f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens" - ) + model = get_model() + if (len_from := len_tokens(msgs, model.model)) != ( + len_to := len_tokens(msgs_reduced, model.model) + ): + logger.info(f"Reduced log from {len_from//1} to {len_to//1} tokens") msgs_limited = limit_log(msgs_reduced) if len(msgs_reduced) != len(msgs_limited): logger.info( diff --git a/gptme/message.py b/gptme/message.py index 5d93288f..32d0feaf 100644 --- a/gptme/message.py +++ b/gptme/message.py @@ -1,4 +1,5 @@ import dataclasses +import hashlib import logging import shutil import sys @@ -19,11 +20,6 @@ logger = logging.getLogger(__name__) -# max tokens allowed in a single system message -# if you hit this limit, you and/or I f-ed up, and should make the message shorter -# maybe we should make it possible to store long outputs in files, and link/summarize it/preview it in the message -max_system_len = 20000 - @dataclass(frozen=True, eq=False) class Message: @@ -51,9 +47,6 @@ class Message: def __post_init__(self): assert isinstance(self.timestamp, datetime) - if self.role == "system": - if (length := len_tokens(self)) >= max_system_len: - logger.warning(f"System message too long: {length} tokens") def __repr__(self): content = textwrap.shorten(self.content, 20, placeholder="...") @@ -302,11 +295,41 @@ def msgs2dicts(msgs: list[Message]) -> list[dict]: return [msg.to_dict(keys=["role", "content", "files"]) for msg in msgs] -# TODO: remove model assumption -def len_tokens(content: str | Message | list[Message], model: str = "gpt-4") -> int: - """Get the number of tokens in a string, message, or list of messages.""" +# Global cache mapping hashes to token counts +_token_cache: dict[tuple[str, str], int] = {} + + +def _hash_content(content: str) -> str: + """Create a hash of the content""" + return hashlib.sha256(content.encode()).hexdigest() + + +def len_tokens(content: str | Message | list[Message], model: str) -> int: + """Get the number of tokens in a string, message, or list of messages. + + Uses efficient caching with content hashing to minimize memory usage while + maintaining fast repeated calculations, which is especially important for + conversations with many messages. + """ if isinstance(content, list): - return sum(len_tokens(msg.content, model) for msg in content) + return sum(len_tokens(msg, model) for msg in content) if isinstance(content, Message): - return len_tokens(content.content, model) - return len(get_tokenizer(model).encode(content)) + content = content.content + + assert isinstance(content, str), content + # Check cache using hash + content_hash = _hash_content(content) + cache_key = (content_hash, model) + if cache_key in _token_cache: + return _token_cache[cache_key] + + # Calculate and cache + count = len(get_tokenizer(model).encode(content)) + _token_cache[cache_key] = count + + # Limit cache size by removing oldest entries if needed + if len(_token_cache) > 1000: + # Remove first item (oldest in insertion order) + _token_cache.pop(next(iter(_token_cache))) + + return count diff --git a/gptme/util/__init__.py b/gptme/util/__init__.py index 53b8cde9..5830a287 100644 --- a/gptme/util/__init__.py +++ b/gptme/util/__init__.py @@ -26,6 +26,7 @@ _warned_models = set() +@lru_cache def get_tokenizer(model: str): import tiktoken # fmt: skip @@ -184,7 +185,8 @@ def decorator(func): # pragma: no cover prompt = "\n\n".join([msg.content for msg in func(*args, **kwargs)]) prompt = textwrap.indent(prompt, " ") - prompt_tokens = len_tokens(prompt) + # Use a default model for documentation purposes + prompt_tokens = len_tokens(prompt, model="gpt-4") kwargs_str = ( (" (" + ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) + ")") if kwargs diff --git a/gptme/util/cost.py b/gptme/util/cost.py index 0868296e..631f1d06 100644 --- a/gptme/util/cost.py +++ b/gptme/util/cost.py @@ -3,11 +3,14 @@ def _tokens_inout(msgs: list[Message]) -> tuple[int, int]: - tokens_in, tokens_out = len_tokens(msgs[:-1]), 0 + from ..llm.models import get_model # fmt: skip + + model = get_model() + tokens_in, tokens_out = len_tokens(msgs[:-1], model.model), 0 if msgs[-1].role == "assistant": - tokens_out = len_tokens(msgs[-1]) + tokens_out = len_tokens(msgs[-1], model.model) else: - tokens_in += len_tokens(msgs[-1]) + tokens_in += len_tokens(msgs[-1], model.model) return tokens_in, tokens_out diff --git a/gptme/util/reduce.py b/gptme/util/reduce.py index 0b31616e..e88e8bf2 100644 --- a/gptme/util/reduce.py +++ b/gptme/util/reduce.py @@ -8,8 +8,8 @@ from collections.abc import Generator from ..codeblock import Codeblock +from ..llm.models import DEFAULT_MODEL, get_model from ..message import Message, len_tokens -from ..llm.models import get_model logger = logging.getLogger(__name__) @@ -21,11 +21,12 @@ def reduce_log( ) -> Generator[Message, None, None]: """Reduces log until it is below `limit` tokens by continually summarizing the longest messages until below the limit.""" # get the token limit + model = DEFAULT_MODEL or get_model("gpt-4") if limit is None: - limit = 0.9 * get_model().context + limit = 0.9 * model.context # if we are below the limit, return the log as-is - tokens = len_tokens(log) + tokens = len_tokens(log, model=model.model) if tokens <= limit: yield from log return @@ -34,7 +35,7 @@ def reduce_log( # filter out pinned messages i, longest_msg = max( [(i, m) for i, m in enumerate(log) if not m.pinned], - key=lambda t: len_tokens(t[1].content), + key=lambda t: len_tokens(t[1].content, model.model), ) # attempt to truncate the longest message @@ -53,7 +54,7 @@ def reduce_log( log = log[:i] + [summary_msg] + log[i + 1 :] - tokens = len_tokens(log) + tokens = len_tokens(log, model.model) if tokens <= limit: yield from log else: @@ -105,7 +106,7 @@ def limit_log(log: list[Message]) -> list[Message]: then removes the last message to get below the limit. Will always pick the first few system messages. """ - limit = get_model().context + model = get_model() # Always pick the first system messages initial_system_msgs = [] @@ -118,11 +119,11 @@ def limit_log(log: list[Message]) -> list[Message]: msgs = [] for msg in reversed(log[len(initial_system_msgs) :]): msgs.append(msg) - if len_tokens(msgs) > limit: + if len_tokens(msgs, model.model) > model.context: break # Remove the message that put us over the limit - if len_tokens(msgs) > limit: + if len_tokens(msgs, model.model) > model.context: # skip the last message msgs.pop() diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 0b9ae424..324964c1 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -12,12 +12,12 @@ def init(): def test_get_prompt_full(): prompt = get_prompt("full") # TODO: lower this significantly by selectively removing examples from the full prompt - assert 500 < len_tokens(prompt.content) < 5000 + assert 500 < len_tokens(prompt.content, "gpt-4") < 5000 def test_get_prompt_short(): prompt = get_prompt("short") - assert 500 < len_tokens(prompt.content) < 2000 + assert 500 < len_tokens(prompt.content, "gpt-4") < 2000 def test_get_prompt_custom(): diff --git a/tests/test_reduce.py b/tests/test_reduce.py index 271263bf..3ff65ae6 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -21,10 +21,10 @@ def test_truncate_msg(): - len_pre = len_tokens(long_msg) + len_pre = len_tokens(long_msg, "gpt-4") truncated = truncate_msg(long_msg) assert truncated is not None - len_post = len_tokens(truncated) + len_post = len_tokens(truncated, "gpt-4") assert len_pre > len_post assert "[...]" in truncated.content assert "```cli.py" in truncated.content @@ -38,12 +38,12 @@ def test_reduce_log(): Message("user", content=" ".join(fn.name for fn in [readme, cli, htmlindex])), long_msg, ] - len_pre = len_tokens(msgs) + len_pre = len_tokens(msgs, "gpt-4") print(f"{len_pre=}") limit = 1000 reduced = list(reduce_log(msgs, limit=limit)) - len_post = len_tokens(reduced) + len_post = len_tokens(reduced, "gpt-4") print(f"{len_post=}") print(f"{reduced[-1].content=}")