Skip to content

Commit

Permalink
fix: anthropic fixes and ci, auto-detect provider+model from env (#85)
Browse files Browse the repository at this point in the history
* fix: auto-detect provider from API keys in env, auto-select model for provider, fixes for summarization and server support with anthropic

* ci: added anthropic to test matrix

* fix: fixed bug in model selection
  • Loading branch information
ErikBjare authored Aug 7, 2024
1 parent dd00d4b commit 8c20800
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 47 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ on:

env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}

jobs:
build:
name: Test on ${{ matrix.os }}, py-${{ matrix.python_version }}, extras-${{ matrix.extras }}
name: Test ${{ matrix.os }} with `${{ matrix.extras }}` on ${{ matrix.provider }}
runs-on: ${{ matrix.os }}
env:
RELEASE: false
Expand All @@ -21,6 +22,12 @@ jobs:
os: [ubuntu-latest]
python_version: ['3.10']
extras: ['-E server', '-E browser', '-E all']
provider: ['openai']
include:
- os: ubuntu-latest
python_version: '3.10'
extras: '-E all'
provider: 'anthropic'
steps:
- uses: actions/checkout@v3
with:
Expand Down Expand Up @@ -54,6 +61,7 @@ jobs:
uses: nick-fields/retry@v2
env:
TERM: xterm
PROVIDER: ${{ matrix.provider }}
with:
timeout_minutes: 5
max_attempts: 2
Expand Down
23 changes: 12 additions & 11 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from .commands import CMDFIX, action_descriptions, execute_cmd
from .constants import MULTIPROMPT_SEPARATOR, PROMPT_USER
from .dirs import get_logs_dir
from .init import init, init_logging
from .init import PROVIDERS, init, init_logging
from .llm import reply
from .logmanager import LogManager, _conversations
from .message import Message
from .models import get_model
from .prompts import get_prompt
from .tools import execute_msg
from .tools.shell import ShellSession, set_shell
Expand All @@ -32,7 +33,8 @@
logger = logging.getLogger(__name__)
print_builtin = __builtins__["print"] # type: ignore

LLMChoice = Literal["openai", "local"]
# TODO: these are a bit redundant/incorrect
LLMChoice = Literal["openai", "anthropic", "local"]
ModelChoice = Literal["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"]


Expand Down Expand Up @@ -74,13 +76,13 @@
)
@click.option(
"--llm",
default="openai",
help="LLM to use.",
type=click.Choice(["openai", "azure", "anthropic", "local"]),
default=None,
help="LLM provider to use.",
type=click.Choice(PROVIDERS),
)
@click.option(
"--model",
default="gpt-4",
default=None,
help="Model to use.",
)
@click.option(
Expand Down Expand Up @@ -182,8 +184,8 @@ def chat(
prompt_msgs: list[Message],
initial_msgs: list[Message],
name: str,
llm: str,
model: str,
llm: str | None,
model: str | None,
stream: bool = True,
no_confirm: bool = False,
interactive: bool = True,
Expand Down Expand Up @@ -239,7 +241,7 @@ def chat(
break

# ask for input if no prompt, generate reply, and run tools
for msg in step(log, no_confirm, model, stream=stream): # pragma: no cover
for msg in step(log, no_confirm, stream=stream): # pragma: no cover
log.append(msg)
# run any user-commands, if msg is from user
if msg.role == "user" and execute_cmd(msg, log):
Expand All @@ -249,7 +251,6 @@ def chat(
def step(
log: LogManager,
no_confirm: bool,
model: str,
stream: bool = True,
) -> Generator[Message, None, None]:
"""Runs a single pass of the chat."""
Expand Down Expand Up @@ -287,7 +288,7 @@ def step(
logger.debug(f"Prepared message: {m}")

# generate response
msg_response = reply(msgs, model, stream)
msg_response = reply(msgs, get_model().model, stream)

# log response and run tools
if msg_response:
Expand Down
12 changes: 6 additions & 6 deletions gptme/cli_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import click

from .init import init, init_logging
from .init import PROVIDERS, init, init_logging

logger = logging.getLogger(__name__)

Expand All @@ -11,16 +11,16 @@
@click.option("-v", "--verbose", is_flag=True, help="Verbose output.")
@click.option(
"--llm",
default="openai",
help="LLM to use.",
type=click.Choice(["openai", "local"]),
default=None,
help="LLM provider to use.",
type=click.Choice(PROVIDERS),
)
@click.option(
"--model",
default="gpt-4",
default=None,
help="Model to use by default, can be overridden in each request.",
)
def main(verbose, llm, model): # pragma: no cover
def main(verbose: bool, llm: str | None, model: str | None): # pragma: no cover
"""
Starts a server and web UI for gptme.
Expand Down
26 changes: 23 additions & 3 deletions gptme/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@

from dotenv import load_dotenv

from .config import load_config
from .dirs import get_readline_history_file
from .llm import init_llm
from .llm import get_recommended_model, init_llm
from .models import set_default_model
from .tabcomplete import register_tabcomplete
from .tools import init_tools

logger = logging.getLogger(__name__)
_init_done = False

PROVIDERS = ["openai", "anthropic", "azure", "local"]

def init(llm: str, model: str, interactive: bool):

def init(provider: str | None, model: str | None, interactive: bool):
global _init_done
if _init_done:
logger.warning("init() called twice, ignoring")
Expand All @@ -25,8 +28,25 @@ def init(llm: str, model: str, interactive: bool):
logger.debug("Started")
load_dotenv()

config = load_config()
if not provider:
provider = config.get_env("PROVIDER")
if not provider:
# auto-detect depending on if OPENAI_API_KEY or ANTHROPIC_API_KEY is set
if config.get_env("OPENAI_API_KEY"):
print("Found OpenAI API key, using OpenAI provider")
provider = "openai"
elif config.get_env("ANTHROPIC_API_KEY"):
print("Found Anthropic API key, using Anthropic provider")
provider = "anthropic"
else:
raise ValueError("No API key found, couldn't auto-detect provider")

# set up API_KEY and API_BASE, needs to be done before loading history to avoid saving API_KEY
init_llm(llm, interactive)
init_llm(provider, interactive)

if not model:
model = config.get_env("MODEL") or get_recommended_model()
set_default_model(model)

if interactive: # pragma: no cover
Expand Down
44 changes: 22 additions & 22 deletions gptme/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
from typing import Generator, Iterator, Tuple

import openai
from anthropic import Anthropic
from openai import AzureOpenAI, OpenAI
from rich import print
Expand Down Expand Up @@ -268,42 +267,39 @@ def print_clear():
return Message("assistant", deltas_to_str(deltas))


def get_recommended_model() -> str:
assert oai_client or anthropic_client, "LLM not initialized"
return "gpt-4-turbo" if oai_client else "claude-3-5-sonnet-20240620"


def get_summary_model() -> str:
assert oai_client or anthropic_client, "LLM not initialized"
return "gpt-4o-mini" if oai_client else "claude-3-haiku-20240307"


def summarize(content: str) -> str:
"""
Summarizes a long text using a LLM.
To summarize messages or the conversation log,
use `gptme.tools.summarize` instead (which wraps this).
"""
assert oai_client, "LLM not initialized"
messages = [
Message(
"system",
content="You are ChatGPT, a large language model by OpenAI. You summarize messages.",
content="You are a helpful assistant that helps summarize messages with an AI assistant through a tool called gptme.",
),
Message("user", content=f"Summarize this:\n{content}"),
]

# model selection
model = "gpt-3.5-turbo"
if len_tokens(messages) > MODELS["openai"][model]["context"]:
model = "gpt-3.5-turbo-16k"
if len_tokens(messages) > MODELS["openai"][model]["context"]:
model = get_summary_model()
context_limit = MODELS["openai" if oai_client else "anthropic"][model]["context"]
if len_tokens(messages) > context_limit:
raise ValueError(
f"Cannot summarize more than 16385 tokens, got {len_tokens(messages)}"
f"Cannot summarize more than {context_limit} tokens, got {len_tokens(messages)}"
)

try:
response = oai_client.chat.completions.create(
model=model,
messages=msgs2dicts(messages), # type: ignore
temperature=0,
max_tokens=256,
)
except openai.APIError:
logger.error("OpenAI API error, returning empty summary: ", exc_info=True)
return "error"
summary = response.choices[0].message.content
summary = _chat_complete(messages, model)
assert summary
logger.debug(
f"Summarized long output ({len_tokens(content)} -> {len_tokens(summary)} tokens): "
Expand All @@ -325,15 +321,19 @@ def generate_name(msgs: list[Message]) -> str:
"""
The following is a conversation between a user and an assistant. Which we will generate a name for.
The name should be 2-5 words describing the conversation, separated by dashes. Examples:
The name should be 3-6 words describing the conversation, separated by dashes. Examples:
- install-llama
- implement-game-of-life
- capitalize-words-in-python
Focus on the main and/or initial topic of the conversation. Avoid using names that are too generic or too specific.
IMPORTANT: output only the name, no preamble or postamble.
""",
)
]
+ msgs
+ [Message("user", "Now, generate a name for this conversation.")]
)
name = _chat_complete(msgs, model="gpt-3.5-turbo").strip()
name = _chat_complete(msgs, model=get_summary_model()).strip()
return name
6 changes: 6 additions & 0 deletions gptme/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ class _ModelDictMeta(TypedDict):
"price_input": 3,
"price_output": 15,
},
"claude-3-haiku-20240307": {
"context": 200_000,
"max_output": 4096,
"price_input": 0.25,
"price_output": 1.25,
},
},
"local": {
# 8B
Expand Down
2 changes: 0 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def args(name: str) -> list[str]:
return [
"--name",
name,
"--model",
"gpt-4-1106-preview",
]


Expand Down
5 changes: 3 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
# noreorder
from flask.testing import FlaskClient # fmt: skip
from gptme.cli import init # fmt: skip
from gptme.models import get_model # fmt: skip
from gptme.server import create_app # fmt: skip


@pytest.fixture(autouse=True)
def init_():
init(llm="openai", model="gpt-3.5-turbo", interactive=False)
init(None, None, interactive=False)


@pytest.fixture
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_api_conversation_generate(conv: str, client: FlaskClient):

response = client.post(
f"/api/conversations/{conv}/generate",
json={"model": "gpt-3.5-turbo"},
json={"model": get_model().model},
)
assert response.status_code == 200
msgs = response.get_json()
Expand Down

0 comments on commit 8c20800

Please sign in to comment.