diff --git a/aisuite/providers/openai_provider.py b/aisuite/providers/openai_provider.py index fdd32d4..493e2d4 100644 --- a/aisuite/providers/openai_provider.py +++ b/aisuite/providers/openai_provider.py @@ -1,4 +1,6 @@ import openai +from pydantic import BaseModel +from typing import Type import os from aisuite.provider import Provider, LLMError @@ -23,9 +25,21 @@ def __init__(self, **config): # Pass the entire config to the OpenAI client constructor self.client = openai.OpenAI(**config) - def chat_completions_create(self, model, messages, **kwargs): + def chat_completions_create( + self, model, messages, response_format: Type[BaseModel] | None = None, **kwargs + ): # Any exception raised by OpenAI will be returned to the caller. # Maybe we should catch them and raise a custom LLMError. + if response_format is not None: + response = self.client.beta.chat.completions.parse( + model=model, + messages=messages, + response_format=response_format, + **kwargs # Pass any additional arguments to the OpenAI API + ) + response.choices[0].message.content = response.choices[0].message.parsed + return response + response = self.client.chat.completions.create( model=model, messages=messages, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 04d1243..34159d9 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -3,6 +3,7 @@ import pytest from aisuite import Client +from pydantic import BaseModel @pytest.fixture(scope="module") @@ -107,6 +108,38 @@ def test_client_chat_completions( assert model_response == expected_response +@pytest.mark.parametrize( + argnames=("patch_target", "provider", "model"), + argvalues=[ + ( + "aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create", + "openai", + "gpt-4o", + ), + ], +) +def test_client_chat_completions_with_structured_output( + provider_configs: dict, patch_target: str, provider: str, model: str +): + expected_response = f"{patch_target}_{provider}_{model}" + with patch(patch_target) as mock_provider: + mock_provider.return_value = expected_response + client = Client() + client.configure(provider_configs) + messages = [ + {"role": "user", "content": "Tell me a pirate joke."}, + ] + + class Joke(BaseModel): + joke: str + + model_str = f"{provider}:{model}" + model_response = client.chat.completions.create( + model_str, messages=messages, response_format=Joke + ) + assert model_response == expected_response + + def test_invalid_provider_in_client_config(): # Testing an invalid provider name in the configuration invalid_provider_configs = {