Skip to content

Commit

Permalink
Merge pull request #285 from gs-snagaraj/main
Browse files Browse the repository at this point in the history
Support OpenAIFunction Custom object Schema
  • Loading branch information
johnoliver authored Mar 3, 2025
2 parents 63db2ab + b71458a commit 3c3b2e7
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.responseformat.ResponseSchemaGenerator;
import com.microsoft.semantickernel.semanticfunctions.InputVariable;
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionMetadata;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -159,14 +164,17 @@ private static String getSchemaForFunctionParameter(@Nullable InputVariable para
entries.add("\"type\":\"" + type + "\"");

// Add description if present
String description =null;
if (parameter != null && parameter.getDescription() != null && !parameter.getDescription()
.isEmpty()) {
String description = parameter.getDescription();
description = parameter.getDescription();
description = description.replaceAll("\\r?\\n|\\r", "");
description = description.replace("\"", "\\\"");

description = String.format("\"description\":\"%s\"", description);
entries.add(description);
entries.add(String.format("\"description\":\"%s\"", description));
}
// If custom type, generate schema
if("object".equalsIgnoreCase(type)) {
return getObjectSchema(parameter.getType(), description);
}

// Add enum options if parameter is an enum
Expand Down Expand Up @@ -219,4 +227,20 @@ private static String getJavaTypeToOpenAiFunctionType(String javaType) {
return "object";
}
}

private static String getObjectSchema(String type, String description){
String schema= "{ \"type\" : \"object\" }";
try {
Class<?> clazz = Class.forName(type);
schema = ResponseSchemaGenerator.jacksonGenerator().generateSchema(clazz);

} catch (ClassNotFoundException | SKException ignored) {

}
Map<String, Object> properties = BinaryData.fromString(schema).toObject(Map.class);
if(StringUtils.isNotBlank(description)) {
properties.put("description", description);
}
return BinaryData.fromObject(properties).toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.openai.chatcompletion;

import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.microsoft.semantickernel.orchestration.responseformat.JsonSchemaResponseFormat;
import com.microsoft.semantickernel.plugin.KernelPlugin;
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
import com.microsoft.semantickernel.semanticfunctions.KernelFunction;
import com.microsoft.semantickernel.semanticfunctions.annotations.DefineKernelFunction;
import com.microsoft.semantickernel.semanticfunctions.annotations.KernelFunctionParameter;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

public class JsonSchemaTest {

Expand All @@ -24,4 +31,86 @@ public void jacksonGenerationTest() throws JsonProcessingException {
"\"type\":\"object\",\"properties\":{\"bar\":{}}"));
}

@Test
public void openAIFunctionTest() {
KernelPlugin plugin = KernelPluginFactory.createFromObject(
new TestPlugin(),
"test");

Assertions.assertNotNull(plugin);
Assertions.assertEquals(plugin.getName(), "test");
Assertions.assertEquals(plugin.getFunctions().size(), 3);

KernelFunction<?> testFunction = plugin.getFunctions()
.get("asyncPersonFunction");
OpenAIFunction openAIFunction = OpenAIFunction.build(
testFunction.getMetadata(),
plugin.getName());

String parameters = "{\"type\":\"object\",\"required\":[\"person\",\"input\"],\"properties\":{\"input\":{\"type\":\"string\",\"description\":\"input string\"},\"person\":{\"type\":\"object\",\"properties\":{\"age\":{\"type\":\"integer\",\"description\":\"The age of the person.\"},\"name\":{\"type\":\"string\",\"description\":\"The name of the person.\"},\"title\":{\"type\":\"string\",\"enum\":[\"MS\",\"MRS\",\"MR\"],\"description\":\"The title of the person.\"}},\"required\":[\"age\",\"name\",\"title\"],\"additionalProperties\":false,\"description\":\"input person\"}}}";
Assertions.assertEquals(parameters, openAIFunction.getFunctionDefinition().getParameters().toString());

}


public static class TestPlugin {

@DefineKernelFunction
public String testFunction(
@KernelFunctionParameter(name = "input", description = "input string") String input) {
return "test" + input;
}

@DefineKernelFunction(returnType = "int")
public Mono<Integer> asyncTestFunction(
@KernelFunctionParameter(name = "input") String input) {
return Mono.just(1);
}

@DefineKernelFunction(returnType = "int", description = "test function description",
name = "asyncPersonFunction", returnDescription = "test return description")
public Mono<Integer> asyncPersonFunction(
@KernelFunctionParameter(name = "person",description = "input person", type = Person.class) Person person,
@KernelFunctionParameter(name = "input", description = "input string") String input) {
return Mono.just(1);
}
}

private static enum Title {
MS,
MRS,
MR
}

public static class Person {
@JsonPropertyDescription("The name of the person.")
private String name;
@JsonPropertyDescription("The age of the person.")
private int age;
@JsonPropertyDescription("The title of the person.")
private Title title;


public Person(String name, int age) {
this.name = name;
this.age = age;
}

public String getName() {
return name;
}

public int getAge() {
return age;
}

public Title getTitle() {
return title;
}

public void setTitle(Title title) {
this.title = title;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public static void main(String[] args) throws Exception {
ChatCompletionService.class);

ContextVariableTypes
.addGlobalConverter(ContextVariableTypeConverter.builder(LightModel.class)
.toPromptString(new Gson()::toJson)
.build());
.addGlobalConverter(new LightModelTypeConverter());

KernelHooks hook = new KernelHooks();

Expand All @@ -99,9 +97,7 @@ public static void main(String[] args) throws Exception {
InvocationContext invocationContext = new Builder()
.withReturnMode(InvocationReturnMode.LAST_MESSAGE_ONLY)
.withToolCallBehavior(ToolCallBehavior.allowAllKernelFunctions(true))
.withContextVariableConverter(ContextVariableTypeConverter.builder(LightModel.class)
.toPromptString(new Gson()::toJson)
.build())
.withContextVariableConverter(new LightModelTypeConverter())
.build();

// Create a history to store the conversation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.samples.demos.lights;

import com.fasterxml.jackson.annotation.JsonPropertyDescription;

public class LightModel {

@JsonPropertyDescription("The unique identifier of the light")
private int id;

@JsonPropertyDescription("The name of the light")
private String name;

@JsonPropertyDescription("The state of the light")
private Boolean isOn;

public LightModel(int id, String name, Boolean isOn) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.microsoft.semantickernel.samples.demos.lights;

import com.google.gson.Gson;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;

public class LightModelTypeConverter extends ContextVariableTypeConverter<LightModel> {
private static final Gson gson = new Gson();

public LightModelTypeConverter() {
super(
LightModel.class,
obj -> {
if(obj instanceof String) {
return gson.fromJson((String)obj, LightModel.class);
} else {
return gson.fromJson(gson.toJson(obj), LightModel.class);
}
},
(types, lightModel) -> gson.toJson(lightModel),
json -> gson.fromJson(json, LightModel.class)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ public List<LightModel> getLights() {
return lights;
}

@DefineKernelFunction(name = "add_light", description = "Adds a new light")
public String addLight(
@KernelFunctionParameter(name = "newLight", description = "new Light Details", type = LightModel.class) LightModel light) {
if( light != null) {
System.out.println("Adding light " + light.getName());
lights.add(light);
return "Light added";
}
return "Light failed to added";
}

@DefineKernelFunction(name = "change_state", description = "Changes the state of the light")
public LightModel changeState(
@KernelFunctionParameter(name = "id", description = "The ID of the light to change", type = int.class) int id,
Expand Down

0 comments on commit 3c3b2e7

Please sign in to comment.