Skip to content

Commit

Permalink
feat: more progress on fresh context, refactored it into new context.…
Browse files Browse the repository at this point in the history
…py file
  • Loading branch information
ErikBjare committed Dec 8, 2024
1 parent d953629 commit a563e79
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 94 deletions.
33 changes: 23 additions & 10 deletions gptme/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def confirm_func(msg) -> bool:
# this way we can defer reading multiple stale versions of files in requests
# we should probably ensure that the old file contents get included in exports and such
# maybe we need seperate modes for this, but I think the refactor makes sense anyway
msg = _include_paths(msg)
msg = _include_paths(msg, workspace)
manager.append(msg)
# if prompt is a user-command, execute it
if execute_cmd(msg, manager, confirm_func):
Expand All @@ -123,7 +123,9 @@ def confirm_func(msg) -> bool:
while True:
try:
set_interruptible()
response_msgs = list(step(manager.log, stream, confirm_func))
response_msgs = list(
step(manager.log, stream, confirm_func, workspace)
)
except KeyboardInterrupt:
console.log("Interrupted. Stopping current execution.")
manager.append(Message("system", "Interrupted"))
Expand Down Expand Up @@ -168,7 +170,9 @@ def confirm_func(msg) -> bool:

# ask for input if no prompt, generate reply, and run tools
clear_interruptible() # Ensure we're not interruptible during user input
for msg in step(manager.log, stream, confirm_func): # pragma: no cover
for msg in step(
manager.log, stream, confirm_func, workspace
): # pragma: no cover
manager.append(msg)
# run any user-commands, if msg is from user
if msg.role == "user" and execute_cmd(msg, manager, confirm_func):
Expand All @@ -179,6 +183,7 @@ def step(
log: Log | list[Message],
stream: bool,
confirm: ConfirmFunc,
workspace: Path | None = None,
) -> Generator[Message, None, None]:
"""Runs a single pass of the chat."""
if isinstance(log, list):
Expand All @@ -197,7 +202,7 @@ def step(
): # pragma: no cover
inquiry = prompt_user()
msg = Message("user", inquiry, quiet=True)
msg = _include_paths(msg)
msg = _include_paths(msg, workspace)
yield msg
log = log.append(msg)

Expand All @@ -206,7 +211,7 @@ def step(
set_interruptible()

# performs reduction/context trimming, if necessary
msgs = prepare_messages(log.messages)
msgs = prepare_messages(log.messages, workspace)

for m in msgs:
logger.debug(f"Prepared message: {m}")
Expand Down Expand Up @@ -257,15 +262,19 @@ def prompt_input(prompt: str, value=None) -> str: # pragma: no cover
return value


def _include_paths(msg: Message) -> Message:
def _include_paths(msg: Message, workspace: Path | None = None) -> Message:
"""
Searches the message for any valid paths and:
- In legacy mode (default):
- appends the contents of such files as codeblocks
- includes images as files
- In fresh context mode (GPTME_FRESH_CONTEXT=1):
- only tracks paths in msg.files
- only tracks paths in msg.files (relative to workspace if provided)
- contents are included fresh before each user message
Args:
msg: Message to process
workspace: If provided, paths will be stored relative to this directory
"""
use_fresh_context = os.getenv("GPTME_FRESH_CONTEXT", "").lower() in (
"1",
Expand Down Expand Up @@ -308,7 +317,11 @@ def _include_paths(msg: Message) -> Message:
# Track files in msg.files
file = _parse_prompt_files(word)
if file:
msg.files.append(file)
# Store path relative to workspace if provided
if workspace and not file.is_absolute():
msg.files.append(file.absolute().relative_to(workspace))
else:
msg.files.append(file)

# append the message with the file contents
if append_msg:
Expand All @@ -333,7 +346,7 @@ def _parse_prompt(prompt: str) -> str | None:
# check if prompt is a path, if so, replace it with the contents of that file
f = Path(prompt).expanduser()
if f.exists() and f.is_file():
return f"```{prompt}\n{Path(prompt).expanduser().read_text()}\n```"
return f"```{prompt}\n{f.read_text()}\n```"
except OSError as oserr:
# some prompts are too long to be a path, so we can't read them
if oserr.errno != errno.ENAMETOOLONG:
Expand Down Expand Up @@ -399,7 +412,7 @@ def _parse_prompt_files(prompt: str) -> Path | None:
return None

try:
p = Path(prompt)
p = Path(prompt).expanduser()
if not (p.exists() and p.is_file()):
return None

Expand Down
200 changes: 200 additions & 0 deletions gptme/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import json
import logging
import os
import shutil
import subprocess
from collections import Counter
from dataclasses import replace
from pathlib import Path

from .message import Message

logger = logging.getLogger(__name__)


def file_to_display_path(f: Path, workspace: Path | None = None) -> Path:
"""
Determine how to display the path:
- If file and pwd is in workspace, show path relative to pwd
- Otherwise, show absolute path
"""
cwd = Path.cwd()
if workspace and workspace in f.parents and workspace in [cwd, *cwd.parents]:
# NOTE: walk_up only available in Python 3.12+
try:
return f.relative_to(cwd)
except ValueError:
# If relative_to fails, try to find a common parent
for parent in cwd.parents:
try:
if workspace in parent.parents or workspace == parent:
return f.relative_to(parent)
except ValueError:
continue
return f.absolute()
elif Path.home() in f.parents:
return Path("~") / f.relative_to(os.path.expanduser("~"))
return f


def textfile_as_codeblock(path: Path) -> str | None:
"""Include file content as a codeblock."""
try:
if path.exists() and path.is_file():
try:
return f"```{path}\n{path.read_text()}\n```"
except UnicodeDecodeError:
return None
except OSError:
return None
return None


def append_file_content(msg: Message, workspace: Path | None = None) -> Message:
"""Append file content to a message."""
files = [file_to_display_path(f, workspace) for f in msg.files]
text_files = {f: content for f in files if (content := textfile_as_codeblock(f))}
return replace(
msg,
content=msg.content + "\n\n".join(text_files.values()),
files=[f for f in files if f not in text_files],
)


def git_branch() -> str | None:
"""Get the current git branch name."""
if shutil.which("git"):
try:
branch = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
capture_output=True,
text=True,
check=True,
)
if branch.returncode == 0:
return branch.stdout.strip()
except subprocess.CalledProcessError:
logger.error("Failed to get git branch")
return None
return None


def gh_pr_status() -> str | None:
"""Get GitHub PR status if available."""
branch = git_branch()
if shutil.which("gh") and branch and branch not in ["main", "master"]:
try:
p = subprocess.run(
["gh", "pr", "view", "--json", "number,title,url,body,comments"],
capture_output=True,
text=True,
check=True,
)
p_diff = subprocess.run(
["gh", "pr", "diff"],
capture_output=True,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
logger.error(f"Failed to get PR info: {e}")
return None

pr = json.loads(p.stdout)
return f"""Pull Request #{pr["number"]}: {pr["title"]} ({branch})
{pr["url"]}
{pr["body"]}
<diff>
{p_diff.stdout}
</diff>
<comments>
{pr["comments"]}
</comments>
"""

return None


def git_status() -> str | None:
"""Get git status if in a repository."""
try:
git_status = subprocess.run(
["git", "status", "-vv"], capture_output=True, text=True, check=True
)
if git_status.returncode == 0:
logger.debug("Including git status in context")
return f"```git status -vv\n{git_status.stdout}```"
except (subprocess.CalledProcessError, FileNotFoundError):
logger.debug("Not in a git repository or git not available")
return None


def gather_fresh_context(msgs: list[Message], workspace: Path | None) -> Message:
"""Gather fresh context from files and git status."""

# Get files mentioned in conversation
workspace_abs = workspace.resolve() if workspace else None
files: Counter[Path] = Counter()
for msg in msgs:
for f in msg.files:
# If path is relative and we have a workspace, make it absolute relative to workspace
if not f.is_absolute() and workspace_abs:
f = (workspace_abs / f).resolve()
else:
f = f.resolve()
files[f] += 1
logger.info(
f"Files mentioned in conversation (workspace: {workspace_abs}): {dict(files)}"
)

# Sort by mentions and recency
def file_score(f: Path) -> tuple[int, float]:
try:
mtime = f.stat().st_mtime
return (files[f], mtime)
except FileNotFoundError:
return (files[f], 0)

mentioned_files = sorted(files.keys(), key=file_score, reverse=True)
sections = []

if git_status_output := git_status():
sections.append(git_status_output)

if pr_status_output := gh_pr_status():
sections.append(pr_status_output)

# Read contents of most relevant files
for f in mentioned_files[:10]: # Limit to top 10 files
if f.exists():
logger.info(f"Including fresh content from: {f}")
try:
with open(f) as file:
content = file.read()
except UnicodeDecodeError:
logger.debug(f"Skipping binary file: {f}")
content = "<binary file>"
display_path = file_to_display_path(f, workspace)
logger.info(f"Reading file: {display_path}")
sections.append(f"```{display_path}\n{content}\n```")
else:
logger.info(f"File not found: {f}")

cwd = Path.cwd()
return Message(
"system",
f"""# Context
Working directory: {cwd}
This context message is inserted right before your last message.
It contains the current state of relevant files and git status at the time of processing.
The file contents shown in this context message are the source of truth.
Any file contents shown elsewhere in the conversation history may be outdated.
This context message will be removed and replaced with fresh context in the next message.
"""
+ "\n\n".join(sections),
)
Loading

0 comments on commit a563e79

Please sign in to comment.