Skip to content

Commit 2dbc216

Browse files
added DALLE 3 API (#13)
1 parent cc64565 commit 2dbc216

9 files changed

+218
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import Foundation
6+
import Swallow
7+
8+
public struct AutomaticSpeechRecognitionRequest {
9+
10+
}

Sources/OpenAI/Intramodular/API/OpenAI.APISpecification.RequestBodies.swift

+46-2
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ extension OpenAI.APISpecification.RequestBodies {
408408
}
409409

410410
extension OpenAI.APISpecification.RequestBodies {
411-
struct CreateSpeech: Codable {
411+
struct CreateSpeech: Codable {
412412
enum ResponseFormat: String, Codable, CaseIterable {
413413
case mp3
414414
case opus
@@ -484,7 +484,7 @@ extension OpenAI.APISpecification.RequestBodies {
484484
case timestampGranularities = "timestamp_granularities[]"
485485
case responseFormat = "response_format"
486486
}
487-
487+
488488
/// The audio file object to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
489489
let file: Data
490490
let filename: String
@@ -605,6 +605,50 @@ extension OpenAI.APISpecification.RequestBodies {
605605
}
606606
}
607607

608+
extension OpenAI.APISpecification.RequestBodies {
609+
struct CreateImage: Codable {
610+
enum CodingKeys: String, CodingKey {
611+
case prompt
612+
case model
613+
case numberOfImages = "n"
614+
case responseFormat = "response_format"
615+
case quality
616+
case size
617+
case style
618+
case user
619+
}
620+
621+
let prompt: String
622+
let model: OpenAI.Model.DALL_E
623+
let numberOfImages: Int
624+
let responseFormat: OpenAI.APIClient.ImageResponseFormat
625+
let quality: String
626+
let size: String
627+
let style: String
628+
let user: String?
629+
630+
init(
631+
prompt: String,
632+
model: OpenAI.Model.DALL_E,
633+
responseFormat: OpenAI.APIClient.ImageResponseFormat,
634+
numberOfImages: Int,
635+
quality: OpenAI.Image.Quality = .standard,
636+
size: OpenAI.Image.Size = .w1024h1024,
637+
style: OpenAI.Image.Style = .vivid,
638+
user: String? = nil
639+
) {
640+
self.prompt = prompt
641+
self.model = model
642+
self.numberOfImages = numberOfImages
643+
self.responseFormat = responseFormat
644+
self.quality = quality.rawValue
645+
self.size = size.rawValue
646+
self.style = style.rawValue
647+
self.user = user
648+
}
649+
}
650+
}
651+
608652
// MARK: - Auxiliary
609653

610654
extension OpenAI.APISpecification.RequestBodies.CreateChatCompletion {

Sources/OpenAI/Intramodular/API/OpenAI.APISpecification.swift

+8-1
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,19 @@ extension OpenAI {
180180
Void
181181
>()
182182

183-
// MARK: Whisper
183+
// MARK: Audio Transcription
184184

185185
@POST
186186
@Path("/v1/audio/transcriptions")
187187
@Body(multipart: .input)
188188
var createAudioTranscription = Endpoint<RequestBodies.CreateTranscription, ResponseBodies.CreateTranscription, Void>()
189+
190+
// MARK: Image Generation
191+
192+
@POST
193+
@Path("/v1/images/generations")
194+
@Body(json: .input, keyEncodingStrategy: .convertToSnakeCase)
195+
var createImage = Endpoint<RequestBodies.CreateImage, OpenAI.List<OpenAI.Image>, Void>()
189196
}
190197
}
191198

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import CorePersistence
6+
import Diagnostics
7+
import Swift
8+
9+
extension OpenAI {
10+
public final class Image: OpenAI.Object {
11+
enum CodingKeys: String, CodingKey {
12+
case data
13+
}
14+
15+
public let url: String?
16+
public let revisedPrompt: String?
17+
18+
public required init(from decoder: Decoder) throws {
19+
let container = try decoder.container(keyedBy: CodingKeys.self)
20+
21+
let data: [String: String] = try container.decode(forKey: .data)
22+
23+
self.url = data["url"]
24+
self.revisedPrompt = data["revisedPrompt"]
25+
26+
super.init(type: .image)
27+
}
28+
29+
public init(data: [String : Any]) {
30+
self.url = data["url"] as! String?
31+
self.revisedPrompt = data["revisedPrompt"] as! String?
32+
33+
super.init(type: .image)
34+
}
35+
}
36+
}
37+
38+
extension OpenAI.Image {
39+
/// The quality of the image that will be generated. hd creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3.
40+
/// Defaults to standard
41+
public enum Quality: String, Codable, CaseIterable {
42+
case standard
43+
case hd
44+
}
45+
46+
/// The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024 for dall-e-2. Must be one of 1024x1024, 1792x1024, or 1024x1792 for dall-e-3 models.
47+
/// Defaults to 1024x1024
48+
public enum Size: String, Codable, CaseIterable {
49+
case w1024h1024 = "1024x1024"
50+
case w1792h1024 = "1792x1024"
51+
case w1024h1792 = "1024x1792"
52+
}
53+
54+
/// The style of the generated images. Must be one of vivid or natural. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for dall-e-3.
55+
/// Defaults to vivid
56+
public enum Style: String, Codable, CaseIterable {
57+
case vivid
58+
case natural
59+
}
60+
}

Sources/OpenAI/Intramodular/Models/OpenAI.Model.swift

+26
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ extension OpenAI {
3232
case chat(Chat)
3333
case speech(Speech)
3434
case whisper(Whisper)
35+
case dall_e(DALL_E)
3536

3637
/// Deprecated by OpenAI.
3738
case feature(Feature)
@@ -64,6 +65,8 @@ extension OpenAI {
6465
return value
6566
case .whisper(let value):
6667
return value
68+
case .dall_e(let value):
69+
return value
6770
case .unknown:
6871
assertionFailure(.unimplemented)
6972

@@ -327,6 +330,27 @@ extension OpenAI.Model {
327330
}
328331
}
329332

333+
extension OpenAI.Model {
334+
public enum DALL_E: String, Named, OpenAI._ModelType, CaseIterable {
335+
public static var `default`: Self {
336+
.dalle3
337+
}
338+
339+
case dalle3 = "dall-e-3"
340+
341+
public var contextSize: Int {
342+
return 4000
343+
}
344+
345+
public var name: String {
346+
switch self {
347+
case .dalle3:
348+
"dall-e-3"
349+
}
350+
}
351+
}
352+
}
353+
330354
// MARK: - Conformances
331355

332356
extension OpenAI.Model: Codable {
@@ -378,6 +402,8 @@ extension OpenAI.Model: RawRepresentable {
378402
return model.rawValue
379403
case .whisper(let model):
380404
return model.rawValue
405+
case .dall_e(let model):
406+
return model.rawValue
381407
case .unknown(let rawValue):
382408
return rawValue
383409
}

Sources/OpenAI/Intramodular/Models/OpenAI.Object.swift

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ extension OpenAI {
2020
case assistant = "assistant"
2121
case assistantFile = "assistant.file"
2222
case run = "thread.run"
23+
case image
2324

2425
public func resolveType() -> Any.Type {
2526
switch self {
@@ -37,6 +38,8 @@ extension OpenAI {
3738
return OpenAI.Speech.self
3839
case .transcription:
3940
return OpenAI.AudioTranscription.self
41+
case .image:
42+
return OpenAI.Image.self
4043
case .file:
4144
return OpenAI.File.self
4245
case .thread:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//
2+
// Copyright (c) Vatsal Manot
3+
//
4+
5+
import CorePersistence
6+
import Swallow
7+
8+
extension OpenAI.APIClient {
9+
/// The format in which the generated images are returned.
10+
public enum ImageResponseFormat: String, Codable, CaseIterable {
11+
/// URLs are only valid for 60 minutes after the image has been generated.
12+
case ephemeralURL = "url"
13+
case base64JSON
14+
}
15+
16+
/// Create an image using DALL-E.
17+
///
18+
/// The maximum length for the prompt is `1000` characters for `dall-e-2` and `4000` characters for `dall-e-3`.
19+
public func createImage(
20+
prompt: String,
21+
responseFormat: ImageResponseFormat = .ephemeralURL,
22+
numberOfImages: Int = 1,
23+
quality: OpenAI.Image.Quality = .standard,
24+
size: OpenAI.Image.Size = .w1024h1024,
25+
style: OpenAI.Image.Style = .vivid,
26+
user: String? = nil
27+
) async throws -> OpenAI.List<OpenAI.Image> {
28+
let requestBody = OpenAI.APISpecification.RequestBodies.CreateImage(
29+
prompt: prompt,
30+
model: .dalle3,
31+
responseFormat: responseFormat,
32+
numberOfImages: numberOfImages,
33+
quality: quality,
34+
size: size,
35+
style: style,
36+
user: user
37+
)
38+
39+
let response = try await run(\.createImage, with: requestBody)
40+
41+
return response
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//
2+
// DalleTests.swift
3+
//
4+
//
5+
// Created by Natasha Murashev on 5/1/24.
6+
//
7+
8+
import OpenAI
9+
import XCTest
10+
11+
final class DalleTests: XCTestCase {
12+
13+
func testCreateImage() async throws {
14+
let result = try await client.createImage(
15+
prompt: "a kitten playing with yarn"
16+
)
17+
18+
_ = result
19+
}
20+
}
21+

Tests/OpenAI/module.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import OpenAI
66

77
public var OPENAI_API_KEY: String {
8-
""
8+
"sk-proj-S7Ut3eDrehdVOAzae6NmT3BlbkFJLop7OieQ030Rg1Ej2EFc"
99
}
1010

1111
public var client: OpenAI.APIClient {

0 commit comments

Comments
 (0)