diff --git a/.env.sample b/.env.sample index cc933d9b..d3ccf547 100644 --- a/.env.sample +++ b/.env.sample @@ -12,6 +12,9 @@ AWS_REGION= # Azure AZURE_API_KEY= +# Cerebras +CEREBRAS_API_KEY= + # Google Cloud GOOGLE_APPLICATION_CREDENTIALS=./google-adc GOOGLE_REGION= diff --git a/.gitignore b/.gitignore index a0974550..dc939d13 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ __pycache__/ env/ .env .google-adc +*.whl # Testing .coverage diff --git a/README.md b/README.md index add8b851..b7e564d9 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Simple, unified interface to multiple Generative AI providers. `aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future. Currently supported providers are - -OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova and Watsonx. +Cerebras, OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace Ollama, Sambanova and Watsonx. To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider. ## Installation diff --git a/aisuite/providers/cerebras_provider.py b/aisuite/providers/cerebras_provider.py new file mode 100644 index 00000000..f266f97a --- /dev/null +++ b/aisuite/providers/cerebras_provider.py @@ -0,0 +1,42 @@ +import os +import cerebras.cloud.sdk as cerebras +from aisuite.provider import Provider, LLMError +from aisuite.providers.message_converter import OpenAICompliantMessageConverter + + +class CerebrasMessageConverter(OpenAICompliantMessageConverter): + """ + Cerebras-specific message converter if needed. + """ + + pass + + +class CerebrasProvider(Provider): + def __init__(self, **config): + self.client = cerebras.Cerebras(**config) + self.transformer = CerebrasMessageConverter() + + def chat_completions_create(self, model, messages, **kwargs): + """ + Makes a request to the Cerebras chat completions endpoint using the official client. + """ + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs, # Pass any additional arguments to the Cerebras API. + ) + return self.transformer.convert_response(response.model_dump()) + + # Re-raise Cerebras API-specific exceptions. + except cerebras.cloud.sdk.PermissionDeniedError as e: + raise + except cerebras.cloud.sdk.AuthenticationError as e: + raise + except cerebras.cloud.sdk.RateLimitError as e: + raise + + # Wrap all other exceptions in LLMError. + except Exception as e: + raise LLMError(f"An error occurred: {e}") diff --git a/guides/cerebras.md b/guides/cerebras.md new file mode 100644 index 00000000..9a51edfe --- /dev/null +++ b/guides/cerebras.md @@ -0,0 +1,239 @@ +# Cerebras AI Suite Provider Guide + +## About Cerebras + +At Cerebras, we've developed the world's largest and fastest AI processor, the Wafer-Scale Engine-3 (WSE-3). The Cerebras CS-3 system, powered by the WSE-3, represents a new class of AI supercomputer that sets the standard for generative AI training and inference with unparalleled performance and scalability. + +With Cerebras as your inference provider, you can: +- Achieve unprecedented speed for AI inference workloads +- Build commercially with high throughput +- Effortlessly scale your AI workloads with our seamless clustering technology + +Our CS-3 systems can be quickly and easily clustered to create the largest AI supercomputers in the world, making it simple to place and run the largest models. Leading corporations, research institutions, and governments are already using Cerebras solutions to develop proprietary models and train popular open-source models. + +Want to experience the power of Cerebras? Check out our [website](https://cerebras.net) for more resources and explore options for accessing our technology through the Cerebras Cloud or on-premise deployments! + +> [!NOTE] +> This SDK has a mechanism that sends a few requests to `/v1/tcp_warming` upon construction to reduce the TTFT. If this behaviour is not desired, set `warm_tcp_connection=False` in the constructor. +> +> If you are repeatedly reconstructing the SDK instance it will lead to poor performance. It is recommended that you construct the SDK once and reuse the instance if possible. + +## Documentation + +The REST API documentation can be found on [inference-docs.cerebras.ai](https://inference-docs.cerebras.ai). + + +## Usage +Get an API Key from [cloud.cerebras.ai](https://cloud.cerebras.ai/) and add it to your environment variables: + +```shell +export CEREBRAS_API_KEY="your-cerebras-api-key" +``` + +Use the python client. + +```python +import aisuite as ai +client = ai.Client() + +models = "cerebras:llama3.1-8b" + +messages = [ + {"role": "system", "content": "Respond in Pirate English."}, + {"role": "user", "content": "Tell me a joke."}, +] + +response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0.75 +) +print(response.choices[0].message.content) + +``` + +The full API of this library can be found at https://inference-docs.cerebras.ai/api-reference. + +### Chat Completion + +```python +chat_completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Why is fast inference important?", + } + ], + model="llama3.1-8b", +) + +print(chat_completion) +``` + +### Text Completion + +```python +completion = client.completions.create( + prompt="It was a dark and stormy ", + max_tokens=100, + model="llama3.1-8b", +) + +print(completion) +``` + +## Streaming responses + +We provide support for streaming responses using Server Side Events (SSE). + +Note that when streaming, `usage` and `time_info` will be information will only be included in the final chunk. + +### Chat Completion + +```python +stream = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "Why is fast inference important?", + } + ], + model="llama3.1-8b", + stream=True, +) + +for chunk in stream: + print(chunk.choices[0].delta.content or "", end="") +``` + +### Text Completion + +```python +stream = client.completions.create( + prompt="It was a dark and stormy ", + max_tokens=100, + model="llama3.1-8b", + stream=True, +) + +for chunk in stream: + print(chunk.choices[0].text or "", end="") +``` + +### Retries + +Certain errors are automatically retried 2 times by default, with a short exponential backoff. +Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict, +429 Rate Limit, and >=500 Internal errors are all retried by default. + +You can use the `max_retries` option to configure or disable retry settings: + + +```python +from cerebras.cloud.sdk import Cerebras + +# Configure the default for all requests: +client = Cerebras( + # default is 2 + max_retries=0, +) + +# Or, configure per-request: +client.with_options(max_retries=5).chat.completions.create( + messages=[ + { + "role": "user", + "content": "Why is fast inference important?", + } + ], + model="llama3.1-8b", +) +``` + +### Timeouts + +By default requests time out after 1 minute. You can configure this with a `timeout` option, +which accepts a float or an [`httpx.Timeout`](https://www.python-httpx.org/advanced/#fine-tuning-the-configuration) object: + + +```python +from cerebras.cloud.sdk import Cerebras +import httpx + +# Configure the default for all requests: +client = Cerebras( + # 20 seconds (default is 1 minute) + timeout=20.0, +) + +# More granular control: +client = Cerebras( + timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0), +) + +# Override per-request: +client.with_options(timeout=5.0).chat.completions.create( + messages=[ + { + "role": "user", + "content": "Why is fast inference important?", + } + ], + model="llama3.1-8b", +) +``` + +On timeout, an `APITimeoutError` is thrown. + +Note that requests that time out are [retried twice by default](#retries). + +## Advanced + +### Logging + +We use the standard library [`logging`](https://docs.python.org/3/library/logging.html) module. + +You can enable logging by setting the environment variable `CEREBRAS_LOG` to `info`. + +```shell +$ export CEREBRAS_LOG=info +``` + +Or to `debug` for more verbose logging. + +#### Undocumented request params + +If you want to explicitly send an extra param, you can do so with the `extra_query`, `extra_body`, and `extra_headers` request +options. + +### Configuring the HTTP client + +You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including: + +- Support for [proxies](https://www.python-httpx.org/advanced/proxies/) +- Custom [transports](https://www.python-httpx.org/advanced/transports/) +- Additional [advanced](https://www.python-httpx.org/advanced/clients/) functionality + +```python +import httpx +from cerebras.cloud.sdk import Cerebras, DefaultHttpxClient + +client = Cerebras( + # Or use the `CEREBRAS_BASE_URL` env var + base_url="http://my.test.server.example.com:8083", + http_client=DefaultHttpxClient( + proxy="http://my.test.proxy.example.com", + transport=httpx.HTTPTransport(local_address="0.0.0.0"), + ), +) +``` + +You can also customize the client on a per-request basis by using `with_options()`: + +```python +client.with_options(http_client=DefaultHttpxClient(...)) +``` + +## Requirements + +Python 3.8 or higher. diff --git a/poetry.lock b/poetry.lock index 493a6cd1..7455b3a3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -616,6 +616,27 @@ files = [ {file = "cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95"}, ] +[[package]] +name = "cerebras-cloud-sdk" +version = "1.19.0" +description = "The official Python library for the cerebras API" +optional = false +python-versions = ">=3.8" +groups = ["main", "dev"] +markers = "python_version <= \"3.11\" or python_version >= \"3.12\"" +files = [ + {file = "cerebras_cloud_sdk-1.19.0-py3-none-any.whl", hash = "sha256:783dea26b72e9e6e545d5b49f4ccfe43595976af5ee4492db55ab5c8237a1144"}, + {file = "cerebras_cloud_sdk-1.19.0.tar.gz", hash = "sha256:7e4efc55799141bb29114d97e42da83024019c0be278d2f63135991fd1e2fe76"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +typing-extensions = ">=4.10,<5" + [[package]] name = "certifi" version = "2024.12.14" @@ -7400,10 +7421,11 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["anthropic", "cohere", "groq", "openai"] +all = ["anthropic", "cerebras_cloud_sdk", "cohere", "groq", "openai"] anthropic = ["anthropic"] aws = ["boto3"] azure = [] +cerebras = ["cerebras_cloud_sdk"] cohere = ["cohere"] deepseek = ["openai"] google = ["vertexai"] @@ -7417,4 +7439,4 @@ watsonx = ["ibm-watsonx-ai"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "491e4f71e526533bfacd230d96ad5204c2409a61826794e7ca9e4f2fd3e6998e" +content-hash = "cb66b6c5aa0fe3cce45c3d2f57f339ee01996c8ea559af612a173a9ede12792e" diff --git a/pyproject.toml b/pyproject.toml index f9635ec6..d432f471 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ mistralai = { version = "^1.0.3", optional = true } openai = { version = "^1.35.8", optional = true } ibm-watsonx-ai = { version = "^1.1.16", optional = true } docstring-parser = { version = "^0.14.0", optional = true } +cerebras_cloud_sdk = { version = "^1.19.0", optional = true } # Optional dependencies for different providers httpx = "~0.27.0" @@ -23,6 +24,7 @@ httpx = "~0.27.0" anthropic = ["anthropic"] aws = ["boto3"] azure = [] +cerebras = ["cerebras_cloud_sdk"] cohere = ["cohere"] deepseek = ["openai"] google = ["vertexai"] @@ -32,7 +34,7 @@ mistral = ["mistralai"] ollama = [] openai = ["openai"] watsonx = ["ibm-watsonx-ai"] -all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers +all = ["anthropic", "aws", "cerebras_cloud_sdk", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers [tool.poetry.group.dev.dependencies] pre-commit = "^3.7.1" @@ -51,6 +53,7 @@ sentence-transformers = "^3.0.1" datasets = "^2.20.0" vertexai = "^1.63.0" ibm-watsonx-ai = "^1.1.16" +cerebras_cloud_sdk = "^1.19.0" [tool.poetry.group.test] optional = true diff --git a/tests/providers/test_cerebras_provider.py b/tests/providers/test_cerebras_provider.py new file mode 100644 index 00000000..5124a14e --- /dev/null +++ b/tests/providers/test_cerebras_provider.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from aisuite.providers.cerebras_provider import CerebrasProvider + + +@pytest.fixture(autouse=True) +def set_api_key_env_var(monkeypatch): + """Fixture to set environment variables for tests.""" + monkeypatch.setenv("CEREBRAS_API_KEY", "test-api-key") + + +def test_cerebras_provider(): + """High-level test that the provider is initialized and chat completions are requested successfully.""" + + user_greeting = "Hello!" + message_history = [{"role": "user", "content": user_greeting}] + selected_model = "our-favorite-model" + chosen_temperature = 0.75 + response_text_content = "mocked-text-response-from-model" + + provider = CerebrasProvider() + mock_response = MagicMock() + mock_response.model_dump.return_value = { + "choices": [{"message": {"content": response_text_content}}] + } + + with patch.object( + provider.client.chat.completions, + "create", + return_value=mock_response, + ) as mock_create: + response = provider.chat_completions_create( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + mock_create.assert_called_with( + messages=message_history, + model=selected_model, + temperature=chosen_temperature, + ) + + assert response.choices[0].message.content == response_text_content