From a563e799de8c7b7fdf07c7c2523d2bb02cbe7c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Sun, 8 Dec 2024 12:33:16 +0100 Subject: [PATCH] feat: more progress on fresh context, refactored it into new context.py file --- gptme/chat.py | 33 +++++--- gptme/context.py | 200 ++++++++++++++++++++++++++++++++++++++++++++ gptme/logmanager.py | 111 ++++++------------------ 3 files changed, 250 insertions(+), 94 deletions(-) create mode 100644 gptme/context.py diff --git a/gptme/chat.py b/gptme/chat.py index 2e39dddd..b82fdaac 100644 --- a/gptme/chat.py +++ b/gptme/chat.py @@ -113,7 +113,7 @@ def confirm_func(msg) -> bool: # this way we can defer reading multiple stale versions of files in requests # we should probably ensure that the old file contents get included in exports and such # maybe we need seperate modes for this, but I think the refactor makes sense anyway - msg = _include_paths(msg) + msg = _include_paths(msg, workspace) manager.append(msg) # if prompt is a user-command, execute it if execute_cmd(msg, manager, confirm_func): @@ -123,7 +123,9 @@ def confirm_func(msg) -> bool: while True: try: set_interruptible() - response_msgs = list(step(manager.log, stream, confirm_func)) + response_msgs = list( + step(manager.log, stream, confirm_func, workspace) + ) except KeyboardInterrupt: console.log("Interrupted. Stopping current execution.") manager.append(Message("system", "Interrupted")) @@ -168,7 +170,9 @@ def confirm_func(msg) -> bool: # ask for input if no prompt, generate reply, and run tools clear_interruptible() # Ensure we're not interruptible during user input - for msg in step(manager.log, stream, confirm_func): # pragma: no cover + for msg in step( + manager.log, stream, confirm_func, workspace + ): # pragma: no cover manager.append(msg) # run any user-commands, if msg is from user if msg.role == "user" and execute_cmd(msg, manager, confirm_func): @@ -179,6 +183,7 @@ def step( log: Log | list[Message], stream: bool, confirm: ConfirmFunc, + workspace: Path | None = None, ) -> Generator[Message, None, None]: """Runs a single pass of the chat.""" if isinstance(log, list): @@ -197,7 +202,7 @@ def step( ): # pragma: no cover inquiry = prompt_user() msg = Message("user", inquiry, quiet=True) - msg = _include_paths(msg) + msg = _include_paths(msg, workspace) yield msg log = log.append(msg) @@ -206,7 +211,7 @@ def step( set_interruptible() # performs reduction/context trimming, if necessary - msgs = prepare_messages(log.messages) + msgs = prepare_messages(log.messages, workspace) for m in msgs: logger.debug(f"Prepared message: {m}") @@ -257,15 +262,19 @@ def prompt_input(prompt: str, value=None) -> str: # pragma: no cover return value -def _include_paths(msg: Message) -> Message: +def _include_paths(msg: Message, workspace: Path | None = None) -> Message: """ Searches the message for any valid paths and: - In legacy mode (default): - appends the contents of such files as codeblocks - includes images as files - In fresh context mode (GPTME_FRESH_CONTEXT=1): - - only tracks paths in msg.files + - only tracks paths in msg.files (relative to workspace if provided) - contents are included fresh before each user message + + Args: + msg: Message to process + workspace: If provided, paths will be stored relative to this directory """ use_fresh_context = os.getenv("GPTME_FRESH_CONTEXT", "").lower() in ( "1", @@ -308,7 +317,11 @@ def _include_paths(msg: Message) -> Message: # Track files in msg.files file = _parse_prompt_files(word) if file: - msg.files.append(file) + # Store path relative to workspace if provided + if workspace and not file.is_absolute(): + msg.files.append(file.absolute().relative_to(workspace)) + else: + msg.files.append(file) # append the message with the file contents if append_msg: @@ -333,7 +346,7 @@ def _parse_prompt(prompt: str) -> str | None: # check if prompt is a path, if so, replace it with the contents of that file f = Path(prompt).expanduser() if f.exists() and f.is_file(): - return f"```{prompt}\n{Path(prompt).expanduser().read_text()}\n```" + return f"```{prompt}\n{f.read_text()}\n```" except OSError as oserr: # some prompts are too long to be a path, so we can't read them if oserr.errno != errno.ENAMETOOLONG: @@ -399,7 +412,7 @@ def _parse_prompt_files(prompt: str) -> Path | None: return None try: - p = Path(prompt) + p = Path(prompt).expanduser() if not (p.exists() and p.is_file()): return None diff --git a/gptme/context.py b/gptme/context.py new file mode 100644 index 00000000..81660b7c --- /dev/null +++ b/gptme/context.py @@ -0,0 +1,200 @@ +import json +import logging +import os +import shutil +import subprocess +from collections import Counter +from dataclasses import replace +from pathlib import Path + +from .message import Message + +logger = logging.getLogger(__name__) + + +def file_to_display_path(f: Path, workspace: Path | None = None) -> Path: + """ + Determine how to display the path: + - If file and pwd is in workspace, show path relative to pwd + - Otherwise, show absolute path + """ + cwd = Path.cwd() + if workspace and workspace in f.parents and workspace in [cwd, *cwd.parents]: + # NOTE: walk_up only available in Python 3.12+ + try: + return f.relative_to(cwd) + except ValueError: + # If relative_to fails, try to find a common parent + for parent in cwd.parents: + try: + if workspace in parent.parents or workspace == parent: + return f.relative_to(parent) + except ValueError: + continue + return f.absolute() + elif Path.home() in f.parents: + return Path("~") / f.relative_to(os.path.expanduser("~")) + return f + + +def textfile_as_codeblock(path: Path) -> str | None: + """Include file content as a codeblock.""" + try: + if path.exists() and path.is_file(): + try: + return f"```{path}\n{path.read_text()}\n```" + except UnicodeDecodeError: + return None + except OSError: + return None + return None + + +def append_file_content(msg: Message, workspace: Path | None = None) -> Message: + """Append file content to a message.""" + files = [file_to_display_path(f, workspace) for f in msg.files] + text_files = {f: content for f in files if (content := textfile_as_codeblock(f))} + return replace( + msg, + content=msg.content + "\n\n".join(text_files.values()), + files=[f for f in files if f not in text_files], + ) + + +def git_branch() -> str | None: + """Get the current git branch name.""" + if shutil.which("git"): + try: + branch = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + if branch.returncode == 0: + return branch.stdout.strip() + except subprocess.CalledProcessError: + logger.error("Failed to get git branch") + return None + return None + + +def gh_pr_status() -> str | None: + """Get GitHub PR status if available.""" + branch = git_branch() + if shutil.which("gh") and branch and branch not in ["main", "master"]: + try: + p = subprocess.run( + ["gh", "pr", "view", "--json", "number,title,url,body,comments"], + capture_output=True, + text=True, + check=True, + ) + p_diff = subprocess.run( + ["gh", "pr", "diff"], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + logger.error(f"Failed to get PR info: {e}") + return None + + pr = json.loads(p.stdout) + return f"""Pull Request #{pr["number"]}: {pr["title"]} ({branch}) +{pr["url"]} + +{pr["body"]} + + +{p_diff.stdout} + + + +{pr["comments"]} + +""" + + return None + + +def git_status() -> str | None: + """Get git status if in a repository.""" + try: + git_status = subprocess.run( + ["git", "status", "-vv"], capture_output=True, text=True, check=True + ) + if git_status.returncode == 0: + logger.debug("Including git status in context") + return f"```git status -vv\n{git_status.stdout}```" + except (subprocess.CalledProcessError, FileNotFoundError): + logger.debug("Not in a git repository or git not available") + return None + + +def gather_fresh_context(msgs: list[Message], workspace: Path | None) -> Message: + """Gather fresh context from files and git status.""" + + # Get files mentioned in conversation + workspace_abs = workspace.resolve() if workspace else None + files: Counter[Path] = Counter() + for msg in msgs: + for f in msg.files: + # If path is relative and we have a workspace, make it absolute relative to workspace + if not f.is_absolute() and workspace_abs: + f = (workspace_abs / f).resolve() + else: + f = f.resolve() + files[f] += 1 + logger.info( + f"Files mentioned in conversation (workspace: {workspace_abs}): {dict(files)}" + ) + + # Sort by mentions and recency + def file_score(f: Path) -> tuple[int, float]: + try: + mtime = f.stat().st_mtime + return (files[f], mtime) + except FileNotFoundError: + return (files[f], 0) + + mentioned_files = sorted(files.keys(), key=file_score, reverse=True) + sections = [] + + if git_status_output := git_status(): + sections.append(git_status_output) + + if pr_status_output := gh_pr_status(): + sections.append(pr_status_output) + + # Read contents of most relevant files + for f in mentioned_files[:10]: # Limit to top 10 files + if f.exists(): + logger.info(f"Including fresh content from: {f}") + try: + with open(f) as file: + content = file.read() + except UnicodeDecodeError: + logger.debug(f"Skipping binary file: {f}") + content = "" + display_path = file_to_display_path(f, workspace) + logger.info(f"Reading file: {display_path}") + sections.append(f"```{display_path}\n{content}\n```") + else: + logger.info(f"File not found: {f}") + + cwd = Path.cwd() + return Message( + "system", + f"""# Context +Working directory: {cwd} + +This context message is inserted right before your last message. +It contains the current state of relevant files and git status at the time of processing. +The file contents shown in this context message are the source of truth. +Any file contents shown elsewhere in the conversation history may be outdated. +This context message will be removed and replaced with fresh context in the next message. + +""" + + "\n\n".join(sections), + ) diff --git a/gptme/logmanager.py b/gptme/logmanager.py index 97523300..c8a9a1b1 100644 --- a/gptme/logmanager.py +++ b/gptme/logmanager.py @@ -2,10 +2,9 @@ import logging import os import shutil -import subprocess import textwrap -from collections import Counter from collections.abc import Generator +from copy import copy from dataclasses import dataclass, field, replace from datetime import datetime from itertools import islice, zip_longest @@ -15,6 +14,7 @@ from rich import print +from .context import append_file_content, gather_fresh_context from .dirs import get_logs_dir from .message import Message, len_tokens, print_msg from .prompts import get_prompt @@ -75,7 +75,6 @@ def __init__( branch: str | None = None, ): self.current_branch = branch or "main" - if logdir: self.logdir = Path(logdir) else: @@ -102,6 +101,11 @@ def __init__( # TODO: Check if logfile has contents, then maybe load, or should it overwrite? + @property + def workspace(self) -> Path: + """Path to workspace directory (resolves symlink if exists).""" + return (self.logdir / "workspace").resolve() + @property def log(self) -> Log: return self._branches[self.current_branch] @@ -307,69 +311,20 @@ def to_dict(self, branches=False) -> dict: return d -def gather_fresh_context(msgs: list[Message]) -> Message: - """Gather fresh context from files and git status.""" - # Get files mentioned in conversation - files = Counter([f for msg in msgs for f in msg.files]) - logger.debug(f"Files mentioned in conversation: {dict(files)}") - - # Sort by mentions and recency - def file_score(f: Path) -> tuple[int, float]: - try: - mtime = Path(f).stat().st_mtime - return (files[f], mtime) - except FileNotFoundError: - return (files[f], 0) - - mentioned_files = sorted(files.keys(), key=file_score, reverse=True) - - # Read contents of most relevant files - context = "" - for f in mentioned_files[:10]: # Limit to top 10 files - if Path(f).exists(): - logger.debug(f"Including fresh content from: {f}") - context += f"```{f}\n" - try: - with open(f) as file: - context += file.read() - except UnicodeDecodeError: - logger.debug(f"Skipping binary file: {f}") - context += "" - context += "\n```\n" - else: - logger.debug(f"File not found: {f}") - - # Add git status if in repo - try: - git_status = subprocess.run( - ["git", "status", "-vv"], capture_output=True, text=True, check=True - ) - if git_status.returncode == 0: - logger.debug("Including git status in context") - context += "\nGit status:\n```\n" + git_status.stdout + "```\n" - except (subprocess.CalledProcessError, FileNotFoundError): - logger.debug("Not in a git repository or git not available") - - return Message("system", "Fresh context:\n\n" + context) - - -def include_file_content(path: Path) -> str | None: - """Include file content as a codeblock.""" - try: - if path.exists() and path.is_file(): - try: - return f"```{path}\n{path.read_text()}\n```" - except UnicodeDecodeError: - return None - except OSError: - return None - return None - - -def prepare_messages(msgs: list[Message]) -> list[Message]: - """Prepares the messages before sending to the LLM.""" +def prepare_messages( + msgs: list[Message], workspace: Path | None = None +) -> list[Message]: + """ + Prepares the messages before sending to the LLM. + - Takes the stored gptme conversation log + - Enhances it with context such as file contents + - Transforms it to the format expected by LLM providers + """ from .tools._rag_context import _HAS_RAG, enhance_messages # fmt: skip + # make a copy to avoid mutating the original + msgs = copy(msgs) + # First enhance messages with context if _HAS_RAG: msgs = enhance_messages(msgs) @@ -385,28 +340,16 @@ def prepare_messages(msgs: list[Message]) -> list[Message]: logger.debug("Using fresh context mode") # Add fresh context # TODO: remove gathered context from `files` before sending to LLM - if msgs and msgs[-1].role == "user": - msgs.insert(-1, gather_fresh_context(msgs)) + last_user_idx = next( + (i for i, msg in enumerate(msgs[::-1]) if msg.role == "user"), None + ) + # insert message right before the last user message + fresh_content_msg = gather_fresh_context(msgs, workspace) + msgs.insert(-last_user_idx if last_user_idx else -1, fresh_content_msg) else: # Legacy mode: Include file contents where they were mentioned - text_files = { - f: content - for msg in msgs - for f in msg.files - if (content := include_file_content(f)) - } - msgs = [ - ( - replace( - msg, - content=msg.content + "\n\n".join(text_files.values()), - files=[f for f in msg.files if f not in text_files], - ) - if msg.role == "user" and msg.files - else msg - ) - for msg in msgs - ] + # FIXME: this doesn't include the versions of files as they were at the time of the message + msgs = [(append_file_content(msg, workspace)) for msg in msgs] # Then reduce and limit as before msgs_reduced = list(reduce_log(msgs))