diff --git a/gptme/chat.py b/gptme/chat.py index 17949365..882d5974 100644 --- a/gptme/chat.py +++ b/gptme/chat.py @@ -27,12 +27,12 @@ from .tools.base import ConfirmFunc from .tools.browser import read_url from .util import ( - ask_execute, console, path_with_tilde, print_bell, rich_to_str, ) +from .util.ask_execute import ask_execute from .util.cost import log_costs from .util.readline import add_history diff --git a/gptme/cli.py b/gptme/cli.py index e8f54521..bd6c8dee 100644 --- a/gptme/cli.py +++ b/gptme/cli.py @@ -23,8 +23,14 @@ from .logmanager import ConversationMeta, get_user_conversations from .message import Message from .prompts import get_prompt -from .tools import all_tools, init_tools, ToolFormat, set_tool_format -from .util import epoch_to_age, generate_name +from .tools import ( + ToolFormat, + all_tools, + init_tools, + set_tool_format, +) +from .util import epoch_to_age +from .util.generate_name import generate_name from .util.readline import add_history logger = logging.getLogger(__name__) diff --git a/gptme/commands.py b/gptme/commands.py index 1517e5ce..e6a8700e 100644 --- a/gptme/commands.py +++ b/gptme/commands.py @@ -19,7 +19,7 @@ from .llm.models import get_model from .tools import ToolUse, execute_msg, loaded_tools from .tools.base import ConfirmFunc, get_tool_format -from .useredit import edit_text_with_editor +from .util.useredit import edit_text_with_editor logger = logging.getLogger(__name__) diff --git a/gptme/tools/patch.py b/gptme/tools/patch.py index 62f031fd..daad7caf 100644 --- a/gptme/tools/patch.py +++ b/gptme/tools/patch.py @@ -9,8 +9,13 @@ from pathlib import Path from ..message import Message -from ..util import print_preview -from .base import ConfirmFunc, Parameter, ToolSpec, ToolUse +from ..util.ask_execute import get_editable_text, set_editable_text, print_preview +from .base import ( + ConfirmFunc, + Parameter, + ToolSpec, + ToolUse, +) instructions = """ To patch/modify files, we use an adapted version of git conflict markers. @@ -188,6 +193,7 @@ def execute_patch( Applies the patch. """ + fn = None if code is not None and args is not None: fn = " ".join(args) if not fn: @@ -197,6 +203,9 @@ def execute_patch( code = kwargs.get("patch", "") fn = kwargs.get("path", "") + assert code is not None, "No patch provided" + assert fn is not None, "No path provided" + if code is None: yield Message("system", "No patch provided") return @@ -217,10 +226,16 @@ def execute_patch( # TODO: include patch headers to delimit multiple patches print_preview(patches_str, lang="diff") + # Make patch content editable before confirmation + set_editable_text(code, "patch") + if not confirm(f"Apply patch to {fn}?"): print("Patch not applied") return + # Get potentially edited content + code = get_editable_text() + try: with open(path) as f: original_content = f.read() diff --git a/gptme/tools/python.py b/gptme/tools/python.py index 6afdb23f..33436953 100644 --- a/gptme/tools/python.py +++ b/gptme/tools/python.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, TypeVar from ..message import Message -from ..util import print_preview +from ..util.ask_execute import print_preview from .base import ( ConfirmFunc, Parameter, diff --git a/gptme/tools/save.py b/gptme/tools/save.py index b2b97c26..7d7a2b5b 100644 --- a/gptme/tools/save.py +++ b/gptme/tools/save.py @@ -6,7 +6,12 @@ from pathlib import Path from ..message import Message -from ..util import print_preview +from ..util.ask_execute import ( + clear_editable_text, + get_editable_text, + set_editable_text, + print_preview, +) from .base import ( ConfirmFunc, Parameter, @@ -99,10 +104,20 @@ def execute_save( yield Message("system", "File already exists with identical content.") return - if not confirm(f"Save to {fn}?"): - # early return - yield Message("system", "Save cancelled.") - return + # Make content editable before confirmation + ext = Path(fn).suffix.lstrip(".") + set_editable_text(content, ext) + + try: + if not confirm(f"Save to {fn}?"): + # early return + yield Message("system", "Save cancelled.") + return + + # Get potentially edited content + content = get_editable_text() + finally: + clear_editable_text() # if the file exists, ask to overwrite if path.exists(): diff --git a/gptme/tools/shell.py b/gptme/tools/shell.py index 30c14170..61149353 100644 --- a/gptme/tools/shell.py +++ b/gptme/tools/shell.py @@ -16,7 +16,8 @@ from .base import Parameter from ..message import Message -from ..util import get_installed_programs, get_tokenizer, print_preview +from ..util import get_installed_programs, get_tokenizer +from ..util.ask_execute import print_preview from .base import ConfirmFunc, ToolSpec, ToolUse logger = logging.getLogger(__name__) diff --git a/gptme/tools/tmux.py b/gptme/tools/tmux.py index 2c04ee5f..e1e9cda0 100644 --- a/gptme/tools/tmux.py +++ b/gptme/tools/tmux.py @@ -13,8 +13,13 @@ from time import sleep from ..message import Message -from ..util import print_preview -from .base import ConfirmFunc, Parameter, ToolSpec, ToolUse +from ..util.ask_execute import print_preview +from .base import ( + ConfirmFunc, + Parameter, + ToolSpec, + ToolUse, +) logger = logging.getLogger(__name__) diff --git a/gptme/util/__init__.py b/gptme/util/__init__.py index 41979edb..e2a2c24f 100644 --- a/gptme/util/__init__.py +++ b/gptme/util/__init__.py @@ -5,12 +5,10 @@ import functools import io import logging -import random import re import shutil import subprocess import sys -import termios import textwrap from datetime import datetime, timedelta from functools import lru_cache @@ -19,14 +17,10 @@ from rich import print from rich.console import Console -from rich.syntax import Syntax - -from ..clipboard import copy, set_copytext EMOJI_WARN = "⚠️" logger = logging.getLogger(__name__) - console = Console(log_path=False) _warned_models = set() @@ -50,85 +44,6 @@ def get_tokenizer(model: str): return tiktoken.get_encoding("cl100k_base") -actions = [ - "running", - "jumping", - "walking", - "skipping", - "hopping", - "flying", - "swimming", - "crawling", - "sneaking", - "sprinting", - "sneaking", - "dancing", - "singing", - "laughing", -] -adjectives = [ - "funny", - "happy", - "sad", - "angry", - "silly", - "crazy", - "sneaky", - "sleepy", - "hungry", - # colors - "red", - "blue", - "green", - "pink", - "purple", - "yellow", - "orange", -] -nouns = [ - "cat", - "dog", - "rat", - "mouse", - "fish", - "elephant", - "dinosaur", - # birds - "bird", - "pelican", - # fictional - "dragon", - "unicorn", - "mermaid", - "monster", - "alien", - "robot", - # sea creatures - "whale", - "shark", - "walrus", - "octopus", - "squid", - "jellyfish", - "starfish", - "penguin", - "seal", -] - - -def generate_name(): - action = random.choice(actions) - adjective = random.choice(adjectives) - noun = random.choice(nouns) - return f"{action}-{adjective}-{noun}" - - -def is_generated_name(name: str) -> bool: - """if name is a name generated by generate_name""" - all_words = actions + adjectives + nouns - return name.count("-") == 2 and all(word in all_words for word in name.split("-")) - - def epoch_to_age(epoch, incl_date=False): # takes epoch and returns "x minutes ago", "3 hours ago", "yesterday", etc. age = datetime.now() - datetime.fromtimestamp(epoch) @@ -148,75 +63,6 @@ def epoch_to_age(epoch, incl_date=False): ) -copiable = False - - -def set_copiable(): - global copiable - copiable = True - - -def clear_copiable(): - global copiable - copiable = False - - -def print_preview(code: str, lang: str, copy: bool = False): # pragma: no cover - print() - print("[bold white]Preview[/bold white]") - - if copy: - set_copiable() - set_copytext(code) - - # NOTE: we can set background_color="default" to remove background - print(Syntax(code.strip("\n"), lang)) - print() - - -override_auto = False - - -def ask_execute(question="Execute code?", default=True) -> bool: # pragma: no cover - global override_auto - if override_auto: - return True - - print_bell() # Ring the bell just before asking for input - termios.tcflush(sys.stdin, termios.TCIFLUSH) # flush stdin - - choicestr = f"[{'Y' if default else 'y'}/{'n' if default else 'N'}{'/c' if copiable else ''}/?]" - answer = console.input( - f"[bold bright_yellow on red] {question} {choicestr} [/] ", - ) - - if not override_auto and copiable and "c" == answer.lower().strip(): - if copy(): - print("Copied to clipboard.") - return False - clear_copiable() - - # secret option to stop asking for the rest of the session - if answer.lower() in ["auto"]: - return (override_auto := True) - - # secret option to ask for help - if answer.lower() in ["help", "h", "?"]: - lines = [ - "Options:", - " y - execute the code", - " n - do not execute the code", - (" c - copy the code to the clipboard\n" if copiable else ""), - " auto - stop asking for the rest of the session", - f"Default is '{'y' if default else 'n'}' if answer is empty.", - ] - helptext = "\n".join(line for line in lines if line) - print(helptext) - return ask_execute(question, default) - - return answer.lower() in (["y", "yes"] + [""] if default else []) - - def clean_example(s: str, strict=False) -> str: orig = s s = re.sub( diff --git a/gptme/util/ask_execute.py b/gptme/util/ask_execute.py new file mode 100644 index 00000000..c5c98754 --- /dev/null +++ b/gptme/util/ask_execute.py @@ -0,0 +1,165 @@ +""" +Utilities for asking user confirmation and handling editable/copiable content. +""" + +import logging +import sys +import termios + +from rich import print +from rich.console import Console +from rich.syntax import Syntax + +from . import print_bell +from .clipboard import copy, set_copytext +from .useredit import edit_text_with_editor + +logger = logging.getLogger(__name__) +console = Console(log_path=False) + +# Global state +override_auto = False +copiable = False +editable = False + +# Editable text state +_editable_text = None +_editable_ext = None + + +def set_copiable(): + """Mark content as copiable.""" + global copiable + copiable = True + + +def clear_copiable(): + """Clear copiable state.""" + global copiable + copiable = False + + +def set_editable_text(text: str, ext: str | None = None): + """Set the text that can be edited and optionally its file extension.""" + global _editable_text, _editable_ext, editable + _editable_text = text + _editable_ext = ext + editable = True + + +def get_editable_text() -> str: + """Get the current editable text.""" + global _editable_text + if _editable_text is None: + raise RuntimeError("No editable text set") + return _editable_text + + +def get_editable_ext() -> str | None: + """Get the file extension for the editable text.""" + global _editable_ext + return _editable_ext + + +def set_edited_text(text: str): + """Update the editable text after editing.""" + global _editable_text + _editable_text = text + + +def clear_editable_text(): + """Clear the editable text and extension.""" + global _editable_text, _editable_ext, editable + _editable_text = None + _editable_ext = None + editable = False + + +def ask_execute(question="Execute code?", default=True) -> bool: + """Ask user for confirmation before executing code. + + Args: + question: The question to ask + default: The default answer if user just presses enter + + Returns: + bool: True if user confirms execution, False otherwise + """ + global override_auto, copiable, editable + + if override_auto: + return True + + print_bell() # Ring the bell just before asking for input + termios.tcflush(sys.stdin, termios.TCIFLUSH) # flush stdin + + # Build choice string with available options + choicestr = f"[{'Y' if default else 'y'}/{'n' if default else 'N'}" + if copiable: + choicestr += "/c" + if editable: + choicestr += "/e" + choicestr += "/?" + choicestr += "]" + + answer = ( + console.input( + f"[bold bright_yellow on red] {question} {choicestr} [/] ", + ) + .lower() + .strip() + ) + + if not override_auto: + if copiable and answer == "c": + if copy(): + print("Copied to clipboard.") + return False + clear_copiable() + elif editable and answer == "e": + edited = edit_text_with_editor(get_editable_text(), ext=get_editable_ext()) + if edited != get_editable_text(): + set_edited_text(edited) + print("Content updated.") + return ask_execute("Execute with changes?", default) + return False + + # secret option to stop asking for the rest of the session + if answer == "auto": + return (override_auto := True) + + # secret option to ask for help + if answer in ["help", "h", "?"]: + lines = [ + "Options:", + " y - execute the code", + " n - do not execute the code", + ] + if copiable: + lines.append(" c - copy the code to the clipboard") + if editable: + lines.append(" e - edit the code before executing") + lines.extend( + [ + " auto - stop asking for the rest of the session", + f"Default is '{'y' if default else 'n'}' if answer is empty.", + ] + ) + helptext = "\n".join(lines) + print(helptext) + return ask_execute(question, default) + + return answer in (["y", "yes"] + [""] if default else []) + + +def print_preview(code: str, lang: str, copy: bool = False): # pragma: no cover + print() + print("[bold white]Preview[/bold white]") + + if copy: + set_copiable() + set_copytext(code) + + # NOTE: we can set background_color="default" to remove background + print(Syntax(code.strip("\n"), lang)) + print() diff --git a/gptme/clipboard.py b/gptme/util/clipboard.py similarity index 100% rename from gptme/clipboard.py rename to gptme/util/clipboard.py diff --git a/gptme/util/generate_name.py b/gptme/util/generate_name.py new file mode 100644 index 00000000..ff47a8c1 --- /dev/null +++ b/gptme/util/generate_name.py @@ -0,0 +1,80 @@ +import random + +# Name generation lists +actions = [ + "running", + "jumping", + "walking", + "skipping", + "hopping", + "flying", + "swimming", + "crawling", + "sneaking", + "sprinting", + "sneaking", + "dancing", + "singing", + "laughing", +] +adjectives = [ + "funny", + "happy", + "sad", + "angry", + "silly", + "crazy", + "sneaky", + "sleepy", + "hungry", + # colors + "red", + "blue", + "green", + "pink", + "purple", + "yellow", + "orange", +] +nouns = [ + "cat", + "dog", + "rat", + "mouse", + "fish", + "elephant", + "dinosaur", + # birds + "bird", + "pelican", + # fictional + "dragon", + "unicorn", + "mermaid", + "monster", + "alien", + "robot", + # sea creatures + "whale", + "shark", + "walrus", + "octopus", + "squid", + "jellyfish", + "starfish", + "penguin", + "seal", +] + + +def generate_name(): + action = random.choice(actions) + adjective = random.choice(adjectives) + noun = random.choice(nouns) + return f"{action}-{adjective}-{noun}" + + +def is_generated_name(name: str) -> bool: + """if name is a name generated by generate_name""" + all_words = actions + adjectives + nouns + return name.count("-") == 2 and all(word in all_words for word in name.split("-")) diff --git a/gptme/useredit.py b/gptme/util/useredit.py similarity index 100% rename from gptme/useredit.py rename to gptme/util/useredit.py diff --git a/scripts/train/collect.py b/scripts/train/collect.py index d95c3ba0..db7156d2 100755 --- a/scripts/train/collect.py +++ b/scripts/train/collect.py @@ -12,7 +12,7 @@ import click import torch # type: ignore -from gptme.util import is_generated_name +from gptme.util.generate_name import is_generated_name from transformers import pipeline # type: ignore logger = logging.getLogger(__name__) diff --git a/tests/test_util.py b/tests/test_util.py index b4bf0054..58a16b42 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -3,10 +3,9 @@ from gptme.util import ( epoch_to_age, example_to_xml, - generate_name, - is_generated_name, transform_examples_to_chat_directives, ) +from gptme.util.generate_name import generate_name, is_generated_name def test_generate_name():