diff --git a/gptme/init.py b/gptme/init.py index a813cde4..06a29e30 100644 --- a/gptme/init.py +++ b/gptme/init.py @@ -1,13 +1,19 @@ import atexit import logging import readline +from typing import cast from dotenv import load_dotenv from .config import config_path, load_config, set_config_value from .dirs import get_readline_history_file from .llm import init_llm -from .models import PROVIDERS, get_recommended_model, set_default_model +from .models import ( + PROVIDERS, + Provider, + get_recommended_model, + set_default_model, +) from .tabcomplete import register_tabcomplete from .tools import init_tools from .util import console @@ -53,9 +59,9 @@ def init(model: str | None, interactive: bool, tool_allowlist: list[str] | None) raise ValueError("No API key found, couldn't auto-detect provider") if any(model.startswith(f"{provider}/") for provider in PROVIDERS): - provider, model = model.split("/", 1) + provider, model = cast(tuple[Provider, str], model.split("/", 1)) else: - provider, model = model, None + provider, model = cast(tuple[Provider, str], (model, None)) # set up API_KEY and API_BASE, needs to be done before loading history to avoid saving API_KEY init_llm(provider) diff --git a/gptme/models.py b/gptme/models.py index bf7c0084..3f144c85 100644 --- a/gptme/models.py +++ b/gptme/models.py @@ -1,6 +1,11 @@ import logging from dataclasses import dataclass -from typing import Literal, TypedDict, get_args +from typing import ( + Literal, + TypedDict, + cast, + get_args, +) from typing_extensions import NotRequired @@ -40,7 +45,7 @@ class _ModelDictMeta(TypedDict): # known models metadata # TODO: can we get this from the API? -MODELS: dict[str, dict[str, _ModelDictMeta]] = { +MODELS: dict[Provider, dict[str, _ModelDictMeta]] = { "openai": OPENAI_MODELS, "anthropic": { "claude-3-opus-20240229": { @@ -49,6 +54,12 @@ class _ModelDictMeta(TypedDict): "price_input": 15, "price_output": 75, }, + "claude-3-5-sonnet-20241022": { + "context": 200_000, + "max_output": 4096, + "price_input": 3, + "price_output": 15, + }, "claude-3-5-sonnet-20240620": { "context": 200_000, "max_output": 4096, @@ -85,12 +96,12 @@ def get_model(model: str | None = None) -> ModelMeta: # if only provider is given, get recommended model if model in PROVIDERS: - provider = model + provider = cast(Provider, model) model = get_recommended_model(provider) return get_model(f"{provider}/{model}") if any(f"{provider}/" in model for provider in PROVIDERS): - provider, model = model.split("/", 1) + provider, model = cast(tuple[Provider, str], model.split("/", 1)) if provider not in MODELS or model not in MODELS[provider]: if provider not in ["openrouter", "local"]: logger.warning( @@ -113,18 +124,18 @@ def get_model(model: str | None = None) -> ModelMeta: ) -def get_recommended_model(provider: str) -> str: # pragma: no cover +def get_recommended_model(provider: Provider) -> str: # pragma: no cover if provider == "openai": return "gpt-4o" elif provider == "openrouter": return "meta-llama/llama-3.1-405b-instruct" elif provider == "anthropic": - return "claude-3-5-sonnet-20240620" + return "claude-3-5-sonnet-20241022" else: - raise ValueError(f"Unknown provider {provider}") + raise ValueError(f"Provider {provider} did not have a recommended model") -def get_summary_model(provider: str) -> str: # pragma: no cover +def get_summary_model(provider: Provider) -> str: # pragma: no cover if provider == "openai": return "gpt-4o-mini" elif provider == "openrouter": @@ -132,4 +143,4 @@ def get_summary_model(provider: str) -> str: # pragma: no cover elif provider == "anthropic": return "claude-3-haiku-20240307" else: - raise ValueError(f"Unknown provider {provider}") + raise ValueError(f"Provider {provider} did not have a summary model")