Skip to content

Commit

Permalink
Add OpenAIRealtimeExtensions with ToConversationFunctionTool (#5666)
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveSandersonMS authored Nov 19, 2024
1 parent 9cfd5ff commit f802390
Show file tree
Hide file tree
Showing 9 changed files with 415 additions and 14 deletions.
2 changes: 1 addition & 1 deletion eng/packages/General.props
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<PackageVersion Include="Microsoft.CodeAnalysis" Version="$(MicrosoftCodeAnalysisVersion)" />
<PackageVersion Include="Microsoft.IO.RecyclableMemoryStream" Version="3.0.0" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.3" />
<PackageVersion Include="OpenAI" Version="2.0.0" />
<PackageVersion Include="OpenAI" Version="2.1.0-beta.2" />
<PackageVersion Include="Polly" Version="8.4.2" />
<PackageVersion Include="Polly.Core" Version="8.4.2" />
<PackageVersion Include="Polly.Extensions" Version="8.4.2" />
Expand Down
2 changes: 1 addition & 1 deletion eng/packages/TestOnly.props
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<Project xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<PackageVersion Include="AutoFixture.AutoMoq" Version="4.17.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.0.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.1.0-beta.2" />
<PackageVersion Include="autofixture" Version="4.17.0" />
<PackageVersion Include="BenchmarkDotNet" Version="0.13.5" />
<PackageVersion Include="FluentAssertions" Version="6.11.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

<PropertyGroup>
<TargetFrameworks>$(TargetFrameworks);netstandard2.0</TargetFrameworks>
<NoWarn>$(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002</NoWarn>
<NoWarn>$(NoWarn);CA1063;CA1508;CA2227;SA1316;S1121;S3358;EA0002;OPENAI002</NoWarn>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<DisableNETStandardCompatErrors>true</DisableNETStandardCompatErrors>
</PropertyGroup>
Expand Down
14 changes: 3 additions & 11 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace Microsoft.Extensions.AI;

/// <summary>Represents an <see cref="IChatClient"/> for an OpenAI <see cref="OpenAIClient"/> or <see cref="OpenAI.Chat.ChatClient"/>.</summary>
public sealed partial class OpenAIChatClient : IChatClient
public sealed class OpenAIChatClient : IChatClient
{
private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement;

Expand Down Expand Up @@ -513,14 +513,14 @@ strictObj is bool strictValue ?
}

resultParameters = BinaryData.FromBytes(
JsonSerializer.SerializeToUtf8Bytes(tool, JsonContext.Default.OpenAIChatToolJson));
JsonSerializer.SerializeToUtf8Bytes(tool, OpenAIJsonContext.Default.OpenAIChatToolJson));
}

return ChatTool.CreateFunctionTool(aiFunction.Metadata.Name, aiFunction.Metadata.Description, resultParameters, strict);
}

/// <summary>Used to create the JSON payload for an OpenAI chat tool description.</summary>
private sealed class OpenAIChatToolJson
internal sealed class OpenAIChatToolJson
{
/// <summary>Gets a singleton JSON data for empty parameters. Optimization for the reasonably common case of a parameterless function.</summary>
public static BinaryData ZeroFunctionParametersSchema { get; } = new("""{"type":"object","required":[],"properties":{}}"""u8.ToArray());
Expand Down Expand Up @@ -681,12 +681,4 @@ private static FunctionCallContent ParseCallContentFromBinaryData(BinaryData ut8
FunctionCallContent.CreateFromParsedArguments(ut8Json, callId, name,
argumentParser: static json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);

/// <summary>Source-generated JSON type information.</summary>
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(OpenAIChatToolJson))]
private sealed partial class JsonContext : JsonSerializerContext;
}
16 changes: 16 additions & 0 deletions src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIJsonContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Text.Json;
using System.Text.Json.Serialization;

namespace Microsoft.Extensions.AI;

/// <summary>Source-generated JSON type information.</summary>
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
UseStringEnumConverter = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = true)]
[JsonSerializable(typeof(OpenAIChatClient.OpenAIChatToolJson))]
[JsonSerializable(typeof(OpenAIRealtimeExtensions.ConversationFunctionToolParametersSchema))]
internal sealed partial class OpenAIJsonContext : JsonSerializerContext;
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
using OpenAI.RealtimeConversation;

namespace Microsoft.Extensions.AI;

/// <summary>
/// Provides extension methods for working with <see cref="RealtimeConversationSession"/> and related types.
/// </summary>
public static class OpenAIRealtimeExtensions
{
private static readonly JsonElement _defaultParameterSchema = JsonDocument.Parse("{}").RootElement;

/// <summary>
/// Converts a <see cref="AIFunction"/> into a <see cref="ConversationFunctionTool"/> so that
/// it can be used with <see cref="RealtimeConversationClient"/>.
/// </summary>
/// <returns>A <see cref="ConversationFunctionTool"/> that can be used with <see cref="RealtimeConversationClient"/>.</returns>
public static ConversationFunctionTool ToConversationFunctionTool(this AIFunction aiFunction)
{
_ = Throw.IfNull(aiFunction);

var parametersSchema = new ConversationFunctionToolParametersSchema
{
Type = "object",
Properties = aiFunction.Metadata.Parameters
.ToDictionary(p => p.Name, GetParameterSchema),
Required = aiFunction.Metadata.Parameters
.Where(p => p.IsRequired)
.Select(p => p.Name),
};

return new ConversationFunctionTool
{
Name = aiFunction.Metadata.Name,
Description = aiFunction.Metadata.Description,
Parameters = new BinaryData(JsonSerializer.SerializeToUtf8Bytes(
parametersSchema, OpenAIJsonContext.Default.ConversationFunctionToolParametersSchema))
};
}

/// <summary>
/// Handles tool calls.
///
/// If the <paramref name="update"/> represents a tool call, calls the corresponding tool and
/// adds the result to the <paramref name="session"/>.
///
/// If the <paramref name="update"/> represents the end of a response, checks if this was due
/// to a tool call and if so, instructs the <paramref name="session"/> to begin responding to it.
/// </summary>
/// <param name="session">The <see cref="RealtimeConversationSession"/>.</param>
/// <param name="update">The <see cref="ConversationUpdate"/> being processed.</param>
/// <param name="tools">The available tools.</param>
/// <param name="detailedErrors">An optional flag specifying whether to disclose detailed exception information to the model. The default value is <see langword="false"/>.</param>
/// <param name="jsonSerializerOptions">An optional <see cref="JsonSerializerOptions"/> that controls JSON handling.</param>
/// <param name="cancellationToken">An optional <see cref="CancellationToken"/>.</param>
/// <returns>A <see cref="Task"/> that represents the completion of processing, including invoking any asynchronous tools.</returns>
public static async Task HandleToolCallsAsync(
this RealtimeConversationSession session,
ConversationUpdate update,
IReadOnlyList<AIFunction> tools,
bool? detailedErrors = false,
JsonSerializerOptions? jsonSerializerOptions = null,
CancellationToken cancellationToken = default)
{
_ = Throw.IfNull(session);
_ = Throw.IfNull(update);
_ = Throw.IfNull(tools);

if (update is ConversationItemStreamingFinishedUpdate itemFinished)
{
// If we need to call a tool to update the model, do so
if (!string.IsNullOrEmpty(itemFinished.FunctionName)
&& await itemFinished.GetFunctionCallOutputAsync(tools, detailedErrors, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) is { } output)
{
await session.AddItemAsync(output, cancellationToken).ConfigureAwait(false);
}
}
else if (update is ConversationResponseFinishedUpdate responseFinished)
{
// If we added one or more function call results, instruct the model to respond to them
if (responseFinished.CreatedItems.Any(item => !string.IsNullOrEmpty(item.FunctionName)))
{
await session!.StartResponseAsync(cancellationToken).ConfigureAwait(false);
}
}
}

private static JsonElement GetParameterSchema(AIFunctionParameterMetadata parameterMetadata)
{
return parameterMetadata switch
{
{ Schema: JsonElement jsonElement } => jsonElement,
_ => _defaultParameterSchema,
};
}

private static async Task<ConversationItem?> GetFunctionCallOutputAsync(
this ConversationItemStreamingFinishedUpdate update,
IReadOnlyList<AIFunction> tools,
bool? detailedErrors = false,
JsonSerializerOptions? jsonSerializerOptions = null,
CancellationToken cancellationToken = default)
{
if (!string.IsNullOrEmpty(update.FunctionName)
&& tools.FirstOrDefault(t => t.Metadata.Name == update.FunctionName) is AIFunction aiFunction)
{
var jsonOptions = jsonSerializerOptions ?? AIJsonUtilities.DefaultOptions;

var functionCallContent = FunctionCallContent.CreateFromParsedArguments(
update.FunctionCallArguments, update.FunctionCallId, update.FunctionName,
argumentParser: json => JsonSerializer.Deserialize(json,
(JsonTypeInfo<IDictionary<string, object>>)jsonOptions.GetTypeInfo(typeof(IDictionary<string, object>)))!);

try
{
var result = await aiFunction.InvokeAsync(functionCallContent.Arguments, cancellationToken).ConfigureAwait(false);
var resultJson = JsonSerializer.Serialize(result, jsonOptions.GetTypeInfo(typeof(object)));
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, resultJson);
}
catch (JsonException)
{
return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, "Invalid JSON");
}
catch (Exception e) when (!cancellationToken.IsCancellationRequested)
{
var message = "Error calling tool";

if (detailedErrors == true)
{
message += $": {e.Message}";
}

return ConversationItem.CreateFunctionCallOutput(update.FunctionCallId, message);
}
}

return null;
}

internal sealed class ConversationFunctionToolParametersSchema
{
public string? Type { get; set; }
public IDictionary<string, JsonElement>? Properties { get; set; }
public IEnumerable<string>? Required { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
<PropertyGroup>
<RootNamespace>Microsoft.Extensions.AI</RootNamespace>
<Description>Unit tests for Microsoft.Extensions.AI.OpenAI</Description>
<NoWarn>$(NoWarn);OPENAI002</NoWarn>
</PropertyGroup>

<PropertyGroup>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>
<InjectDiagnosticAttributesOnLegacy>true</InjectDiagnosticAttributesOnLegacy>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.TestUtilities;
using OpenAI.RealtimeConversation;
using Xunit;

namespace Microsoft.Extensions.AI;

public class OpenAIRealtimeIntegrationTests
{
private RealtimeConversationClient? _conversationClient;

public OpenAIRealtimeIntegrationTests()
{
_conversationClient = CreateConversationClient();
}

[ConditionalFact]
public async Task CanPerformFunctionCall()
{
SkipIfNotEnabled();

var roomCapacityTool = AIFunctionFactory.Create(GetRoomCapacity);
var sessionOptions = new ConversationSessionOptions
{
Instructions = "You help with booking appointments",
Tools = { roomCapacityTool.ToConversationFunctionTool() },
ContentModalities = ConversationContentModalities.Text,
};

using var session = await _conversationClient.StartConversationSessionAsync();
await session.ConfigureSessionAsync(sessionOptions);

await foreach (var update in session.ReceiveUpdatesAsync())
{
switch (update)
{
case ConversationSessionStartedUpdate:
await session.AddItemAsync(
ConversationItem.CreateUserMessage(["""
What type of room can hold the most people?
Reply with the full name of the biggest venue and its capacity only.
Do not mention the other venues.
"""]));
await session.StartResponseAsync();
break;

case ConversationResponseFinishedUpdate responseFinished:
var content = responseFinished.CreatedItems
.SelectMany(i => i.MessageContentParts ?? [])
.OfType<ConversationContentPart>()
.FirstOrDefault();
if (content is not null)
{
Assert.Contains("VehicleAssemblyBuilding", content.Text.Replace(" ", string.Empty));
Assert.Contains("12000", content.Text.Replace(",", string.Empty));
return;
}

break;
}

await session.HandleToolCallsAsync(update, [roomCapacityTool]);
}
}

[Description("Returns the number of people that can fit in a room.")]
private static int GetRoomCapacity(RoomType roomType)
{
return roomType switch
{
RoomType.ShuttleSimulator => throw new InvalidOperationException("No longer available"),
RoomType.NorthAtlantisLawn => 450,
RoomType.VehicleAssemblyBuilding => 12000,
_ => throw new NotSupportedException($"Unknown room type: {roomType}"),
};
}

private enum RoomType
{
ShuttleSimulator,
NorthAtlantisLawn,
VehicleAssemblyBuilding,
}

[MemberNotNull(nameof(_conversationClient))]
protected void SkipIfNotEnabled()
{
if (_conversationClient is null)
{
throw new SkipTestException("Client is not enabled.");
}
}

private static RealtimeConversationClient? CreateConversationClient()
{
var realtimeModel = Environment.GetEnvironmentVariable("OPENAI_REALTIME_MODEL");
if (string.IsNullOrEmpty(realtimeModel))
{
return null;
}

var openAiClient = (AzureOpenAIClient?)IntegrationTestHelpers.GetOpenAIClient();
return openAiClient?.GetRealtimeConversationClient(realtimeModel);
}
}
Loading

0 comments on commit f802390

Please sign in to comment.