Skip to content

Commit

Permalink
Merge 'main' into 'any-tag-filter'
Browse files Browse the repository at this point in the history
  • Loading branch information
milderhc committed Oct 23, 2024
2 parents be089f9 + 3e7f6c8 commit c4bc8c6
Show file tree
Hide file tree
Showing 123 changed files with 1,407 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@

import javax.annotation.Nullable;

/**
* Makes a Gemini service available to the Semantic Kernel.
*/
public class GeminiService implements AIService {
private final VertexAI client;
private final String modelId;

/**
* Creates a new Gemini service.
* @param client The VertexAI client
* @param modelId The Gemini model ID
*/
protected GeminiService(VertexAI client, String modelId) {
this.client = client;
this.modelId = modelId;
Expand All @@ -27,6 +35,10 @@ public String getServiceId() {
return null;
}

/**
* Gets the VertexAI client.
* @return The VertexAI client
*/
protected VertexAI getClient() {
return client;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

/**
* Builder for a Gemini service.
* @param <T> The type of the service
* @param <U> The type of the builder
*/
public abstract class GeminiServiceBuilder<T, U extends GeminiServiceBuilder<T, U>> implements
SemanticKernelBuilder<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,18 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
* A chat completion service that uses the Gemini model to generate chat completions.
*/
public class GeminiChatCompletion extends GeminiService implements ChatCompletionService {

private static final Logger LOGGER = LoggerFactory.getLogger(GeminiChatCompletion.class);

/**
* Constructor for {@link GeminiChatCompletion}.
* @param client The VertexAI client
* @param modelId The model ID
*/
public GeminiChatCompletion(VertexAI client, String modelId) {
super(client, modelId);
}
Expand Down Expand Up @@ -391,6 +399,13 @@ private Tool getTool(@Nullable Kernel kernel, @Nullable ToolCallBehavior toolCal
return toolBuilder.build();
}

/**
* Invoke the Gemini function call.
* @param kernel The semantic kernel
* @param invocationContext Additional context for the invocation
* @param geminiFunction The Gemini function call
* @return The result of the function call
*/
public Mono<GeminiFunctionCall> performFunctionCall(@Nullable Kernel kernel,
@Nullable InvocationContext invocationContext, GeminiFunctionCall geminiFunction) {
if (kernel == null) {
Expand Down Expand Up @@ -433,6 +448,9 @@ public Mono<GeminiFunctionCall> performFunctionCall(@Nullable Kernel kernel,
.map(result -> new GeminiFunctionCall(geminiFunction.getFunctionCall(), result));
}

/**
* Builder for {@link GeminiChatCompletion}.
*/
public static class Builder extends GeminiServiceBuilder<GeminiChatCompletion, Builder> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/**
* Represents a function call in Gemini.
*/
public class GeminiFunctionCall {
@Nonnull
private final FunctionCall functionCall;
Expand All @@ -17,6 +20,11 @@ public class GeminiFunctionCall {
private final String pluginName;
private final String functionName;

/**
* Creates a new Gemini function call.
* @param functionCall The function call
* @param functionResult The result of the function invocation
*/
@SuppressFBWarnings("EI_EXPOSE_REP2")
public GeminiFunctionCall(
@Nonnull FunctionCall functionCall,
Expand All @@ -29,19 +37,35 @@ public GeminiFunctionCall(
this.functionName = name[1];
}

/**
* Gets the plugin name.
* @return The plugin name
*/
public String getPluginName() {
return pluginName;
}

/**
* Gets the function name.
* @return The function name
*/
public String getFunctionName() {
return functionName;
}

/**
* Gets the function call.
* @return The function call
*/
@SuppressFBWarnings("EI_EXPOSE_REP")
public FunctionCall getFunctionCall() {
return functionCall;
}

/**
* Gets the function result.
* @return The function result
*/
@Nullable
public FunctionResult<?> getFunctionResult() {
return functionResult;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.google.chatcompletion;

/**
* Represents the role of a message in a Gemini conversation.
*/
public enum GeminiRole {
/**
* A user message is a message generated by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class GeminiStreamingChatMessageContent<T> extends GeminiChatMessageConte
* @param innerContent The inner content.
* @param encoding The encoding.
* @param metadata The metadata.
* @param id The id of the message.
* @param geminiFunctionCalls The function calls.
*/
public GeminiStreamingChatMessageContent(AuthorRole authorRole, String content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,26 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Parses an XML prompt for a Gemini chat.
*/
public class GeminiXMLPromptParser {

private static final Logger LOGGER = LoggerFactory.getLogger(GeminiXMLPromptParser.class);

/**
* Represents a parsed prompt for Gemini chat.
*/
public static class GeminiParsedPrompt {

private final ChatHistory chatHistory;
private final List<FunctionDeclaration> functions;

/**
* Creates a new parsed prompt.
* @param parsedChatHistory The chat history
* @param parsedFunctions The functions declarations.
*/
protected GeminiParsedPrompt(
ChatHistory parsedChatHistory,
@Nullable List<FunctionDeclaration> parsedFunctions) {
Expand All @@ -36,10 +47,18 @@ protected GeminiParsedPrompt(
this.functions = parsedFunctions;
}

/**
* Gets the chat history.
* @return A copy of the chat history.
*/
public ChatHistory getChatHistory() {
return new ChatHistory(chatHistory.getMessages());
}

/**
* Gets the functions declarations.
* @return A copy of the functions declarations.
*/
public List<FunctionDeclaration> getFunctions() {
return Collections.unmodifiableList(functions);
}
Expand Down Expand Up @@ -131,6 +150,11 @@ public ChatPromptParseVisitor<GeminiParsedPrompt> reset() {
}
}

/**
* Create a GeminiParsedPrompt by parsing a raw prompt.
* @param rawPrompt the raw prompt to parse.
* @return The parsed prompt.
*/
public static GeminiParsedPrompt parse(String rawPrompt) {
ChatPromptParseVisitor<GeminiParsedPrompt> visitor = ChatXMLPromptParser.parse(
rawPrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/**
* A Gemini service for text generation.
* @see TextGenerationService
*/
public class GeminiTextGenerationService extends GeminiService implements TextGenerationService {

private static final Logger LOGGER = LoggerFactory.getLogger(GeminiTextGenerationService.class);

/**
* Creates a new Gemini text generation service.
* @param client The VertexAI client
* @param modelId The Gemini model ID
*/
public GeminiTextGenerationService(VertexAI client, String modelId) {
super(client, modelId);
}

/**
* Creates a new builder for a Gemini text generation service.
* @return The builder
*/
public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -121,6 +134,9 @@ private GenerativeModel getGenerativeModel(
return modelBuilder.build();
}

/**
* Builder for a Gemini text generation service.
*/
public static class Builder extends
GeminiServiceBuilder<GeminiTextGenerationService, GeminiTextGenerationService.Builder> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,21 @@
import reactor.core.publisher.Mono;
import javax.annotation.Nullable;

/**
* A client for the Hugging Face API.
*/
public class HuggingFaceClient {

private final KeyCredential key;
private final String endpoint;
private final HttpClient httpClient;

/**
* Creates a new Hugging Face client.
* @param key The key credential for endpoint authentication.
* @param endpoint The endpoint for the Hugging Face API.
* @param httpClient The HTTP client to use for requests.
*/
public HuggingFaceClient(
KeyCredential key,
String endpoint,
Expand Down Expand Up @@ -74,6 +83,12 @@ public GeneratedTextItemList(

}

/**
* Gets the text contents from the Hugging Face API.
* @param modelId The model ID.
* @param textGenerationRequest The text generation request.
* @return The generated text items.
*/
public Mono<List<GeneratedTextItem>> getTextContentsAsync(
String modelId,
TextGenerationRequest textGenerationRequest) {
Expand Down Expand Up @@ -131,10 +146,17 @@ private Mono<String> performRequest(String modelId,
return responseBody;
}

/**
* Creates a new builder for a Hugging Face client.
* @return The builder
*/
public static Builder builder() {
return new Builder();
}

/**
* Builder for a Hugging Face client.
*/
public static class Builder {

@Nullable
Expand All @@ -144,6 +166,10 @@ public static class Builder {
@Nullable
private HttpClient httpClient = null;

/**
* Builds the Hugging Face client.
* @return The client
*/
public HuggingFaceClient build() {
if (httpClient == null) {
httpClient = HttpClient.createDefault();
Expand All @@ -160,16 +186,31 @@ public HuggingFaceClient build() {
httpClient);
}

/**
* Sets the key credential for the client.
* @param key The key credential
* @return The builder
*/
public Builder credential(KeyCredential key) {
this.key = key;
return this;
}

/**
* Sets the endpoint for the client.
* @param endpoint The endpoint
* @return The builder
*/
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
return this;
}

/**
* Sets the HTTP client for the client.
* @param httpClient The HTTP client
* @return The builder
*/
public Builder httpClient(HttpClient httpClient) {
this.httpClient = httpClient;
return this;
Expand Down
Loading

0 comments on commit c4bc8c6

Please sign in to comment.