Skip to content

Commit

Permalink
61 update openai client (#62)
Browse files Browse the repository at this point in the history
* #61: Create OpenAI Chat Response Schema

* #61: Test unmarshling

* #61: unmarshal openai response

* #61: Update CohereChatCompletion struct and related types

* #61: Update Cohere chat response mapping

* #61: lint

* #61: lint

---------

Co-authored-by: Max <mkrueger190@gmail.com>
  • Loading branch information
mkrueger12 and mkrueger12 authored Jan 5, 2024
1 parent ff3dc22 commit 4b66e2f
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 91 deletions.
115 changes: 97 additions & 18 deletions pkg/api/schemas/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,27 @@ type UnifiedChatRequest struct {

// UnifiedChatResponse defines Glide's Chat Response Schema unified across all language models
type UnifiedChatResponse struct {
ID string `json:"id,omitempty"`
Created float64 `json:"created,omitempty"`
Provider string `json:"provider,omitempty"`
Router string `json:"router,omitempty"`
Model string `json:"model,omitempty"`
Cached bool `json:"cached,omitempty"`
ProviderResponse ProviderResponse `json:"provider_response,omitempty"`
ID string `json:"id,omitempty"`
Created int `json:"created,omitempty"`
Provider string `json:"provider,omitempty"`
Router string `json:"router,omitempty"`
Model string `json:"model,omitempty"`
Cached bool `json:"cached,omitempty"`
ModelResponse ProviderResponse `json:"modelResponse,omitempty"`
}

// ProviderResponse contains data from the chosen provider
// ProviderResponse is the unified response from the provider.

type ProviderResponse struct {
ResponseID map[string]string `json:"response_id,omitempty"`
ResponseID map[string]string `json:"responseID,omitempty"`
Message ChatMessage `json:"message"`
TokenCount TokenCount `json:"token_count"`
TokenCount TokenCount `json:"tokenCount"`
}

type TokenCount struct {
PromptTokens float64 `json:"prompt_tokens"`
ResponseTokens float64 `json:"response_tokens"`
TotalTokens float64 `json:"total_tokens"`
PromptTokens float64 `json:"promptTokens"`
ResponseTokens float64 `json:"responseTokens"`
TotalTokens float64 `json:"totalTokens"`
}

// ChatMessage is a message in a chat request.
Expand All @@ -41,15 +42,93 @@ type ChatMessage struct {
Name string `json:"name,omitempty"`
}

// ChatChoice is a choice in a chat response.
type ChatChoice struct {
// OpenAI Chat Response
// TODO: Should this live here?
type OpenAIChatCompletion struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}

type Choice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
Logprobs interface{} `json:"logprobs"`
FinishReason string `json:"finish_reason"`
}

type Usage struct {
CompletionTokens float64 `json:"completion_tokens,omitempty"`
PromptTokens float64 `json:"prompt_tokens,omitempty"`
TotalTokens float64 `json:"total_tokens,omitempty"`
PromptTokens float64 `json:"prompt_tokens"`
CompletionTokens float64 `json:"completion_tokens"`
TotalTokens float64 `json:"total_tokens"`
}

// Cohere Chat Response
type CohereChatCompletion struct {
Text string `json:"text"`
GenerationID string `json:"generation_id"`
ResponseID string `json:"response_id"`
TokenCount CohereTokenCount `json:"token_count"`
Citations []Citation `json:"citations"`
Documents []Documents `json:"documents"`
SearchQueries []SearchQuery `json:"search_queries"`
SearchResults []SearchResults `json:"search_results"`
Meta Meta `json:"meta"`
ToolInputs map[string]interface{} `json:"tool_inputs"`
}

type CohereTokenCount struct {
PromptTokens float64 `json:"prompt_tokens"`
ResponseTokens float64 `json:"response_tokens"`
TotalTokens float64 `json:"total_tokens"`
BilledTokens float64 `json:"billed_tokens"`
}

type Meta struct {
APIVersion struct {
Version string `json:"version"`
} `json:"api_version"`
BilledUnits struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"billed_units"`
}

type Citation struct {
Start int `json:"start"`
End int `json:"end"`
Text string `json:"text"`
DocumentID []string `json:"documentId"`
}

type Documents struct {
ID string `json:"id"`
Data map[string]string `json:"data"` // TODO: This needs to be updated
}

type SearchQuery struct {
Text string `json:"text"`
GenerationID string `json:"generationId"`
}

type SearchResults struct {
SearchQuery []SearchQueryObject `json:"searchQuery"`
Connectors []ConnectorsResponse `json:"connectors"`
DocumentID []string `json:"documentId"`
}

type SearchQueryObject struct {
Text string `json:"text"`
GenerationID string `json:"generationId"`
}

type ConnectorsResponse struct {
ID string `json:"id"`
UserAccessToken string `json:"user_access_token"`
ContOnFail string `json:"continue_on_failure"`
Options map[string]string `json:"options"`
}
62 changes: 30 additions & 32 deletions pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest)
return nil, err
}

if len(chatResponse.ProviderResponse.Message.Content) == 0 {
if len(chatResponse.ModelResponse.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

Expand Down Expand Up @@ -166,41 +166,39 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
return nil, err
}

// Parse response
var response schemas.UnifiedChatResponse

var responsePayload schemas.ProviderResponse

var tokenCount schemas.TokenCount

messageStruct := schemas.ChatMessage{
Role: "Model",
Content: responseJSON["text"].(string),
}
// Parse the response JSON
var cohereCompletion schemas.CohereChatCompletion

tokenCount = schemas.TokenCount{
PromptTokens: responseJSON["token_count"].(map[string]interface{})["prompt_tokens"].(float64),
ResponseTokens: responseJSON["token_count"].(map[string]interface{})["response_tokens"].(float64),
TotalTokens: responseJSON["token_count"].(map[string]interface{})["total_tokens"].(float64),
err = json.Unmarshal(bodyBytes, &cohereCompletion)
if err != nil {
c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
return nil, err
}

responsePayload = schemas.ProviderResponse{
ResponseID: map[string]string{
"response_id": responseJSON["response_id"].(string),
"generation_id": responseJSON["generation_id"].(string),
// Map response to UnifiedChatResponse schema
response := schemas.UnifiedChatResponse{
ID: cohereCompletion.ResponseID,
Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this
Provider: providerName,
Router: "chat", // TODO: this will be the router used
Model: "command-light", // TODO: this needs to come from config or router as Cohere doesn't provide this
Cached: false,
ModelResponse: schemas.ProviderResponse{
ResponseID: map[string]string{
"generationId": cohereCompletion.GenerationID,
"responseId": cohereCompletion.ResponseID,
},
Message: schemas.ChatMessage{
Role: "model", // TODO: Does this need to change?
Content: cohereCompletion.Text,
Name: "",
},
TokenCount: schemas.TokenCount{
PromptTokens: cohereCompletion.TokenCount.PromptTokens,
ResponseTokens: cohereCompletion.TokenCount.ResponseTokens,
TotalTokens: cohereCompletion.TokenCount.TotalTokens,
},
},
Message: messageStruct,
TokenCount: tokenCount,
}

response = schemas.UnifiedChatResponse{
ID: responseJSON["response_id"].(string),
Created: float64(time.Now().Unix()),
Provider: "cohere",
Router: "chat", // TODO: change this to router name
Model: payload.Model, // Should this be derived from somehwhere else? Cohere doesn't specify it in response
Cached: false,
ProviderResponse: responsePayload,
}

return &response, nil
Expand Down
63 changes: 26 additions & 37 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"time"

"glide/pkg/providers/errs"

Expand Down Expand Up @@ -84,7 +83,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest)
return nil, err
}

if len(chatResponse.ProviderResponse.Message.Content) == 0 {
if len(chatResponse.ModelResponse.Message.Content) == 0 {
return nil, ErrEmptyResponse
}

Expand Down Expand Up @@ -154,47 +153,37 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
}

// Parse the response JSON
var responseJSON map[string]interface{}
var openAICompletion schemas.OpenAIChatCompletion

err = json.Unmarshal(bodyBytes, &responseJSON)
err = json.Unmarshal(bodyBytes, &openAICompletion)
if err != nil {
c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
return nil, err
}

// Parse response
var response schemas.UnifiedChatResponse

var responsePayload schemas.ProviderResponse

var tokenCount schemas.TokenCount

message := responseJSON["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{})
messageStruct := schemas.ChatMessage{
Role: message["role"].(string),
Content: message["content"].(string),
}

tokenCount = schemas.TokenCount{
PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(float64),
ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(float64),
TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(float64),
}

responsePayload = schemas.ProviderResponse{
ResponseID: map[string]string{"system_fingerprint": responseJSON["system_fingerprint"].(string)},
Message: messageStruct,
TokenCount: tokenCount,
}

response = schemas.UnifiedChatResponse{
ID: responseJSON["id"].(string),
Created: float64(time.Now().Unix()),
Provider: "openai",
Router: "chat",
Model: responseJSON["model"].(string),
Cached: false,
ProviderResponse: responsePayload,
// Map response to UnifiedChatResponse schema
response := schemas.UnifiedChatResponse{
ID: openAICompletion.ID,
Created: openAICompletion.Created,
Provider: providerName,
Router: "chat", // TODO: this will be the router used
Model: openAICompletion.Model,
Cached: false,
ModelResponse: schemas.ProviderResponse{
ResponseID: map[string]string{
"system_fingerprint": openAICompletion.SystemFingerprint,
},
Message: schemas.ChatMessage{
Role: openAICompletion.Choices[0].Message.Role,
Content: openAICompletion.Choices[0].Message.Content,
Name: "",
},
TokenCount: schemas.TokenCount{
PromptTokens: openAICompletion.Usage.PromptTokens,
ResponseTokens: openAICompletion.Usage.CompletionTokens,
TotalTokens: openAICompletion.Usage.TotalTokens,
},
},
}

return &response, nil
Expand Down
8 changes: 4 additions & 4 deletions pkg/providers/openai/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error {
}

type Config struct {
BaseURL string `yaml:"base_url" json:"baseUrl" validate:"required"`
ChatEndpoint string `yaml:"chat_endpoint" json:"chatEndpoint" validate:"required"`
BaseURL string `yaml:"baseUrl" json:"baseUrl" validate:"required"`
ChatEndpoint string `yaml:"chatEndpoint" json:"chatEndpoint" validate:"required"`
Model string `yaml:"model" json:"model" validate:"required"`
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
DefaultParams *Params `yaml:"default_params,omitempty" json:"defaultParams"`
APIKey fields.Secret `yaml:"apiKey" json:"-" validate:"required"`
DefaultParams *Params `yaml:"defaultParams,omitempty" json:"defaultParams"`
}

// DefaultConfig for OpenAI models
Expand Down

0 comments on commit 4b66e2f

Please sign in to comment.