Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for anthropic #84

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"--llm",
default="openai",
help="LLM to use.",
type=click.Choice(["openai", "azure", "local"]),
type=click.Choice(["openai", "azure", "anthropic", "local"]),
)
@click.option(
"--model",
Expand Down
148 changes: 124 additions & 24 deletions gptme/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import shutil
import sys
from typing import Generator, Iterator, Tuple

import openai
from anthropic import Anthropic
from openai import AzureOpenAI, OpenAI
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from rich import print

from .config import config_path, get_config, set_config_value
Expand All @@ -22,10 +23,11 @@
logger = logging.getLogger(__name__)

oai_client: OpenAI | None = None
anthropic_client: Anthropic | None = None


def init_llm(llm: str, interactive: bool):
global oai_client
global oai_client, anthropic_client

# set up API_KEY (if openai) and API_BASE (if local)
config = get_config()
Expand All @@ -50,16 +52,22 @@ def init_llm(llm: str, interactive: bool):
api_version="2023-07-01-preview",
azure_endpoint=azure_endpoint,
)

elif llm == "anthropic":
api_key = config.get_env_required("ANTHROPIC_API_KEY")
anthropic_client = Anthropic(
api_key=api_key,
)

elif llm == "local":
api_key = config.get_env("OPENAI_API_BASE", "local")
api_base = config.get_env_required("OPENAI_API_BASE")
oai_client = OpenAI(api_key="local", base_url=api_base)
oai_client = OpenAI(api_key="ollama", base_url=api_base)
else:
print(f"Error: Unknown LLM: {llm}")
sys.exit(1)

# ensure we have initialized the client
assert oai_client
assert oai_client or anthropic_client


def ask_for_api_key():
Expand All @@ -82,10 +90,11 @@ def reply(messages: list[Message], model: str, stream: bool = False) -> Message:
print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r")
response = _chat_complete(messages, model)
print(" " * shutil.get_terminal_size().columns, end="\r")
print(f"{PROMPT_ASSISTANT}: {response}")
return Message("assistant", response)


def _chat_complete(messages: list[Message], model: str) -> str:
def _chat_complete_openai(messages: list[Message], model: str) -> str:
# This will generate code and such, so we need appropriate temperature and top_p params
# top_p controls diversity, temperature controls randomness
assert oai_client, "LLM not initialized"
Expand All @@ -100,42 +109,134 @@ def _chat_complete(messages: list[Message], model: str) -> str:
return content


def _reply_stream(messages: list[Message], model: str) -> Message:
print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r")
def _chat_complete_anthropic(messages: list[Message], model: str) -> str:
assert anthropic_client, "LLM not initialized"
messages, system_message = _transform_system_messages_anthropic(messages)
response = anthropic_client.messages.create(
model=model,
messages=msgs2dicts(messages), # type: ignore
system=system_message,
temperature=temperature,
top_p=top_p,
max_tokens=4096,
)
# TODO: rewrite handling of response to support anthropic API
content = response.content
assert content
assert len(content) == 1
return content[0].text # type: ignore


def _chat_complete(messages: list[Message], model: str) -> str:
if oai_client:
return _chat_complete_openai(messages, model)
elif anthropic_client:
return _chat_complete_anthropic(messages, model)
else:
raise ValueError("LLM not initialized")


def _transform_system_messages_anthropic(
messages: list[Message],
) -> Tuple[list[Message], str]:
# transform system messages into system kwarg for anthropic
# for first system message, transform it into a system kwarg
assert messages[0].role == "system"
system_prompt = messages[0].content
messages.pop(0)

# for any subsequent system messages, transform them into a <system> message
for i, message in enumerate(messages):
if message.role == "system":
messages[i] = Message(
"user",
content=f"<system>{message.content}</system>",
)

# find consecutive user role messages and merge them into a single <system> message
messages_new: list[Message] = []
while messages:
message = messages.pop(0)
if messages_new and messages_new[-1].role == "user":
messages_new[-1] = Message(
"user",
content=f"{messages_new[-1].content}\n{message.content}",
)
else:
messages_new.append(message)
messages = messages_new

return messages, system_prompt


def _stream(messages: list[Message], model: str) -> Iterator[str]:
if oai_client:
return _stream_openai(messages, model)
elif anthropic_client:
return _stream_anthropic(messages, model)
else:
raise ValueError("LLM not initialized")


def _stream_openai(messages: list[Message], model: str) -> Generator[str, None, None]:
assert oai_client, "LLM not initialized"
response = oai_client.chat.completions.create(
stop_reason = None
for chunk in oai_client.chat.completions.create(
model=model,
messages=msgs2dicts(messages), # type: ignore
temperature=temperature,
top_p=top_p,
stream=True,
# the llama-cpp-python server needs this explicitly set, otherwise unreliable results
max_tokens=1000 if not model.startswith("gpt-") else None,
)
):
if not chunk.choices: # type: ignore
# Got a chunk with no choices, Azure always sends one of these at the start
continue
stop_reason = chunk.choices[0].finish_reason # type: ignore
yield chunk.choices[0].delta.content # type: ignore
logger.debug(f"Stop reason: {stop_reason}")

def deltas_to_str(deltas: list[ChoiceDelta]):
return "".join([d.content or "" for d in deltas])

def _stream_anthropic(
messages: list[Message], model: str
) -> Generator[str, None, None]:
messages, system_prompt = _transform_system_messages_anthropic(messages)
assert anthropic_client, "LLM not initialized"
with anthropic_client.messages.stream(
model=model,
messages=msgs2dicts(messages), # type: ignore
system=system_prompt,
temperature=temperature,
top_p=top_p,
max_tokens=4096,
) as stream:
yield from stream.text_stream


def _reply_stream(messages: list[Message], model: str) -> Message:
print(f"{PROMPT_ASSISTANT}: Thinking...", end="\r")

def deltas_to_str(deltas: list[str]):
return "".join([d or "" for d in deltas])

def print_clear():
print(" " * shutil.get_terminal_size().columns, end="\r")

deltas: list[ChoiceDelta] = []
deltas: list[str] = []
print_clear()
print(f"{PROMPT_ASSISTANT}: ", end="")
stop_reason = None
try:
for chunk in response:
if isinstance(chunk, tuple):
print("Got a tuple, expected Chunk")
for delta in _stream(messages, model):
if isinstance(delta, tuple):
print("Got a tuple, expected str")
continue
if not chunk.choices:
# Got a chunk with no choices, Azure always sends one of these at the start
if isinstance(delta, tuple):
print("Got a Chunk, expected str")
continue
delta = chunk.choices[0].delta
deltas.append(delta)
delta_str = deltas_to_str(deltas)
stop_reason = chunk.choices[0].finish_reason
print(deltas_to_str([delta]), end="")
print(deltas_to_str([deltas[-1]]), end="")
# need to flush stdout to get the print to show up
sys.stdout.flush()

Expand All @@ -157,14 +258,13 @@ def print_clear():
if patch_started and patch_finished:
if "```" not in delta_str[-10:]:
print("\n```", end="")
deltas.append(ChoiceDelta(content="\n```"))
deltas.append("\n```")
print("\n")
break
except KeyboardInterrupt:
return Message("assistant", deltas_to_str(deltas) + "... ^C Interrupted")
finally:
print_clear()
logger.debug(f"Stop reason: {stop_reason}")
return Message("assistant", deltas_to_str(deltas))


Expand Down
12 changes: 11 additions & 1 deletion gptme/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import TypedDict
from typing import Optional, TypedDict

from typing_extensions import NotRequired

Expand All @@ -12,6 +12,7 @@ class ModelMeta:
provider: str
model: str
context: int
max_output: Optional[int] = None

# price in USD per 1k tokens
# if price is not set, it is assumed to be 0
Expand All @@ -21,6 +22,7 @@ class ModelMeta:

class _ModelDictMeta(TypedDict):
context: int
max_output: NotRequired[int]

# price in USD per 1k tokens
price_input: NotRequired[float]
Expand Down Expand Up @@ -74,6 +76,14 @@ class _ModelDictMeta(TypedDict):
"context": 128_000,
},
},
"anthropic": {
"claude-3-5-sonnet-20240620": {
"context": 200_000,
"max_output": 4096,
"price_input": 0.003,
"price_output": 0.015,
},
},
"local": {
# 8B
"llama3": {
Expand Down
Loading
Loading