Skip to content

Commit 13863ec

Browse files
added embeddings for mistral
1 parent 97041fe commit 13863ec

14 files changed

+208
-57
lines changed

Package.swift

+11
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,17 @@ let package = Package(
210210
.enableExperimentalFeature("AccessLevelOnImport")
211211
]
212212
),
213+
.testTarget(
214+
name: "MistralTests",
215+
dependencies: [
216+
"AI",
217+
"Swallow"
218+
],
219+
path: "Tests/Mistral",
220+
swiftSettings: [
221+
.enableExperimentalFeature("AccessLevelOnImport")
222+
]
223+
),
213224
.testTarget(
214225
name: "OpenAITests",
215226
dependencies: [

Sources/CoreMI/Intramodular/Model Identifier/ModelIdentifier+Utilities.swift

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ extension ModelIdentifier {
4949
case mistral_tiny = "mistral-tiny"
5050
case mistral_small = "mistral-small"
5151
case mistral_medium = "mistral-medium"
52+
case mistral_embed = "mistral-embed"
5253
}
5354

5455
private enum _OpenAI_Model: String, CaseIterable {

Sources/Mistral/Intramodular/Mistral.APISpecification.swift Sources/Mistral/Intramodular/API/Mistral.APISpecification.RequestBodies.swift

+15-42
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ extension Mistral {
4141
@POST
4242
@Path("chat/completions")
4343
public var chatCompletions = Endpoint<RequestBodies.ChatCompletions, ResponseBodies.ChatCompletion, Void>()
44+
45+
@POST
46+
@Path("embeddings")
47+
public var createEmbeddings = Endpoint<RequestBodies.CreateEmbedding, Mistral.Embeddings, Void>()
4448
}
4549
}
4650

@@ -121,24 +125,6 @@ extension Mistral.APISpecification {
121125
}
122126
}
123127

124-
extension Mistral {
125-
public struct ChatMessage: Codable, Hashable, Sendable {
126-
public enum Role: String, Codable, Hashable, Sendable {
127-
case system
128-
case user
129-
case assistant
130-
}
131-
132-
public var role: Role
133-
public var content: String
134-
135-
public init(role: Role, content: String) {
136-
self.role = role
137-
self.content = content
138-
}
139-
}
140-
}
141-
142128
extension Mistral.APISpecification.RequestBodies {
143129
/// https://docs.mistral.ai/api#operation/createChatCompletion
144130
public struct ChatCompletions: Codable, Hashable, Sendable {
@@ -152,31 +138,18 @@ extension Mistral.APISpecification.RequestBodies {
152138
}
153139
}
154140

155-
extension Mistral.APISpecification.ResponseBodies {
156-
public struct ChatCompletion: Codable, Hashable, Sendable {
157-
public struct Choice: Codable, Hashable, Sendable {
158-
public enum FinishReason: String, Codable, Hashable, Sendable {
159-
case stop = "stop"
160-
case length = "length"
161-
case modelLength = "model_length"
162-
}
163-
164-
public let index: Int
165-
public let message: Mistral.ChatMessage
166-
public let finishReason: FinishReason
167-
}
141+
extension Mistral.APISpecification.RequestBodies {
142+
public struct CreateEmbedding: Codable, Hashable {
143+
public let model: Mistral.Model
144+
public let input: [String]
145+
public let encodingFormat: String
168146

169-
public struct Usage: Codable, Hashable, Sendable {
170-
public let promptTokens: Int
171-
public let completionTokens: Int
172-
public let totalTokens: Int
147+
init(input: [String]) {
148+
self.model = Mistral.Model.mistral_embed
149+
self.input = input
150+
self.encodingFormat = "Float"
173151
}
174-
175-
public var id: String
176-
public var object: String
177-
public var created: Date
178-
public var model: String
179-
public var choices: [Choice]
180-
public let usage: Usage
181152
}
182153
}
154+
155+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import NetworkKit
6+
import FoundationX
7+
import Swallow
8+
9+
extension Mistral.APISpecification.ResponseBodies {
10+
public struct ChatCompletion: Codable, Hashable, Sendable {
11+
public struct Choice: Codable, Hashable, Sendable {
12+
public enum FinishReason: String, Codable, Hashable, Sendable {
13+
case stop = "stop"
14+
case length = "length"
15+
case modelLength = "model_length"
16+
}
17+
18+
public let index: Int
19+
public let message: Mistral.ChatMessage
20+
public let finishReason: FinishReason
21+
}
22+
23+
public struct Usage: Codable, Hashable, Sendable {
24+
public let promptTokens: Int
25+
public let completionTokens: Int
26+
public let totalTokens: Int
27+
}
28+
29+
public var id: String
30+
public var object: String
31+
public var created: Date
32+
public var model: String
33+
public var choices: [Choice]
34+
public let usage: Usage
35+
}
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import CoreMI
6+
import CorePersistence
7+
8+
extension Mistral.Client: TextEmbeddingsRequestHandling {
9+
10+
public func fulfill(
11+
_ request: TextEmbeddingsRequest
12+
) async throws -> TextEmbeddings {
13+
let model = Mistral.Model.mistral_embed.__conversion()
14+
15+
let embeddings: Mistral.Embeddings = try await createEmbeddings(for: request.input)
16+
let textEmbeddings = embeddings.data.map {
17+
TextEmbeddings.Element(
18+
text: $0.object,
19+
embedding: $0.embedding,
20+
model: model)
21+
}
22+
23+
return TextEmbeddings(
24+
model: model,
25+
data: textEmbeddings
26+
)
27+
}
28+
}

Sources/Mistral/Intramodular/Mistral.Client.swift

+8
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,11 @@ extension Mistral.Client: _MIService {
5050
self.init(apiKey: credential.apiKey)
5151
}
5252
}
53+
54+
extension Mistral.Client {
55+
public func createEmbeddings(
56+
for input: [String]
57+
) async throws -> Mistral.Embeddings {
58+
try await run(\.createEmbeddings, with: .init(input: input))
59+
}
60+
}

Sources/Mistral/Intramodular/Mistral.Model.swift

+9-6
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,18 @@ extension Mistral {
1212
case mistral_tiny = "mistral-tiny"
1313
case mistral_small = "mistral-small"
1414
case mistral_medium = "mistral-medium"
15+
case mistral_embed = "mistral-embed"
1516

1617
public var name: String {
1718
switch self {
18-
case .mistral_tiny:
19-
return "Tiny"
20-
case .mistral_small:
21-
return "Small"
22-
case .mistral_medium:
23-
return "Medium"
19+
case .mistral_tiny:
20+
return "Tiny"
21+
case .mistral_small:
22+
return "Small"
23+
case .mistral_medium:
24+
return "Medium"
25+
case .mistral_embed:
26+
return "Embed"
2427
}
2528
}
2629
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import Foundation
6+
7+
extension Mistral {
8+
public struct ChatMessage: Codable, Hashable, Sendable {
9+
public enum Role: String, Codable, Hashable, Sendable {
10+
case system
11+
case user
12+
case assistant
13+
}
14+
15+
public var role: Role
16+
public var content: String
17+
18+
public init(role: Role, content: String) {
19+
self.role = role
20+
self.content = content
21+
}
22+
}
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import Foundation
6+
7+
extension Mistral {
8+
public struct Embeddings: Codable, Hashable, Sendable {
9+
public let id: String
10+
public let object: String
11+
public let data: [EmbeddingData]
12+
public let model: String
13+
public let usage: Usage
14+
}
15+
}
16+
17+
extension Mistral.Embeddings {
18+
public struct EmbeddingData: Codable, Hashable, Sendable {
19+
public let object: String
20+
public let embedding: [Double]
21+
public let index: Int
22+
}
23+
}
24+
25+
extension Mistral.Embeddings {
26+
public struct Usage: Codable, Hashable, Sendable {
27+
public enum Role: String, Codable, Hashable, Sendable {
28+
case promptTokens = "prompt_tokens"
29+
case totalTokens = "total_tokens"
30+
}
31+
32+
public let promptTokens: Int
33+
public let totalTokens: Int
34+
}
35+
}

Tests/Groq/Intramodular/CompletionTests.swift

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
//
2-
// File.swift
3-
//
4-
//
5-
// Created by Natasha Murashev on 5/26/24.
2+
// Copyright (c) Vatsal Manot
63
//
74

85
import LargeLanguageModels

Tests/Groq/module.swift

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
//
2-
// File.swift
3-
//
4-
//
5-
// Created by Natasha Murashev on 5/26/24.
2+
// Copyright (c) Vatsal Manot
63
//
74

85
import Groq
96

107
public var GROQ_API_KEY: String {
11-
"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"
8+
""
129
}
1310

1411
public var client: Groq.Client {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import LargeLanguageModels
6+
import Groq
7+
import XCTest
8+
9+
final class CompletionTests: XCTestCase {
10+
11+
let llm: any LLMRequestHandling = client
12+
13+
func testTextEmbeddings() async {
14+
let textInput = ["Hello", "World"]
15+
do {
16+
let embeddings = try await client.createEmbeddings(for: textInput)
17+
let embeddingsData = embeddings.data
18+
XCTAssertTrue(!embeddingsData.isEmpty)
19+
XCTAssertTrue(embeddingsData.first!.object == "embedding")
20+
} catch {
21+
print(error)
22+
XCTFail(error.localizedDescription)
23+
}
24+
}
25+
26+
}

Tests/Mistral/module.swift

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import Mistral
6+
7+
public var MISTRAL_API_KEY: String {
8+
""
9+
}
10+
11+
public var client: Mistral.Client {
12+
Mistral.Client(apiKey: MISTRAL_API_KEY)
13+
}

0 commit comments

Comments
 (0)