Skip to content

Commit b45c44f

Browse files
committed
Refactor GenAIGenerator to use instance method for token generation
1 parent 46581da commit b45c44f

File tree

3 files changed

+133
-127
lines changed

3 files changed

+133
-127
lines changed

mobile/examples/phi-3/ios/LocalLLM/LocalLLM/ContentView.swift

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ struct ContentView: View {
1717
@State private var stats: String = "" // token genetation stats
1818
@State private var showAlert: Bool = false
1919
@State private var errorMessage: String = ""
20+
21+
private let generator = GenAIGenerator()
2022

2123
var body: some View {
2224
VStack {
@@ -65,7 +67,7 @@ struct ContentView: View {
6567

6668

6769
DispatchQueue.global(qos: .background).async {
68-
GenAIGenerator.generate(prompt)
70+
generator.generate(prompt)
6971
}
7072
}) {
7173
Image(systemName: "paperplane.fill")
@@ -101,6 +103,7 @@ struct ContentView: View {
101103
.onReceive(NotificationCenter.default.publisher(for: NSNotification.Name("TokenGenerationError"))) { notification in
102104
if let userInfo = notification.userInfo, let error = userInfo["error"] as? String {
103105
errorMessage = error
106+
isGenerating = false
104107
showAlert = true
105108
}
106109
}

mobile/examples/phi-3/ios/LocalLLM/LocalLLM/GenAIGenerator.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ NS_ASSUME_NONNULL_BEGIN
1111

1212
@interface GenAIGenerator : NSObject
1313

14-
+ (void)generate:(NSString *)input_user_question;
14+
- (void)generate:(NSString *)input_user_question;
1515

1616
@end
1717

Original file line numberDiff line numberDiff line change
@@ -1,143 +1,146 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33
#import "GenAIGenerator.h"
4+
#include <chrono>
5+
#include <vector>
46
#include "LocalLLM-Swift.h"
57
#include "ort_genai.h"
68
#include "ort_genai_c.h"
7-
#include <chrono>
8-
#include <vector>
9+
10+
@interface GenAIGenerator () {
11+
std::unique_ptr<OgaModel> model;
12+
std::unique_ptr<OgaTokenizer> tokenizer;
13+
}
14+
@end
915

1016
@implementation GenAIGenerator
1117

1218
typedef std::chrono::steady_clock Clock;
1319
typedef std::chrono::time_point<Clock> TimePoint;
14-
static std::unique_ptr<OgaModel> model = nullptr;
15-
static std::unique_ptr<OgaTokenizer> tokenizer = nullptr;
16-
17-
+ (void)generate:(nonnull NSString*)input_user_question {
18-
std::vector<long long> tokenTimes; // per-token generation times
19-
TimePoint startTime, firstTokenTime, tokenStartTime;
20-
21-
@try {
22-
NSLog(@"Starting token generation...");
23-
24-
if (!model) {
25-
NSLog(@"Creating model...");
26-
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
27-
const char* modelPath = llmPath.cString;
28-
model = OgaModel::Create(modelPath); // throws exception
29-
30-
if (!model) {
31-
@throw [NSException exceptionWithName:@"ModelCreationError" reason:@"Failed to create model." userInfo:nil];
32-
}
33-
}
34-
35-
if (!tokenizer) {
36-
NSLog(@"Creating tokenizer...");
37-
tokenizer = OgaTokenizer::Create(*model); // throws exception
38-
if (!tokenizer) {
39-
@throw [NSException exceptionWithName:@"TokenizerCreationError" reason:@"Failed to create tokenizer." userInfo:nil];
40-
}
41-
}
42-
43-
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
44-
45-
// Construct the prompt
46-
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
47-
const char* prompt = [promptString UTF8String];
48-
49-
// Encode the prompt
50-
auto sequences = OgaSequences::Create();
51-
tokenizer->Encode(prompt, *sequences);
52-
53-
size_t promptTokensCount = sequences->SequenceCount(0);
54-
55-
NSLog(@"Setting generator parameters...");
56-
auto params = OgaGeneratorParams::Create(*model);
57-
params->SetSearchOption("max_length", 200);
58-
params->SetInputSequences(*sequences);
59-
60-
auto generator = OgaGenerator::Create(*model, *params);
61-
62-
bool isFirstToken = true;
63-
NSLog(@"Starting token generation loop...");
64-
65-
startTime = Clock::now();
66-
while (!generator->IsDone()) {
67-
tokenStartTime = Clock::now();
68-
69-
generator->ComputeLogits();
70-
generator->GenerateNextToken();
71-
72-
if (isFirstToken) {
73-
firstTokenTime = Clock::now();
74-
isFirstToken = false;
75-
}
76-
77-
// Get the sequence data and decode the token
78-
const int32_t* seq = generator->GetSequenceData(0);
79-
size_t seq_len = generator->GetSequenceCount(0);
80-
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
81-
82-
if (!decode_tokens) {
83-
@throw [NSException exceptionWithName:@"TokenDecodeError" reason:@"Token decoding failed." userInfo:nil];
84-
}
85-
86-
// Measure token generation time excluding logging
87-
TimePoint tokenEndTime = Clock::now();
88-
auto tokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(tokenEndTime - tokenStartTime).count();
89-
tokenTimes.push_back(tokenDuration);
90-
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
91-
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
92-
}
93-
94-
TimePoint endTime = Clock::now();
95-
// Log token times
96-
NSLog(@"Per-token generation times: %@", [self formatTokenTimes:tokenTimes]);
97-
98-
// Calculate metrics
99-
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
100-
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();
101-
102-
double promtProcTime = (double)promptTokensCount / firstTokenDuration;
103-
double tokenGenRate = (double)(tokenTimes.size() - 1) * 1000.0 / (totalDuration - firstTokenDuration);
104-
105-
NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %zu", totalDuration, firstTokenDuration, tokenTimes.size());
106-
NSLog(@"Prompt tokens: %zu, Prompt Processing Time: %f tokens/s", promptTokensCount, promtProcTime);
107-
NSLog(@"Generated tokens: %zu, Token Generation Rate: %f tokens/s", tokenTimes.size(), tokenGenRate);
108-
109-
110-
NSDictionary *stats = @{
111-
@"tokenGenRate" : @(tokenGenRate),
112-
@"promptProcRate": @(promtProcTime)
113-
};
114-
// notify main thread that token generation is complete
115-
dispatch_async(dispatch_get_main_queue(), ^{
116-
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
117-
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
118-
});
119-
120-
NSLog(@"Token generation completed.");
121-
122-
} @catch (NSException* e) {
123-
NSString* errorMessage = e.reason;
124-
NSLog(@"Error during generation: %@", errorMessage);
125-
126-
// Send error to the UI
127-
NSDictionary *errorInfo = @{@"error": errorMessage};
128-
dispatch_async(dispatch_get_main_queue(), ^{
129-
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo];
130-
});
20+
21+
- (instancetype)init {
22+
self = [super init];
23+
if (self) {
24+
self->model = nullptr;
25+
self->tokenizer = nullptr;
26+
}
27+
return self;
28+
}
29+
30+
- (void)generate:(nonnull NSString*)input_user_question {
31+
std::vector<long long> tokenTimes; // per-token generation times
32+
TimePoint startTime, firstTokenTime, tokenStartTime;
33+
34+
try {
35+
NSLog(@"Starting token generation...");
36+
37+
if (!self->model) {
38+
NSLog(@"Creating model...");
39+
NSString* llmPath = [[NSBundle mainBundle] resourcePath];
40+
const char* modelPath = llmPath.cString;
41+
self->model = OgaModel::Create(modelPath); // throws exception
42+
}
43+
44+
if (!self->tokenizer) {
45+
NSLog(@"Creating tokenizer...");
46+
self->tokenizer = OgaTokenizer::Create(*self->model); // throws exception
47+
}
48+
49+
auto tokenizer_stream = OgaTokenizerStream::Create(*self->tokenizer);
50+
51+
// Construct the prompt
52+
NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question];
53+
const char* prompt = [promptString UTF8String];
54+
55+
// Encode the prompt
56+
auto sequences = OgaSequences::Create();
57+
self->tokenizer->Encode(prompt, *sequences);
58+
59+
size_t promptTokensCount = sequences->SequenceCount(0);
60+
61+
NSLog(@"Setting generator parameters...");
62+
auto params = OgaGeneratorParams::Create(*self->model);
63+
params->SetSearchOption("max_length", 200);
64+
params->SetInputSequences(*sequences);
65+
66+
auto generator = OgaGenerator::Create(*self->model, *params);
67+
68+
bool isFirstToken = true;
69+
NSLog(@"Starting token generation loop...");
70+
71+
startTime = Clock::now();
72+
while (!generator->IsDone()) {
73+
tokenStartTime = Clock::now();
74+
75+
generator->ComputeLogits();
76+
generator->GenerateNextToken();
77+
78+
if (isFirstToken) {
79+
firstTokenTime = Clock::now();
80+
isFirstToken = false;
81+
}
82+
83+
// Get the sequence data and decode the token
84+
const int32_t* seq = generator->GetSequenceData(0);
85+
size_t seq_len = generator->GetSequenceCount(0);
86+
const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]);
87+
88+
if (!decode_tokens) {
89+
@throw [NSException exceptionWithName:@"TokenDecodeError" reason:@"Token decoding failed." userInfo:nil];
90+
}
91+
92+
// Measure token generation time excluding logging
93+
TimePoint tokenEndTime = Clock::now();
94+
auto tokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(tokenEndTime - tokenStartTime).count();
95+
tokenTimes.push_back(tokenDuration);
96+
NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens];
97+
[SharedTokenUpdater.shared addDecodedToken:decodedTokenString];
13198
}
99+
100+
TimePoint endTime = Clock::now();
101+
// Log token times
102+
NSLog(@"Per-token generation times: %@", [self formatTokenTimes:tokenTimes]);
103+
104+
// Calculate metrics
105+
auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
106+
auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count();
107+
108+
double promptProcRate = (double)promptTokensCount * 1000.0 / firstTokenDuration;
109+
double tokenGenRate = (double)(tokenTimes.size() - 1) * 1000.0 / (totalDuration - firstTokenDuration);
110+
111+
NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %zu",
112+
totalDuration, firstTokenDuration, tokenTimes.size());
113+
NSLog(@"Prompt tokens: %zu, Prompt Processing Rate: %f tokens/s", promptTokensCount, promptProcRate);
114+
NSLog(@"Generated tokens: %zu, Token Generation Rate: %f tokens/s", tokenTimes.size(), tokenGenRate);
115+
116+
NSDictionary* stats = @{@"tokenGenRate" : @(tokenGenRate), @"promptProcRate" : @(promptProcRate)};
117+
// notify main thread that token generation is complete
118+
dispatch_async(dispatch_get_main_queue(), ^{
119+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats];
120+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil];
121+
});
122+
123+
NSLog(@"Token generation completed.");
124+
125+
} catch (const std::exception& e) {
126+
NSString* errorMessage = [NSString stringWithUTF8String:e.what()];
127+
NSLog(@"Error during generation: %@", errorMessage);
128+
129+
// Send error to the UI
130+
NSDictionary* errorInfo = @{@"error" : errorMessage};
131+
dispatch_async(dispatch_get_main_queue(), ^{
132+
[[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo];
133+
});
134+
}
132135
}
133136

134137
// Utility function to format token times for logging
135-
+ (NSString*)formatTokenTimes:(const std::vector<long long>&)tokenTimes {
136-
NSMutableString *formattedTimes = [NSMutableString string];
137-
for (size_t i = 0; i < tokenTimes.size(); i++) {
138-
[formattedTimes appendFormat:@"%lld ms, ", tokenTimes[i]];
139-
}
140-
return [formattedTimes copy];
138+
- (NSString*)formatTokenTimes:(const std::vector<long long>&)tokenTimes {
139+
NSMutableString* formattedTimes = [NSMutableString string];
140+
for (size_t i = 0; i < tokenTimes.size(); i++) {
141+
[formattedTimes appendFormat:@"%lld ms, ", tokenTimes[i]];
142+
}
143+
return [formattedTimes copy];
141144
}
142145

143146
@end

0 commit comments

Comments
 (0)