Skip to content

Commit 24e7b76

Browse files
[azopenai] Fixing some issues with incorrect/incomplete types in generation (#22119)
Fixes: - ToolChoice was unmodeled. - ResponseFormat for ChatCompletions wasn't settable using the swagger as we had it (it's an object, not a string)
1 parent e96bba7 commit 24e7b76

11 files changed

+319
-32
lines changed

sdk/ai/azopenai/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Release History
22

3-
## 0.4.0 (2023-12-07)
3+
## 0.4.0 (2023-12-11)
44

55
Support for many of the features mentioned in OpenAI's November Dev Day and Microsoft's 2023 Ignite conference
66

sdk/ai/azopenai/assets.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "go",
44
"TagPrefix": "go/ai/azopenai",
5-
"Tag": "go/ai/azopenai_9ed7d01267"
5+
"Tag": "go/ai/azopenai_d4fd4783ec"
66
}

sdk/ai/azopenai/autorest.md

+33-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ These settings apply only when `--go` is specified on the command line.
44

55
``` yaml
66
input-file:
7-
- https://github.com/Azure/azure-rest-api-specs/blob/d402f685809d6d08be9c0b45065cadd7d78ab870/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json
8-
7+
- https://github.com/Azure/azure-rest-api-specs/blob/3e0e2a93ddb3c9c44ff1baf4952baa24ca98e9db/specification/cognitiveservices/data-plane/AzureOpenAI/inference/preview/2023-12-01-preview/generated.json
98
output-folder: ../azopenai
109
clear-output-folder: false
1110
module: github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai
@@ -98,6 +97,20 @@ directive:
9897
transform: return $.replace(/InternalOYDAuthTypeRename/g, "configType")
9998
```
10099

100+
`ChatCompletionsResponseFormat.Type`
101+
102+
```yaml
103+
directive:
104+
- from: swagger-document
105+
where: $.definitions.ChatCompletionsResponseFormat
106+
transform: $.properties.type["x-ms-client-name"] = "InternalChatCompletionsResponseFormat"
107+
- from:
108+
- models.go
109+
- models_serde.go
110+
where: $
111+
transform: return $.replace(/InternalChatCompletionsResponseFormat/g, "respType")
112+
```
113+
101114
## Model -> DeploymentName
102115

103116
```yaml
@@ -571,3 +584,21 @@ directive:
571584
572585
return $.replace(/(func \(c ChatCompletionsOptions\) MarshalJSON\(\).+?populate\(objectMap, "frequency_penalty", c.FrequencyPenalty\))/s, "$1\n" + populateLines)
573586
```
587+
588+
Fix ToolChoice discriminated union
589+
590+
```yaml
591+
directive:
592+
- from: swagger-document
593+
where: $.definitions.ChatCompletionsOptions.properties
594+
transform: $["tool_choice"]["x-ms-client-name"] = "ToolChoiceRenameMe"
595+
- from:
596+
- models.go
597+
- models_serde.go
598+
where: $
599+
transform: |
600+
return $
601+
.replace(/^\s+ToolChoiceRenameMe.+$/m, "ToolChoice *ChatCompletionsToolChoice") // update the name _and_ type for the field
602+
.replace(/ToolChoiceRenameMe/g, "ToolChoice") // rename all other references
603+
.replace(/populateAny\(objectMap, "tool_choice", c\.ToolChoice\)/, 'populate(objectMap, "tool_choice", c.ToolChoice)'); // treat field as typed so nil means omit.
604+
```

sdk/ai/azopenai/client_chat_completions_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package azopenai_test
88

99
import (
1010
"context"
11+
"encoding/json"
1112
"errors"
1213
"io"
1314
"net/http"
@@ -262,3 +263,40 @@ func TestClient_OpenAI_GetChatCompletions_Vision(t *testing.T) {
262263

263264
t.Logf(*resp.Choices[0].Message.Content)
264265
}
266+
267+
func TestGetChatCompletions_usingResponseFormatForJSON(t *testing.T) {
268+
testFn := func(t *testing.T, chatClient *azopenai.Client, deploymentName string) {
269+
body := azopenai.ChatCompletionsOptions{
270+
DeploymentName: &deploymentName,
271+
Messages: []azopenai.ChatRequestMessageClassification{
272+
&azopenai.ChatRequestSystemMessage{Content: to.Ptr("You are a helpful assistant designed to output JSON.")},
273+
&azopenai.ChatRequestUserMessage{
274+
Content: azopenai.NewChatRequestUserMessageContent("List capital cities and their states"),
275+
},
276+
},
277+
// Without this format directive you end up getting JSON, but with a non-JSON preamble, like this:
278+
// "I'm happy to help! Here are some examples of capital cities and their corresponding states:\n\n```json\n{\n" (etc)
279+
ResponseFormat: &azopenai.ChatCompletionsJSONResponseFormat{},
280+
Temperature: to.Ptr[float32](0.0),
281+
}
282+
283+
resp, err := chatClient.GetChatCompletions(context.Background(), body, nil)
284+
require.NoError(t, err)
285+
286+
// validate that it came back as JSON data
287+
var v any
288+
err = json.Unmarshal([]byte(*resp.Choices[0].Message.Content), &v)
289+
require.NoError(t, err)
290+
require.NotEmpty(t, v)
291+
}
292+
293+
t.Run("OpenAI", func(t *testing.T) {
294+
chatClient := newOpenAIClientForTest(t)
295+
testFn(t, chatClient, "gpt-3.5-turbo-1106")
296+
})
297+
298+
t.Run("AzureOpenAI", func(t *testing.T) {
299+
chatClient := newTestClient(t, azureOpenAI.DallE.Endpoint)
300+
testFn(t, chatClient, "gpt-4-1106-preview")
301+
})
302+
}

sdk/ai/azopenai/client_functions_test.go

+36-4
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,46 @@ type ParamProperty struct {
2828
func TestGetChatCompletions_usingFunctions(t *testing.T) {
2929
// https://platform.openai.com/docs/guides/gpt/function-calling
3030

31+
useSpecificTool := azopenai.NewChatCompletionsToolChoice(
32+
azopenai.ChatCompletionsToolChoiceFunction{Name: "get_current_weather"},
33+
)
34+
3135
t.Run("OpenAI", func(t *testing.T) {
3236
chatClient := newOpenAIClientForTest(t)
33-
testChatCompletionsFunctions(t, chatClient, openAI.ChatCompletions)
34-
testChatCompletionsFunctions(t, chatClient, openAI.ChatCompletionsLegacyFunctions)
37+
38+
testData := []struct {
39+
Model string
40+
ToolChoice *azopenai.ChatCompletionsToolChoice
41+
}{
42+
// all of these variants use the tool provided - auto just also works since we did provide
43+
// a tool reference and ask a question to use it.
44+
{Model: openAI.ChatCompletions, ToolChoice: nil},
45+
{Model: openAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
46+
{Model: openAI.ChatCompletionsLegacyFunctions, ToolChoice: useSpecificTool},
47+
}
48+
49+
for _, td := range testData {
50+
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
51+
}
3552
})
3653

3754
t.Run("AzureOpenAI", func(t *testing.T) {
3855
chatClient := newAzureOpenAIClientForTest(t, azureOpenAI)
39-
testChatCompletionsFunctions(t, chatClient, azureOpenAI.ChatCompletions)
56+
57+
testData := []struct {
58+
Model string
59+
ToolChoice *azopenai.ChatCompletionsToolChoice
60+
}{
61+
// all of these variants use the tool provided - auto just also works since we did provide
62+
// a tool reference and ask a question to use it.
63+
{Model: azureOpenAI.ChatCompletions, ToolChoice: nil},
64+
{Model: azureOpenAI.ChatCompletions, ToolChoice: azopenai.ChatCompletionsToolChoiceAuto},
65+
{Model: azureOpenAI.ChatCompletions, ToolChoice: useSpecificTool},
66+
}
67+
68+
for _, td := range testData {
69+
testChatCompletionsFunctions(t, chatClient, td.Model, td.ToolChoice)
70+
}
4071
})
4172
}
4273

@@ -120,7 +151,7 @@ func testChatCompletionsFunctionsOlderStyle(t *testing.T, client *azopenai.Clien
120151
require.Equal(t, location{Location: "Boston, MA", Unit: "celsius"}, *funcParams)
121152
}
122153

123-
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string) {
154+
func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, deploymentName string, toolChoice *azopenai.ChatCompletionsToolChoice) {
124155
body := azopenai.ChatCompletionsOptions{
125156
DeploymentName: &deploymentName,
126157
Messages: []azopenai.ChatRequestMessageClassification{
@@ -150,6 +181,7 @@ func testChatCompletionsFunctions(t *testing.T, chatClient *azopenai.Client, dep
150181
},
151182
},
152183
},
184+
ToolChoice: toolChoice,
153185
Temperature: to.Ptr[float32](0.0),
154186
}
155187

sdk/ai/azopenai/constants.go

-20
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdk/ai/azopenai/custom_models.go

+53
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,56 @@ func (e *Error) Error() string {
132132

133133
return *e.message
134134
}
135+
136+
// ChatCompletionsToolChoice controls which tool is used for this ChatCompletions call.
137+
// You can choose between:
138+
// - [ChatCompletionsToolChoiceAuto] means the model can pick between generating a message or calling a function.
139+
// - [ChatCompletionsToolChoiceNone] means the model will not call a function and instead generates a message
140+
// - Use the [NewChatCompletionsToolChoice] function to specify a specific tool.
141+
type ChatCompletionsToolChoice struct {
142+
value any
143+
}
144+
145+
var (
146+
// ChatCompletionsToolChoiceAuto means the model can pick between generating a message or calling a function.
147+
ChatCompletionsToolChoiceAuto *ChatCompletionsToolChoice = &ChatCompletionsToolChoice{value: "auto"}
148+
149+
// ChatCompletionsToolChoiceNone means the model will not call a function and instead generates a message.
150+
ChatCompletionsToolChoiceNone *ChatCompletionsToolChoice = &ChatCompletionsToolChoice{value: "none"}
151+
)
152+
153+
// NewChatCompletionsToolChoice creates a ChatCompletionsToolChoice for a specific tool.
154+
func NewChatCompletionsToolChoice[T ChatCompletionsToolChoiceFunction](v T) *ChatCompletionsToolChoice {
155+
return &ChatCompletionsToolChoice{value: v}
156+
}
157+
158+
// ChatCompletionsToolChoiceFunction can be used to force the model to call a particular function.
159+
type ChatCompletionsToolChoiceFunction struct {
160+
// Name is the name of the function to call.
161+
Name string
162+
}
163+
164+
// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolChoiceFunction.
165+
func (tf ChatCompletionsToolChoiceFunction) MarshalJSON() ([]byte, error) {
166+
type jsonInnerFunc struct {
167+
Name string `json:"name"`
168+
}
169+
170+
type jsonFormat struct {
171+
Type string `json:"type"`
172+
Function jsonInnerFunc `json:"function"`
173+
}
174+
175+
return json.Marshal(jsonFormat{
176+
Type: "function",
177+
//nolint:gosimple,can't use the ChatCompletionsToolChoiceFunction here or marshalling will be circular!
178+
Function: jsonInnerFunc{
179+
Name: tf.Name,
180+
},
181+
})
182+
}
183+
184+
// MarshalJSON implements the json.Marshaller interface for type ChatCompletionsToolChoice.
185+
func (tc ChatCompletionsToolChoice) MarshalJSON() ([]byte, error) {
186+
return json.Marshal(tc.value)
187+
}

sdk/ai/azopenai/interfaces.go

+9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdk/ai/azopenai/models.go

+42-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)