Skip to content

Commit

Permalink
perf: optimize token counting with caching and explicit model param (#…
Browse files Browse the repository at this point in the history
…325)

- Add caching to len_tokens using content hashing
- Make get_tokenizer lru_cached
- Pass model parameter explicitly to len_tokens
- Remove redundant token counting
  • Loading branch information
ErikBjare authored Dec 11, 2024
1 parent 202869e commit c05f9c6
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 50 deletions.
7 changes: 5 additions & 2 deletions gptme/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def handle_cmd(
yield from execute_msg(msg, confirm=lambda _: True)
case "tokens":
manager.undo(1, quiet=True)
n_tokens = len_tokens(manager.log.messages)
model = get_model()
n_tokens = len_tokens(
manager.log.messages, model.model if model else "gpt-4"
)
print(f"Tokens used: {n_tokens}")
model = get_model()
if model:
Expand All @@ -149,7 +152,7 @@ def handle_cmd(
f"""
# {tool.name}
{tool.desc.rstrip(".")}
tokens (example): {len_tokens(tool.get_examples(get_tool_format()))}"""
tokens (example): {len_tokens(tool.get_examples(get_tool_format()), "gpt-4")}"""
)
case "export":
manager.undo(1, quiet=True)
Expand Down
15 changes: 8 additions & 7 deletions gptme/eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def print_model_results(model_results: dict[str, list[EvalResult]]):
for model, results in model_results.items():
print(f"\nResults for model: {model}")
model_total_tokens = sum(
len_tokens(result.gen_stdout) + len_tokens(result.run_stdout)
len_tokens(result.gen_stdout, "gpt-4")
+ len_tokens(result.run_stdout, "gpt-4")
for result in results
)
print(f"Completed {len(results)} tests in {model_total_tokens}tok:")
Expand All @@ -60,8 +61,8 @@ def print_model_results(model_results: dict[str, list[EvalResult]]):
duration_result = (
result.timings["gen"] + result.timings["run"] + result.timings["eval"]
)
gen_tokens = len_tokens(result.gen_stdout)
run_tokens = len_tokens(result.run_stdout)
gen_tokens = len_tokens(result.gen_stdout, "gpt-4")
run_tokens = len_tokens(result.run_stdout, "gpt-4")
result_total_tokens = gen_tokens + run_tokens
print(
f"{checkmark} {result.name}: {duration_result:.0f}s/{result_total_tokens}tok "
Expand Down Expand Up @@ -94,8 +95,8 @@ def print_model_results_table(model_results: dict[str, list[EvalResult]]):
passed = all(case.passed for case in result.results)
checkmark = "✅" if result.status == "success" and passed else "❌"
duration = sum(result.timings.values())
gen_tokens = len_tokens(result.gen_stdout)
run_tokens = len_tokens(result.run_stdout)
gen_tokens = len_tokens(result.gen_stdout, "gpt-4")
run_tokens = len_tokens(result.run_stdout, "gpt-4")
reason = "timeout" if result.status == "timeout" else ""
if reason:
row.append(f"{checkmark} {reason}")
Expand Down Expand Up @@ -125,8 +126,8 @@ def aggregate_and_display_results(result_files: list[str]):
}
all_results[model][result.name]["total"] += 1
all_results[model][result.name]["tokens"] += len_tokens(
result.gen_stdout
) + len_tokens(result.run_stdout)
result.gen_stdout, "gpt-4"
) + len_tokens(result.run_stdout, "gpt-4")
if result.status == "success" and all(
case.passed for case in result.results
):
Expand Down
11 changes: 6 additions & 5 deletions gptme/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def print_clear():
f"Generation interrupted after {end_time - start_time:.1f}s "
f"(ttft: {first_token_time - start_time:.2f}s, "
f"gen: {end_time - first_token_time:.2f}s, "
f"tok/s: {len_tokens(output)/(end_time - first_token_time):.1f})"
f"tok/s: {len_tokens(output, model)/(end_time - first_token_time):.1f})"
)

return Message("assistant", output)
Expand Down Expand Up @@ -179,15 +179,15 @@ def _summarize_str(content: str) -> str:
provider = _client_to_provider()
model = get_summary_model(provider)
context_limit = MODELS[provider][model]["context"]
if len_tokens(messages) > context_limit:
if len_tokens(messages, model) > context_limit:
raise ValueError(
f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages)}"
f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages, model)}"
)

summary = _chat_complete(messages, model, None)
assert summary
logger.debug(
f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): "
f"Summarized long output ({len_tokens(content, model)} -> {len_tokens(summary, model)} tokens): "
+ summary
)
return summary
Expand Down Expand Up @@ -260,7 +260,8 @@ def _summarize_helper(s: str, tok_max_start=400, tok_max_end=400) -> str:
Helper function for summarizing long outputs.
Truncates long outputs, then summarizes.
"""
if len_tokens(s) > tok_max_start + tok_max_end:
# Use gpt-4 as default model for summarization helper
if len_tokens(s, "gpt-4") > tok_max_start + tok_max_end:
beginning = " ".join(s.split()[:tok_max_start])
end = " ".join(s.split()[-tok_max_end:])
summary = _summarize_str(beginning + "\n...\n" + end)
Expand Down
11 changes: 7 additions & 4 deletions gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,19 @@ def prepare_messages(
- Enhances it with context such as file contents
- Transforms it to the format expected by LLM providers
"""
from .llm.models import get_model # fmt: skip

# Enrich with enabled context enhancements (RAG, fresh context)
msgs = enrich_messages_with_context(msgs, workspace)

# Then reduce and limit as before
msgs_reduced = list(reduce_log(msgs))

if len_tokens(msgs) != len_tokens(msgs_reduced):
logger.info(
f"Reduced log from {len_tokens(msgs)//1} to {len_tokens(msgs_reduced)//1} tokens"
)
model = get_model()
if (len_from := len_tokens(msgs, model.model)) != (
len_to := len_tokens(msgs_reduced, model.model)
):
logger.info(f"Reduced log from {len_from//1} to {len_to//1} tokens")
msgs_limited = limit_log(msgs_reduced)
if len(msgs_reduced) != len(msgs_limited):
logger.info(
Expand Down
51 changes: 37 additions & 14 deletions gptme/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import hashlib
import logging
import shutil
import sys
Expand All @@ -19,11 +20,6 @@

logger = logging.getLogger(__name__)

# max tokens allowed in a single system message
# if you hit this limit, you and/or I f-ed up, and should make the message shorter
# maybe we should make it possible to store long outputs in files, and link/summarize it/preview it in the message
max_system_len = 20000


@dataclass(frozen=True, eq=False)
class Message:
Expand Down Expand Up @@ -51,9 +47,6 @@ class Message:

def __post_init__(self):
assert isinstance(self.timestamp, datetime)
if self.role == "system":
if (length := len_tokens(self)) >= max_system_len:
logger.warning(f"System message too long: {length} tokens")

def __repr__(self):
content = textwrap.shorten(self.content, 20, placeholder="...")
Expand Down Expand Up @@ -302,11 +295,41 @@ def msgs2dicts(msgs: list[Message]) -> list[dict]:
return [msg.to_dict(keys=["role", "content", "files"]) for msg in msgs]


# TODO: remove model assumption
def len_tokens(content: str | Message | list[Message], model: str = "gpt-4") -> int:
"""Get the number of tokens in a string, message, or list of messages."""
# Global cache mapping hashes to token counts
_token_cache: dict[tuple[str, str], int] = {}


def _hash_content(content: str) -> str:
"""Create a hash of the content"""
return hashlib.sha256(content.encode()).hexdigest()


def len_tokens(content: str | Message | list[Message], model: str) -> int:
"""Get the number of tokens in a string, message, or list of messages.
Uses efficient caching with content hashing to minimize memory usage while
maintaining fast repeated calculations, which is especially important for
conversations with many messages.
"""
if isinstance(content, list):
return sum(len_tokens(msg.content, model) for msg in content)
return sum(len_tokens(msg, model) for msg in content)
if isinstance(content, Message):
return len_tokens(content.content, model)
return len(get_tokenizer(model).encode(content))
content = content.content

assert isinstance(content, str), content
# Check cache using hash
content_hash = _hash_content(content)
cache_key = (content_hash, model)
if cache_key in _token_cache:
return _token_cache[cache_key]

# Calculate and cache
count = len(get_tokenizer(model).encode(content))
_token_cache[cache_key] = count

# Limit cache size by removing oldest entries if needed
if len(_token_cache) > 1000:
# Remove first item (oldest in insertion order)
_token_cache.pop(next(iter(_token_cache)))

return count
4 changes: 3 additions & 1 deletion gptme/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_warned_models = set()


@lru_cache
def get_tokenizer(model: str):
import tiktoken # fmt: skip

Expand Down Expand Up @@ -184,7 +185,8 @@ def decorator(func): # pragma: no cover

prompt = "\n\n".join([msg.content for msg in func(*args, **kwargs)])
prompt = textwrap.indent(prompt, " ")
prompt_tokens = len_tokens(prompt)
# Use a default model for documentation purposes
prompt_tokens = len_tokens(prompt, model="gpt-4")
kwargs_str = (
(" (" + ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) + ")")
if kwargs
Expand Down
9 changes: 6 additions & 3 deletions gptme/util/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@


def _tokens_inout(msgs: list[Message]) -> tuple[int, int]:
tokens_in, tokens_out = len_tokens(msgs[:-1]), 0
from ..llm.models import get_model # fmt: skip

model = get_model()
tokens_in, tokens_out = len_tokens(msgs[:-1], model.model), 0
if msgs[-1].role == "assistant":
tokens_out = len_tokens(msgs[-1])
tokens_out = len_tokens(msgs[-1], model.model)
else:
tokens_in += len_tokens(msgs[-1])
tokens_in += len_tokens(msgs[-1], model.model)
return tokens_in, tokens_out


Expand Down
17 changes: 9 additions & 8 deletions gptme/util/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from collections.abc import Generator

from ..codeblock import Codeblock
from ..llm.models import DEFAULT_MODEL, get_model
from ..message import Message, len_tokens
from ..llm.models import get_model

logger = logging.getLogger(__name__)

Expand All @@ -21,11 +21,12 @@ def reduce_log(
) -> Generator[Message, None, None]:
"""Reduces log until it is below `limit` tokens by continually summarizing the longest messages until below the limit."""
# get the token limit
model = DEFAULT_MODEL or get_model("gpt-4")
if limit is None:
limit = 0.9 * get_model().context
limit = 0.9 * model.context

# if we are below the limit, return the log as-is
tokens = len_tokens(log)
tokens = len_tokens(log, model=model.model)
if tokens <= limit:
yield from log
return
Expand All @@ -34,7 +35,7 @@ def reduce_log(
# filter out pinned messages
i, longest_msg = max(
[(i, m) for i, m in enumerate(log) if not m.pinned],
key=lambda t: len_tokens(t[1].content),
key=lambda t: len_tokens(t[1].content, model.model),
)

# attempt to truncate the longest message
Expand All @@ -53,7 +54,7 @@ def reduce_log(

log = log[:i] + [summary_msg] + log[i + 1 :]

tokens = len_tokens(log)
tokens = len_tokens(log, model.model)
if tokens <= limit:
yield from log
else:
Expand Down Expand Up @@ -105,7 +106,7 @@ def limit_log(log: list[Message]) -> list[Message]:
then removes the last message to get below the limit.
Will always pick the first few system messages.
"""
limit = get_model().context
model = get_model()

# Always pick the first system messages
initial_system_msgs = []
Expand All @@ -118,11 +119,11 @@ def limit_log(log: list[Message]) -> list[Message]:
msgs = []
for msg in reversed(log[len(initial_system_msgs) :]):
msgs.append(msg)
if len_tokens(msgs) > limit:
if len_tokens(msgs, model.model) > model.context:
break

# Remove the message that put us over the limit
if len_tokens(msgs) > limit:
if len_tokens(msgs, model.model) > model.context:
# skip the last message
msgs.pop()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def init():
def test_get_prompt_full():
prompt = get_prompt("full")
# TODO: lower this significantly by selectively removing examples from the full prompt
assert 500 < len_tokens(prompt.content) < 5000
assert 500 < len_tokens(prompt.content, "gpt-4") < 5000


def test_get_prompt_short():
prompt = get_prompt("short")
assert 500 < len_tokens(prompt.content) < 2000
assert 500 < len_tokens(prompt.content, "gpt-4") < 2000


def test_get_prompt_custom():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@


def test_truncate_msg():
len_pre = len_tokens(long_msg)
len_pre = len_tokens(long_msg, "gpt-4")
truncated = truncate_msg(long_msg)
assert truncated is not None
len_post = len_tokens(truncated)
len_post = len_tokens(truncated, "gpt-4")
assert len_pre > len_post
assert "[...]" in truncated.content
assert "```cli.py" in truncated.content
Expand All @@ -38,12 +38,12 @@ def test_reduce_log():
Message("user", content=" ".join(fn.name for fn in [readme, cli, htmlindex])),
long_msg,
]
len_pre = len_tokens(msgs)
len_pre = len_tokens(msgs, "gpt-4")
print(f"{len_pre=}")

limit = 1000
reduced = list(reduce_log(msgs, limit=limit))
len_post = len_tokens(reduced)
len_post = len_tokens(reduced, "gpt-4")
print(f"{len_post=}")
print(f"{reduced[-1].content=}")

Expand Down

0 comments on commit c05f9c6

Please sign in to comment.