From 321a31dd6c16663eaffdf046b79f178fb0ab7359 Mon Sep 17 00:00:00 2001 From: lloydchang Date: Sun, 24 Nov 2024 02:02:36 -0800 Subject: [PATCH] feat: add Gemini fix https://github.com/ErikBjare/gptme/issues/161 --- docs/providers.rst | 10 ++++++++++ gptme/eval/main.py | 1 + gptme/init.py | 5 +++-- gptme/llm/__init__.py | 5 +++++ gptme/llm/llm_openai.py | 3 +++ gptme/llm/models.py | 19 +++++++++++++++++-- 6 files changed, 39 insertions(+), 4 deletions(-) diff --git a/docs/providers.rst b/docs/providers.rst index e65dea4e..e336239e 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -10,6 +10,7 @@ To select a provider and model, run ``gptme`` with the ``--model`` flag set to ` gptme --model openai/gpt-4o "hello" gptme --model anthropic "hello" # if model part unspecified, will fall back to the provider default gptme --model openrouter/meta-llama/llama-3.1-70b-instruct "hello" + gptme --model gemini/gemini-1.5-flash-latest "hello" gptme --model local/llama3.2:1b "hello" On first startup, if ``--model`` is not set, and no API keys are set in the config or environment it will be prompted for. It will then auto-detect the provider, and save the key in the configuration file. @@ -43,6 +44,15 @@ To use OpenRouter, set your API key: export OPENROUTER_API_KEY="your-api-key" +Gemini +---------- + +To use Gemini, set your API key: + +.. code-block:: sh + + export GEMINI_API_KEY="your-api-key" + Groq ---- diff --git a/gptme/eval/main.py b/gptme/eval/main.py index 330da232..b1355319 100644 --- a/gptme/eval/main.py +++ b/gptme/eval/main.py @@ -203,6 +203,7 @@ def main( "anthropic/claude-3-5-sonnet-20241022", "anthropic/claude-3-5-haiku-20241022", "openrouter/meta-llama/llama-3.1-405b-instruct", + "gemini/gemini-1.5-flash-latest", ] results_files = [] diff --git a/gptme/init.py b/gptme/init.py index 2c247c8e..7895c6b0 100644 --- a/gptme/init.py +++ b/gptme/init.py @@ -92,7 +92,7 @@ def cleanup_logging(): def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover - api_key = input("Your OpenAI, Anthropic, or OpenRouter API key: ").strip() + api_key = input("Your OpenAI, Anthropic, OpenRouter, or Gemini API key: ").strip() if (found_model_tuple := get_model_from_api_key(api_key)) is not None: return found_model_tuple else: @@ -102,12 +102,13 @@ def _prompt_api_key() -> tuple[str, str, str]: # pragma: no cover def ask_for_api_key(): # pragma: no cover """Interactively ask user for API key""" - console.print("No API key set for OpenAI, Anthropic, or OpenRouter.") + console.print("No API key set for OpenAI, Anthropic, OpenRouter, or Gemini.") console.print( """You can get one at: - OpenAI: https://platform.openai.com/account/api-keys - Anthropic: https://console.anthropic.com/settings/keys - OpenRouter: https://openrouter.ai/settings/keys + - Gemini: https://aistudio.google.com/app/apikey """ ) # Save to config diff --git a/gptme/llm/__init__.py b/gptme/llm/__init__.py index 5bf56f47..ab2e95da 100644 --- a/gptme/llm/__init__.py +++ b/gptme/llm/__init__.py @@ -121,6 +121,8 @@ def _client_to_provider() -> Provider: return "openai" elif "openrouter" in openai_client.base_url.host: return "openrouter" + elif "gemini" in openai_client.base_url.host: + return "gemini" else: return "azure" elif anthropic_client: @@ -252,6 +254,9 @@ def guess_model_from_config() -> Provider | None: elif config.get_env("OPENROUTER_API_KEY"): console.log("Found OpenRouter API key, using OpenRouter provider") return "openrouter" + elif config.get_env("GEMINI_API_KEY"): + console.log("Found Gemini API key, using Gemini provider") + return "gemini" return None diff --git a/gptme/llm/llm_openai.py b/gptme/llm/llm_openai.py index b31b05f0..92d43fa7 100644 --- a/gptme/llm/llm_openai.py +++ b/gptme/llm/llm_openai.py @@ -45,6 +45,9 @@ def init(provider: Provider, config: Config): elif provider == "openrouter": api_key = config.get_env_required("OPENROUTER_API_KEY") openai = OpenAI(api_key=api_key, base_url="https://openrouter.ai/api/v1") + elif provider == "gemini": + api_key = config.get_env_required("GEMINI_API_KEY") + openai = OpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta") elif provider == "xai": api_key = config.get_env_required("XAI_API_KEY") openai = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") diff --git a/gptme/llm/models.py b/gptme/llm/models.py index 8d4f087b..10884ca6 100644 --- a/gptme/llm/models.py +++ b/gptme/llm/models.py @@ -15,11 +15,11 @@ # available providers Provider = Literal[ - "openai", "anthropic", "azure", "openrouter", "groq", "xai", "deepseek", "local" + "openai", "anthropic", "azure", "openrouter", "gemini", "groq", "xai", "deepseek", "local" ] PROVIDERS: list[Provider] = cast(list[Provider], get_args(Provider)) PROVIDERS_OPENAI: list[Provider] -PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "xai", "groq", "deepseek", "local"] +PROVIDERS_OPENAI = ["openai", "azure", "openrouter", "gemini", "xai", "groq", "deepseek", "local"] @dataclass(frozen=True) @@ -90,6 +90,17 @@ class _ModelDictMeta(TypedDict): "price_output": 1.25, }, }, + # https://ai.google.dev/gemini-api/docs/models/gemini#gemini-1.5-flash + # https://ai.google.dev/pricing#1_5flash + "gemini": { + "gemini-1.5-flash-latest": { + "context": 1_048_576, + "max_output": 8192, + "price_input": 0.15, + "price_output": 0.60, + "supports_vision": True, + }, + }, "local": {}, } @@ -137,6 +148,8 @@ def get_recommended_model(provider: Provider) -> str: # pragma: no cover return "gpt-4o" elif provider == "openrouter": return "meta-llama/llama-3.1-405b-instruct" + elif provider == "gemini": + return "gemini-1.5-flash-latest" elif provider == "anthropic": return "claude-3-5-sonnet-20241022" else: @@ -148,6 +161,8 @@ def get_summary_model(provider: Provider) -> str: # pragma: no cover return "gpt-4o-mini" elif provider == "openrouter": return "meta-llama/llama-3.1-8b-instruct" + elif provider == "gemini": + return "gemini-1.5-flash-latest" elif provider == "anthropic": return "claude-3-haiku-20240307" else: