Skip to content

Commit

Permalink
fix: refactor summarize, added /save command
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Oct 11, 2023
1 parent 05f74ad commit a78cc91
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 32 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ A local alternative to ChatGPT's "Advanced Data Analysis" (previously "Code Inte

## 🎥 Demo

> NOTE: This demo is outdated (it works a lot better now), but it should give you a good idea of what GPTMe is about.
[![demo screencast with asciinema](https://github.com/ErikBjare/gptme/assets/1405370/5dda4240-bb7d-4cfa-8dd1-cd1218ccf571)](https://asciinema.org/a/606375)

<details>
Expand Down
24 changes: 22 additions & 2 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"summarize",
"context",
"load",
"save",
"shell",
"python",
"replay",
Expand All @@ -88,6 +89,7 @@
"edit": "Edit previous messages",
"summarize": "Summarize the conversation so far",
"load": "Load a file",
"save": "Save the most recent code block to a file",
"shell": "Execute a shell command",
"python": "Execute a Python command",
"replay": "Re-execute past commands in the conversation (does not store output in log)",
Expand Down Expand Up @@ -151,6 +153,8 @@ def handle_cmd(
# print context msg
print(_gen_context_msg())
case "undo":
# undo the '/undo' command itself
log.undo(1, quiet=True)
# if int, undo n messages
n = int(args[0]) if args and args[0].isdigit() else 1
log.undo(n)
Expand All @@ -159,6 +163,23 @@ def handle_cmd(
with open(filename) as f:
contents = f.read()
yield Message("system", f"# filename: {filename}\n\n{contents}")
case "save":
# undo
log.undo(1, quiet=True)

# save the most recent code block to a file
code = log.get_last_code_block()
if not code:
print("No code block found")
return
filename = args[0] if args else input("Filename: ")
if Path(filename).exists():
ans = input("File already exists, overwrite? [y/N] ")
if ans.lower() != "y":
return
with open(filename, "w") as f:
f.write(code)
print(f"Saved code block to {filename}")
case "exit":
sys.exit(0)
case "replay":
Expand Down Expand Up @@ -220,8 +241,7 @@ def handle_cmd(
@click.option(
"--model",
default="gpt-4",
help="Model to use (gpt-3.5 not recommended)",
type=click.Choice(["gpt-4", "gpt-3.5-turbo", "wizardcoder-..."]),
help="Model to use (gpt-3.5 not recommended). Can be: gpt-4, gpt-3.5-turbo, wizardcoder-..., etc.",
)
@click.option(
"--stream/--no-stream",
Expand Down
28 changes: 27 additions & 1 deletion gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .config import get_config
from .constants import PROMPT_ASSISTANT
from .message import Message
from .util import msgs2dicts
from .util import len_tokens, msgs2dicts

# Optimized for code
# Discussion here: https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683
Expand All @@ -24,6 +24,7 @@ def init_llm(llm: str):
# set up API_KEY (if openai) and API_BASE (if local)
config = get_config()

# TODO: use llm/model from config if specified and not passed as args
if llm == "openai":
if "OPENAI_API_KEY" in os.environ:
api_key = os.environ["OPENAI_API_KEY"]
Expand Down Expand Up @@ -105,3 +106,28 @@ def print_clear():
print_clear()
logger.debug(f"Stop reason: {stop_reason}")
return Message("assistant", deltas_to_str(deltas))


def summarize(content: str) -> str:
"""
Summarizes a long text using a LLM.
To summarize messages or the conversation log,
use `gptme.tools.summarize` instead (which wraps this).
"""
try:
response = openai.Completion.create(
model="text-davinci-003",
prompt="Please summarize the following:\n" + content + "\n\nSummary:",
temperature=0,
max_tokens=256,
)
except openai.APIError:
logger.error("OpenAI API error, returning empty summary: ", exc_info=True)
return "error"
summary = response.choices[0].text
logger.debug(
f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): "
+ summary
)
return summary
16 changes: 15 additions & 1 deletion gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import textwrap
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TypeAlias

from rich import print
Expand All @@ -24,7 +25,11 @@ def __init__(
show_hidden=False,
):
self.log = log or []
assert logfile is not None, "logfile must be specified"
if logfile is None:
# generate tmpfile
fpath = NamedTemporaryFile(delete=False).name
print(f"[yellow]No logfile specified, using tmpfile {fpath}.[/]")
logfile = Path(fpath)
self.logfile = logfile
self.show_hidden = show_hidden
# TODO: Check if logfile has contents, then maybe load, or should it overwrite?
Expand Down Expand Up @@ -108,6 +113,15 @@ def load(
msgs = initial_msgs
return cls(msgs, logfile=logfile, **kwargs)

def get_last_code_block(self) -> str | None:
"""Returns the last code block in the log, if any."""
for msg in self.log[::-1]:
# check if message contains a code block
backtick_count = msg.content.count("```")
if backtick_count >= 2:
return msg.content.split("```")[-2].split("\n", 1)[-1]
return None


def write_log(msg_or_log: Message | list[Message], logfile: PathLike) -> None:
"""
Expand Down
42 changes: 14 additions & 28 deletions gptme/tools/summarize.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,37 @@
import logging
from functools import lru_cache

import openai

from ..llm import summarize as _summarize
from ..message import Message, format_msgs
from ..util import len_tokens

logger = logging.getLogger(__name__)


@lru_cache(maxsize=100)
def _llm_summarize(content: str) -> str:
"""Summarizes a long text using a LLM algorithm."""
response = openai.Completion.create(
model="text-davinci-003",
prompt="Please summarize the following:\n" + content + "\n\nSummary:",
temperature=0,
max_tokens=256,
)
summary = response.choices[0].text
logger.debug(
f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): "
+ summary
)
return summary


def summarize(msg: Message | list[Message]) -> Message:
"""Uses a cheap LLM to summarize long outputs."""
# construct plaintext from message(s)
msgs = msg if isinstance(msg, list) else [msg]
content = "\n".join(format_msgs(msgs))
summary = _summarize(content)
summary = _summarize_helper(content)
# construct message from summary
summary_msg = Message(
role="system", content=f"Summary of the conversation:\n{summary})"
)
return summary_msg


def _summarize(s: str) -> str:
if len_tokens(s) > 200:
# first 100 tokens
beginning = " ".join(s.split()[:150])
# last 100 tokens
end = " ".join(s.split()[-100:])
summary = _llm_summarize(beginning + "\n...\n" + end)
@lru_cache(maxsize=128)
def _summarize_helper(s: str, tok_max_start=500, tok_max_end=500) -> str:
"""
Helper function for summarizing long outputs.
Trims long outputs to 200 tokens, then summarizes.
"""
if len_tokens(s) > tok_max_start + tok_max_end:
beginning = " ".join(s.split()[:tok_max_start])
end = " ".join(s.split()[-tok_max_end:])
summary = _summarize(beginning + "\n...\n" + end)
else:
summary = _llm_summarize(s)
summary = _summarize(s)
return summary
21 changes: 21 additions & 0 deletions tests/test_logmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from gptme.logmanager import LogManager, Message


def test_get_last_code_block():
# tests that the last code block is indeed returned, with the correct formatting
log = LogManager()
log.append(
Message(
"assistant",
"""
```python
print('hello')
```
```python
print('world')
```
""",
)
)
assert log.get_last_code_block() == "print('world')\n"

0 comments on commit a78cc91

Please sign in to comment.