Skip to content

Commit

Permalink
fix: update to use latest Sonnet model by default, improve typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Oct 22, 2024
1 parent 2be45a8 commit 6e70168
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
12 changes: 9 additions & 3 deletions gptme/init.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import atexit
import logging
import readline
from typing import cast

from dotenv import load_dotenv

from .config import config_path, load_config, set_config_value
from .dirs import get_readline_history_file
from .llm import init_llm
from .models import PROVIDERS, get_recommended_model, set_default_model
from .models import (
PROVIDERS,
Provider,
get_recommended_model,
set_default_model,
)
from .tabcomplete import register_tabcomplete
from .tools import init_tools
from .util import console
Expand Down Expand Up @@ -53,9 +59,9 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None)
raise ValueError("No API key found, couldn't auto-detect provider")

if any(model.startswith(f"{provider}/") for provider in PROVIDERS):
provider, model = model.split("/", 1)
provider, model = cast(tuple[Provider, str], model.split("/", 1))
else:
provider, model = model, None
provider, model = cast(tuple[Provider, str], (model, None))

# set up API_KEY and API_BASE, needs to be done before loading history to avoid saving API_KEY
init_llm(provider)
Expand Down
29 changes: 20 additions & 9 deletions gptme/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from dataclasses import dataclass
from typing import Literal, TypedDict, get_args
from typing import (
Literal,
TypedDict,
cast,
get_args,
)

from typing_extensions import NotRequired

Expand Down Expand Up @@ -40,7 +45,7 @@ class _ModelDictMeta(TypedDict):

# known models metadata
# TODO: can we get this from the API?
MODELS: dict[str, dict[str, _ModelDictMeta]] = {
MODELS: dict[Provider, dict[str, _ModelDictMeta]] = {
"openai": OPENAI_MODELS,
"anthropic": {
"claude-3-opus-20240229": {
Expand All @@ -49,6 +54,12 @@ class _ModelDictMeta(TypedDict):
"price_input": 15,
"price_output": 75,
},
"claude-3-5-sonnet-20241022": {
"context": 200_000,
"max_output": 4096,
"price_input": 3,
"price_output": 15,
},
"claude-3-5-sonnet-20240620": {
"context": 200_000,
"max_output": 4096,
Expand Down Expand Up @@ -85,12 +96,12 @@ def get_model(model: str | None = None) -> ModelMeta:

# if only provider is given, get recommended model
if model in PROVIDERS:
provider = model
provider = cast(Provider, model)
model = get_recommended_model(provider)
return get_model(f"{provider}/{model}")

if any(f"{provider}/" in model for provider in PROVIDERS):
provider, model = model.split("/", 1)
provider, model = cast(tuple[Provider, str], model.split("/", 1))
if provider not in MODELS or model not in MODELS[provider]:
if provider not in ["openrouter", "local"]:
logger.warning(
Expand All @@ -113,23 +124,23 @@ def get_model(model: str | None = None) -> ModelMeta:
)


def get_recommended_model(provider: str) -> str: # pragma: no cover
def get_recommended_model(provider: Provider) -> str: # pragma: no cover
if provider == "openai":
return "gpt-4o"
elif provider == "openrouter":
return "meta-llama/llama-3.1-405b-instruct"
elif provider == "anthropic":
return "claude-3-5-sonnet-20240620"
return "claude-3-5-sonnet-20241022"
else:
raise ValueError(f"Unknown provider {provider}")
raise ValueError(f"Provider {provider} did not have a recommended model")


def get_summary_model(provider: str) -> str: # pragma: no cover
def get_summary_model(provider: Provider) -> str: # pragma: no cover
if provider == "openai":
return "gpt-4o-mini"
elif provider == "openrouter":
return "meta-llama/llama-3.1-8b-instruct"
elif provider == "anthropic":
return "claude-3-haiku-20240307"
else:
raise ValueError(f"Unknown provider {provider}")
raise ValueError(f"Provider {provider} did not have a summary model")

0 comments on commit 6e70168

Please sign in to comment.