Skip to content

Commit

Permalink
Update package
Browse files Browse the repository at this point in the history
  • Loading branch information
vmanot committed Jan 23, 2025
1 parent 2b1f499 commit 4021a8f
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 36 deletions.
3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ let package = Package(
"Swallow"
],
path: "Sources/HuggingFace",
resources: [
.process("Resources")
],
swiftSettings: [
.enableExperimentalFeature("AccessLevelOnImport")
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,23 @@ extension HuggingFace {
}

@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
let semaphore = DispatchSemaphore(value: 0)
stateSubscriber = downloadState.sink { state in
func waitUntilDone() async throws -> URL {
for try await state in downloadState.toAsyncStream().eraseToAnyAsyncSequence() {
switch state {
case .completed:
semaphore.signal()
case .failed:
semaphore.signal()
default:
break
case .completed, .failed:
break
default:
continue
}
}
semaphore.wait()

switch downloadState.value {
case .completed(let url):
return url
case .failed(let error):
throw error
default:
throw DownloadError.unexpectedError
case .completed(let url):
return url
case .failed(let error):
throw error
default:
throw DownloadError.unexpectedError
}
}

Expand Down Expand Up @@ -141,10 +136,6 @@ extension HuggingFace.Downloader: URLSessionDownloadDelegate {
) {
if let error = error {
downloadState.value = .failed(error)
// } else if let response = task.response as? HTTPURLResponse {
// print("HTTP response status code: \(response.statusCode)")
// let headers = response.allHeaderFields
// print("HTTP response headers: \(headers)")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Copyright (c) Preternatural AI, Inc.
//

import Combine
import CoreMI
import CorePersistence
import FoundationX
Expand All @@ -18,13 +19,11 @@ extension HuggingFace.Hub {
public var downloadBase: URL
public var hfToken: String?
public var endpoint: String
public var useBackgroundSession: Bool

public init(
downloadBase: URL? = nil,
hfToken: String? = nil,
endpoint: String = "https://huggingface.co",
useBackgroundSession: Bool = false
endpoint: String = "https://huggingface.co"
) {
self.hfToken = hfToken

Expand All @@ -37,7 +36,6 @@ extension HuggingFace.Hub {
}

self.endpoint = endpoint
self.useBackgroundSession = useBackgroundSession
}

public convenience init(
Expand Down Expand Up @@ -178,12 +176,11 @@ extension HuggingFace.Hub.Client {
}

public struct HubFileDownloader {
public let repo: Repo
public let repoDestination: URL
public let repo: Repo
public let repoDestination: URL
public let relativeFilename: String
public let hfToken: String?
public let hfToken: String?
public let endpoint: String?
public let backgroundSession: Bool

public var source: URL {
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
Expand Down Expand Up @@ -215,18 +212,28 @@ extension HuggingFace.Hub.Client {
// (See for example PipelineLoader in swift-coreml-diffusers)
@discardableResult
func download(outputHandler: @escaping (Double) -> Void) async throws -> URL {
guard !downloaded else { return destination }
guard !downloaded else {
return destination
}

try prepareDestination()
let downloader = HuggingFace.Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
let downloadSubscriber = downloader.downloadState.sink { state in
let downloader = HuggingFace.Downloader(from: source, to: destination, using: hfToken)

let progressSubscription: AnyCancellable = downloader.downloadState.throttle(
for: .milliseconds(50),
scheduler: .mainThread,
latest: true
)
.sink { state in
if case .downloading(let progress) = state {
outputHandler(progress)
}
}
_ = try withExtendedLifetime(downloadSubscriber) {
try downloader.waitUntilDone()

try await _asyncWithExtendedLifetime(progressSubscription) {

Check failure on line 233 in Sources/HuggingFace/Intramodular/HuggingFace.Hub.Client.swift

View workflow job for this annotation

GitHub Actions / build / preternatural-build

cannot find '_asyncWithExtendedLifetime' in scope
try await downloader.waitUntilDone()
}

return destination
}
}
Expand Down Expand Up @@ -273,8 +280,7 @@ extension HuggingFace.Hub.Client {
repoDestination: repoDestination,
relativeFilename: filename,
hfToken: hfToken,
endpoint: endpoint,
backgroundSession: useBackgroundSession
endpoint: endpoint
)
try await downloader.download { fractionDownloaded in
fileProgress.completedUnitCount = Int64(100 * fractionDownloaded)
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 4021a8f

Please sign in to comment.