Skip to content

Commit

Permalink
Gemini file fixes (#53)
Browse files Browse the repository at this point in the history
* file and test fixes

* added method for refering to multiple files
  • Loading branch information
NatashaTheRobot authored Dec 26, 2024
1 parent ce18503 commit ce8a069
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 38 deletions.
14 changes: 7 additions & 7 deletions Sources/_Gemini/Intramodular/Models/_Gemini.File.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import Foundation

extension _Gemini {

public struct File: Codable {
public struct File: Codable, Hashable {
public let createTime: String?
public let expirationTime: String?
public let mimeType: String?
public let name: _Gemini.File.Name?
public let name: _Gemini.File.Name
public let sha256Hash: String?
public let sizeBytes: String?
public let state: State
Expand All @@ -26,20 +26,20 @@ extension _Gemini {
case active = "ACTIVE"
}

public struct VideoMetadata: Codable {
public struct VideoMetadata: Codable, Hashable {
public let videoDuration: String
}
}

public struct FileList: Codable {
public let files: [_Gemini.File]
public struct FileList: Codable, Hashable {
public let files: [_Gemini.File]?
// A token that can be sent as a pageToken into a subsequent files.list call.
public let nextPageToken: String
public let nextPageToken: String?
}
}

extension _Gemini.File {
public struct Name: Codable, RawRepresentable {
public struct Name: Codable, RawRepresentable, Hashable {
public let rawValue: String

public init(rawValue: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,45 @@ extension _Gemini.Client {
configuration: configuration
)
}

public func generateContent(
messages: [_Gemini.Message] = [],
files: [_Gemini.File],
model: _Gemini.Model,
configuration: _Gemini.GenerationConfiguration = configDefault
) async throws -> _Gemini.Content {

let systemInstruction = extractSystemInstruction(from: messages)
let messages: [_Gemini.Message] = messages.filter({ $0.role != .system })
var contents: [_Gemini.APISpecification.RequestBodies.Content] = []

try files.forEach { file in
guard let mimeType = file.mimeType else {
throw _Gemini.APIError.unknown(message: "Invalid MIME type")
}

contents.append(
_Gemini.APISpecification.RequestBodies.Content(
role: "user",
parts: [.file(url: file.uri, mimeType: mimeType)]
)
)
}

contents.append(contentsOf: messages.map { message in
_Gemini.APISpecification.RequestBodies.Content(
role: message.role.rawValue,
parts: [.text(message.content)]
)
})

return try await generateContent(
contents: contents,
systemInstruction: systemInstruction,
model: model,
configuration: configuration
)
}

internal func generateContent(
contents: [_Gemini.APISpecification.RequestBodies.Content],
Expand Down
44 changes: 17 additions & 27 deletions Tests/_Gemini/Intramodular/_GeminiTests+Files.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ import _Gemini

private let fileURL = URL(string: "https://upload.wikimedia.org/wikipedia/en/7/77/EricCartman.png")!
private var fileData: Data? = nil
private var fileName: _Gemini.File.Name? = nil

@Test mutating func testUploadFileFromData() async throws {
let file: _Gemini.File = try await uploadFile()

print(file)
#expect(file.name.isNotNil)
#expect(((file.name?.rawValue.starts(with: "files/")) == true))
#expect(((file.name.rawValue.starts(with: "files/")) == true))
try await client.deleteFile(fileURL: file.uri)
}

@Test mutating func testUploadFileFromRemoteURL() async throws {
Expand All @@ -32,8 +31,8 @@ import _Gemini
)

print(file)
#expect(file.name.isNotNil)
#expect(((file.name?.rawValue.starts(with: "files/")) == true))
#expect(((file.name.rawValue.starts(with: "files/")) == true))
try await client.deleteFile(fileURL: file.uri)
}

@Test mutating func testUploadFileFromLocalURL() async throws {
Expand All @@ -51,46 +50,39 @@ import _Gemini
)

print(file)
#expect(file.name.isNotNil)
#expect(((file.name?.rawValue.starts(with: "files/")) == true))
#expect(((file.name.rawValue.starts(with: "files/")) == true))
try await client.deleteFile(fileURL: file.uri)
}

@Test mutating func testGetFile() async throws {
let uploadedFile: _Gemini.File = try await uploadFile()

if fileName == nil {
fileName = try await uploadFile().name
}

guard let fileName = fileName else {
#expect(Bool(false), "The file name is invalid")
return
}

let file: _Gemini.File = try await client.getFile(name: fileName)
#expect(file.name?.rawValue == fileName.rawValue)
let file: _Gemini.File = try await client.getFile(name: uploadedFile.name)
#expect(file.name.rawValue == uploadedFile.name.rawValue)
try await client.deleteFile(fileURL: file.uri)
}

@Test mutating func testListFiles() async throws {
let file = try await uploadFile()
let fileList: _Gemini.FileList = try await client.listFiles()
let files: [_Gemini.File] = fileList.files

guard let fileName = file.name else {
#expect(Bool(false), "The uploaded file has no valid name.")
guard let files = fileList.files else {
#expect(Bool(false), "The file was uploaded, so there must be files")
return
}

let uploadedFileIsPresent = files.contains { $0.name! == fileName }
let uploadedFileIsPresent = files.contains { $0.name == file.name }
#expect(uploadedFileIsPresent, "Expected the newly uploaded file to be in the returned file list.")
try await client.deleteFile(fileURL: file.uri)
}

@Test mutating func testDeleteFile() async throws {
let uploadedFile = try await uploadFile()
let file: _Gemini.File = try await client.getFile(name: uploadedFile.name!)
let file: _Gemini.File = try await client.getFile(name: uploadedFile.name)

try await client.deleteFile(fileURL: file.uri)
do {
let _ = try await client.getFile(name: uploadedFile.name!)
let _ = try await client.getFile(name: uploadedFile.name)
#expect(Bool(false), "Expected getFile to throw when fetching a deleted file, but it succeeded.")
} catch {
#expect(Bool(true), "getFile threw an error, as expected, when trying to retrieve a deleted file.")
Expand All @@ -108,9 +100,7 @@ extension GeminiFileTests {
mimeType: .custom("image/png"),
displayName: UUID().uuidString
)

self.fileName = file.name


return file
}

Expand Down
57 changes: 53 additions & 4 deletions Tests/_Gemini/Intramodular/_GeminiTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import NetworkKit
displayName: nil
)
print("File successfully uploaded: \(String(describing: file.name))")
let activeFile = try await client.pollFileUntilActive(name: file.name!)
let activeFile = try await client.pollFileUntilActive(name: file.name)

let messages = [_Gemini.Message(role: .user, content: "What is happening in this video?")]

Expand All @@ -42,6 +42,7 @@ import NetworkKit
if let tokenUsage = content.tokenUsage {
print("Token usage - Total: \(tokenUsage.total)")
}
try await client.deleteFile(fileURL: file.uri)
} catch let error as GeminiTestError {
print("Detailed error: \(error.localizedDescription)")
#expect(Bool(false), "Video content generation failed: \(error)")
Expand All @@ -62,8 +63,8 @@ import NetworkKit
mimeType: mimeType,
displayName: nil
)
print("File successfully uploaded: \(String(describing: file.name))")
let activeFile = try await client.pollFileUntilActive(name: file.name!)
print("File successfully uploaded: \(file.name.rawValue)")
let activeFile = try await client.pollFileUntilActive(name: file.name)

let messages = [_Gemini.Message(role: .user, content: "What is being said in this audio?")]

Expand All @@ -77,6 +78,7 @@ import NetworkKit
print("Generated content: \(content)")

#expect(!content.text.isEmpty)
try await client.deleteFile(fileURL: file.uri)
} catch let error as GeminiTestError {
print("Detailed error: \(error.localizedDescription)")
#expect(Bool(false), "Audio content generation failed: \(error)")
Expand All @@ -97,7 +99,7 @@ import NetworkKit
mimeType: mimeType,
displayName: nil
)
let activeFile = try await client.pollFileUntilActive(name: file.name!)
let activeFile = try await client.pollFileUntilActive(name: file.name)

let messages = [_Gemini.Message(role: .user, content: "What is in this image?")]

Expand All @@ -111,6 +113,53 @@ import NetworkKit
print("Generated content: \(content)")

#expect(!content.text.isEmpty)
try await client.deleteFile(fileURL: file.uri)
} catch let error as GeminiTestError {
print("Detailed error: \(error.localizedDescription)")
#expect(Bool(false), "Image content generation failed: \(error)")
} catch {
throw GeminiTestError.imageProcessingError(error)
}
}

@Test func testMultipleContentGeneration() async throws {
do {
guard let imageURL = URL(string: "https://upload.wikimedia.org/wikipedia/en/7/77/EricCartman.png") else {
throw GeminiTestError.invalidURL("https://upload.wikimedia.org/wikipedia/en/7/77/EricCartman.png")
}

let mimeType: HTTPMediaType = .custom("image/png")
let imageFile = try await client.uploadFile(
from: imageURL,
mimeType: mimeType,
displayName: nil
)
let activeImageFile = try await client.pollFileUntilActive(name: imageFile.name)

guard let image2URL = URL(string: "https://upload.wikimedia.org/wikipedia/en/2/25/KyleBroflovski.png") else {
throw GeminiTestError.invalidURL("https://upload.wikimedia.org/wikipedia/en/2/25/KyleBroflovski.png")
}

let image2File = try await client.uploadFile(
from: image2URL,
mimeType: mimeType,
displayName: nil
)
print("File successfully uploaded: \(String(describing: image2File.name))")
let activeImage2File = try await client.pollFileUntilActive(name: image2File.name)

let messages = [_Gemini.Message(role: .user, content: "What do these two images have in common?")]

let content = try await client.generateContent(
messages: messages,
files: [activeImageFile, activeImage2File],
model: .gemini_2_0_flash_exp
)
print("Generated content: \(content)")

#expect(!content.text.isEmpty)
try await client.deleteFile(fileURL: imageFile.uri)
try await client.deleteFile(fileURL: image2File.uri)
} catch let error as GeminiTestError {
print("Detailed error: \(error.localizedDescription)")
#expect(Bool(false), "Image content generation failed: \(error)")
Expand Down

0 comments on commit ce8a069

Please sign in to comment.