Skip to content

Commit

Permalink
#51 Fixed circular import issue
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Jan 1, 2024
1 parent 65405ef commit 65b409f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 24 deletions.
24 changes: 11 additions & 13 deletions pkg/providers/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ package providers
import (
"errors"
"fmt"
"time"

"glide/pkg/providers/openai"
"glide/pkg/telemetry"
"time"
)

var (
ErrProviderNotFound = errors.New("provider not found")
)
var ErrProviderNotFound = errors.New("provider not found")

type LangModelConfig struct {
ID string `yaml:"id"`
Expand All @@ -34,7 +33,6 @@ func DefaultLangModelConfig() *LangModelConfig {
func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (LanguageModel, error) {
if c.OpenAI != nil {
client, err := openai.NewClient(c.OpenAI, tel)

if err != nil {
return nil, fmt.Errorf("error initing openai client: %v", err)
}
Expand All @@ -45,37 +43,37 @@ func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (LanguageModel, erro
return nil, ErrProviderNotFound
}

func (m *LangModelConfig) validateOneProvider() error {
func (c *LangModelConfig) validateOneProvider() error {
providersConfigured := 0

if m.OpenAI != nil {
if c.OpenAI != nil {
providersConfigured++
}

// check other providers here
if providersConfigured == 0 {
return fmt.Errorf("exactly one provider must be cofigured for model \"%v\", none is configured", m.ID)
return fmt.Errorf("exactly one provider must be cofigured for model \"%v\", none is configured", c.ID)
}

if providersConfigured > 1 {
return fmt.Errorf(
"exactly one provider must be cofigured for model \"%v\", %v are configured",
m.ID,
c.ID,
providersConfigured,
)
}

return nil
}

func (m *LangModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
*m = *DefaultLangModelConfig()
func (c *LangModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
*c = *DefaultLangModelConfig()

type plain LangModelConfig // to avoid recursion

if err := unmarshal((*plain)(m)); err != nil {
if err := unmarshal((*plain)(c)); err != nil {
return err
}

return m.validateOneProvider()
return c.validateOneProvider()
}
5 changes: 5 additions & 0 deletions pkg/providers/errs/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package errs

import "errors"

var ErrProviderUnavailable = errors.New("provider is not available")
4 changes: 2 additions & 2 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"io"
"net/http"

"glide/pkg/providers"
"glide/pkg/providers/errs"

"glide/pkg/api/schemas"
"go.uber.org/zap"
Expand Down Expand Up @@ -144,7 +144,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
zap.Any("headers", resp.Header),
)

return nil, providers.ErrProviderUnavailable
return nil, errs.ErrProviderUnavailable
}

// Parse response
Expand Down
4 changes: 1 addition & 3 deletions pkg/providers/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package providers

import (
"context"
"errors"

"glide/pkg/api/schemas"
)

var ErrProviderUnavailable = errors.New("provider is not available")

// ModelProvider defines an interface all model providers should support
type ModelProvider interface {
Provider() string
Expand Down
10 changes: 4 additions & 6 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@ package routers
import (
"context"
"errors"

"glide/pkg/providers"
"glide/pkg/providers/factory"
"go.uber.org/multierr"
"go.uber.org/zap"

"glide/pkg/api/schemas"
"glide/pkg/telemetry"
)

var (
ErrNoModels = errors.New("no models configured for router")
)
var ErrNoModels = errors.New("no models configured for router")

type LangRouter struct {
config *LangRouterConfig
Expand All @@ -35,6 +33,7 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter

func (r *LangRouter) BuildModels(modelConfigs []providers.LangModelConfig) error {
var errs error

models := make([]providers.LanguageModel, 0, len(modelConfigs))

for _, modelConfig := range modelConfigs {
Expand All @@ -45,8 +44,7 @@ func (r *LangRouter) BuildModels(modelConfigs []providers.LangModelConfig) error

r.telemetry.Logger.Debug("init lang model", zap.String("modelID", modelConfig.ID))

model, err := factory.NewModelFromConfig(modelConfig, r.telemetry)

model, err := modelConfig.ToModel(r.telemetry)
if err != nil {
errs = multierr.Append(errs, err)
continue
Expand Down

0 comments on commit 65b409f

Please sign in to comment.