Skip to content

Commit a4e1e5a

Browse files
authored
Merge pull request #1086 from mathiasb/feature-mistral-embedding-pgvector
llms/mistral: Implementing embeddings.EmbedderClient for Mistral and an example with PGVector
2 parents 71ded3c + c62063e commit a4e1e5a

File tree

7 files changed

+520
-4
lines changed

7 files changed

+520
-4
lines changed
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module github.com/tmc/langchaingo/examples/mistral-embedding-example
2+
3+
go 1.22.0
4+
5+
toolchain go1.22.1
6+
7+
require github.com/tmc/langchaingo v0.1.12
8+
9+
replace github.com/tmc/langchaingo => ../../
10+
11+
require (
12+
github.com/dlclark/regexp2 v1.10.0 // indirect
13+
github.com/gage-technologies/mistral-go v1.1.0 // indirect
14+
github.com/google/uuid v1.6.0 // indirect
15+
github.com/jackc/pgpassfile v1.0.0 // indirect
16+
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
17+
github.com/jackc/pgx/v5 v5.5.5 // indirect
18+
github.com/pgvector/pgvector-go v0.1.1 // indirect
19+
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
20+
golang.org/x/crypto v0.23.0 // indirect
21+
golang.org/x/text v0.15.0 // indirect
22+
)

examples/mistral-embedding-example/go.sum

+248
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"flag"
6+
"fmt"
7+
"log"
8+
"time"
9+
10+
"github.com/tmc/langchaingo/embeddings"
11+
"github.com/tmc/langchaingo/llms/mistral"
12+
"github.com/tmc/langchaingo/schema"
13+
"github.com/tmc/langchaingo/vectorstores"
14+
"github.com/tmc/langchaingo/vectorstores/pgvector"
15+
)
16+
17+
func main() {
18+
var dsn string
19+
flag.StringVar(&dsn, "dsn", "", "PGvector connection string")
20+
flag.Parse()
21+
model, err := mistral.New()
22+
if err != nil {
23+
panic(err)
24+
}
25+
26+
e, err := embeddings.NewEmbedder(model)
27+
28+
if err != nil {
29+
panic(err)
30+
}
31+
32+
// Create a new pgvector store.
33+
ctx := context.Background()
34+
store, err := pgvector.New(
35+
ctx,
36+
pgvector.WithConnectionURL(dsn),
37+
pgvector.WithEmbedder(e),
38+
)
39+
if err != nil {
40+
log.Fatal("pgvector.New", err)
41+
}
42+
43+
// Add documents to the pgvector store.
44+
_, err = store.AddDocuments(context.Background(), []schema.Document{
45+
{
46+
PageContent: "Tokyo",
47+
Metadata: map[string]any{
48+
"population": 38,
49+
"area": 2190,
50+
},
51+
},
52+
{
53+
PageContent: "Paris",
54+
Metadata: map[string]any{
55+
"population": 11,
56+
"area": 105,
57+
},
58+
},
59+
{
60+
PageContent: "London",
61+
Metadata: map[string]any{
62+
"population": 9.5,
63+
"area": 1572,
64+
},
65+
},
66+
{
67+
PageContent: "Santiago",
68+
Metadata: map[string]any{
69+
"population": 6.9,
70+
"area": 641,
71+
},
72+
},
73+
{
74+
PageContent: "Buenos Aires",
75+
Metadata: map[string]any{
76+
"population": 15.5,
77+
"area": 203,
78+
},
79+
},
80+
{
81+
PageContent: "Rio de Janeiro",
82+
Metadata: map[string]any{
83+
"population": 13.7,
84+
"area": 1200,
85+
},
86+
},
87+
{
88+
PageContent: "Sao Paulo",
89+
Metadata: map[string]any{
90+
"population": 22.6,
91+
"area": 1523,
92+
},
93+
},
94+
})
95+
if err != nil {
96+
log.Fatal("store.AddDocuments:\n", err)
97+
}
98+
time.Sleep(1 * time.Second)
99+
100+
// Search for similar documents.
101+
docs, err := store.SimilaritySearch(ctx, "japan", 1)
102+
if err != nil {
103+
log.Fatal("store.SimilaritySearch1:\n", err)
104+
}
105+
fmt.Println("store.SimilaritySearch1:\n", docs)
106+
107+
time.Sleep(2 * time.Second) // Don't trigger cloudflare
108+
109+
// Search for similar documents using score threshold.
110+
docs, err = store.SimilaritySearch(ctx, "only cities in south america", 3, vectorstores.WithScoreThreshold(0.50))
111+
if err != nil {
112+
log.Fatal("store.SimilaritySearch2:\n", err)
113+
}
114+
fmt.Println("store.SimilaritySearch2:\n", docs)
115+
116+
time.Sleep(3 * time.Second) // Don't trigger cloudflare
117+
118+
// Search for similar documents using score threshold and metadata filter.
119+
// Metadata filter for pgvector only supports key-value pairs for now.
120+
filter := map[string]any{"area": "1523"} // Sao Paulo
121+
122+
docs, err = store.SimilaritySearch(ctx, "only cities in south america",
123+
3,
124+
vectorstores.WithScoreThreshold(0.50),
125+
vectorstores.WithFilters(filter),
126+
)
127+
if err != nil {
128+
log.Fatal("store.SimilaritySearch3:\n", err)
129+
}
130+
fmt.Println("store.SimilaritySearch3:\n", docs)
131+
}

go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ require (
158158
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect
159159
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect
160160
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect
161-
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 // indirect
162161
go.opencensus.io v0.24.0 // indirect
163162
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
164163
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
@@ -196,7 +195,7 @@ require (
196195
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.8.1
197196
github.com/cohere-ai/tokenizer v1.1.2
198197
github.com/fatih/color v1.17.0
199-
github.com/gage-technologies/mistral-go v1.0.0
198+
github.com/gage-technologies/mistral-go v1.1.0
200199
github.com/getzep/zep-go v1.0.4
201200
github.com/go-openapi/strfmt v0.21.3
202201
github.com/go-sql-driver/mysql v1.7.1
@@ -220,6 +219,7 @@ require (
220219
github.com/weaviate/weaviate-go-client/v4 v4.13.1
221220
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
222221
go.mongodb.org/mongo-driver v1.14.0
222+
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1
223223
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
224224
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
225225
golang.org/x/tools v0.14.0

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSw
196196
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
197197
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
198198
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
199-
github.com/gage-technologies/mistral-go v1.0.0 h1:Hwk0uJO+Iq4kMX/EwbfGRUq9zkO36w7HZ/g53N4N73A=
200-
github.com/gage-technologies/mistral-go v1.0.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
199+
github.com/gage-technologies/mistral-go v1.1.0 h1:POv1wM9jA/9OBXGV2YdPi9Y/h09+MjCbUF+9hRYlVUI=
200+
github.com/gage-technologies/mistral-go v1.1.0/go.mod h1:tF++Xt7U975GcLlzhrjSQb8l/x+PrriO9QEdsgm9l28=
201201
github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv1M62hOWzwo5OXotisrKc=
202202
github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ=
203203
github.com/getsentry/sentry-go v0.12.0 h1:era7g0re5iY13bHSdN/xMkyV+5zZppjRVQhZrXCaEIk=

llms/mistral/mistralembed.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package mistral
2+
3+
import (
4+
"context"
5+
"errors"
6+
)
7+
8+
var ErrEmptyEmbeddings = errors.New("empty embeddings")
9+
10+
func convertFloat64ToFloat32(input []float64) []float32 {
11+
// Create a slice with the same length as the input.
12+
output := make([]float32, len(input))
13+
14+
// Iterate over the input slice and convert each element.
15+
for i, v := range input {
16+
output[i] = float32(v)
17+
}
18+
19+
return output
20+
}
21+
22+
// CreateEmbedding implements the embeddings.EmbedderClient interface and creates embeddings for the given input texts.
23+
func (m *Model) CreateEmbedding(_ context.Context, inputTexts []string) ([][]float32, error) {
24+
embsRes, err := m.client.Embeddings("mistral-embed", inputTexts)
25+
if err != nil {
26+
return nil, errors.New("failed to create embeddings: " + err.Error())
27+
}
28+
allEmbds := make([][]float32, len(embsRes.Data))
29+
for i, embs := range embsRes.Data {
30+
if len(embs.Embedding) == 0 {
31+
return nil, ErrEmptyEmbeddings
32+
}
33+
allEmbds[i] = convertFloat64ToFloat32(embs.Embedding)
34+
}
35+
return allEmbds, nil
36+
}

llms/mistral/mistralembed_test.go

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package mistral
2+
3+
import (
4+
"context"
5+
"os"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
"github.com/tmc/langchaingo/embeddings"
10+
)
11+
12+
// TestConvertFloat64ToFloat32 tests the ConvertFloat64ToFloat32 function using table-driven tests.
13+
func TestConvertFloat64ToFloat32(t *testing.T) {
14+
t.Parallel()
15+
tests := []struct {
16+
name string
17+
input []float64
18+
expected []float32
19+
}{
20+
{
21+
name: "empty slice",
22+
input: []float64{},
23+
expected: []float32{},
24+
},
25+
{
26+
name: "single element",
27+
input: []float64{3.14},
28+
expected: []float32{3.14},
29+
},
30+
{
31+
name: "multiple elements",
32+
input: []float64{1.23, 4.56, 7.89},
33+
expected: []float32{1.23, 4.56, 7.89},
34+
},
35+
{
36+
name: "zero values",
37+
input: []float64{0.0, 0.0, 0.0},
38+
expected: []float32{0.0, 0.0, 0.0},
39+
},
40+
}
41+
42+
for _, tt := range tests {
43+
t.Run(tt.name, func(t *testing.T) {
44+
t.Parallel()
45+
output := convertFloat64ToFloat32(tt.input)
46+
47+
require.Equal(t, len(tt.expected), len(output), "length mismatch")
48+
for i := range output {
49+
require.Equal(t, tt.expected[i], output[i], "at index %d", i)
50+
}
51+
})
52+
}
53+
}
54+
55+
func TestMistralEmbed(t *testing.T) {
56+
t.Parallel()
57+
envVar := "MISTRAL_API_KEY"
58+
59+
// Get the value of the environment variable
60+
value := os.Getenv(envVar)
61+
62+
// Check if it is set (non-empty)
63+
if value == "" {
64+
t.Skipf("Environment variable %s is not set, so skipping the test", envVar)
65+
return
66+
}
67+
68+
model, err := New()
69+
require.NoError(t, err)
70+
71+
e, err := embeddings.NewEmbedder(model)
72+
require.NoError(t, err)
73+
74+
_, err = e.EmbedDocuments(context.Background(), []string{"Hello world"})
75+
require.NoError(t, err)
76+
77+
_, err = e.EmbedQuery(context.Background(), "Hello world")
78+
require.NoError(t, err)
79+
}

0 commit comments

Comments
 (0)