|
1 | 1 | // Copyright (c) Microsoft Corporation. All rights reserved.
|
2 | 2 | // Licensed under the MIT License.
|
3 | 3 | #import "GenAIGenerator.h"
|
| 4 | +#include <chrono> |
| 5 | +#include <vector> |
4 | 6 | #include "LocalLLM-Swift.h"
|
5 | 7 | #include "ort_genai.h"
|
6 | 8 | #include "ort_genai_c.h"
|
7 |
| -#include <chrono> |
8 |
| -#include <vector> |
| 9 | + |
| 10 | +@interface GenAIGenerator () { |
| 11 | + std::unique_ptr<OgaModel> model; |
| 12 | + std::unique_ptr<OgaTokenizer> tokenizer; |
| 13 | +} |
| 14 | +@end |
9 | 15 |
|
10 | 16 | @implementation GenAIGenerator
|
11 | 17 |
|
12 | 18 | typedef std::chrono::steady_clock Clock;
|
13 | 19 | typedef std::chrono::time_point<Clock> TimePoint;
|
14 |
| -static std::unique_ptr<OgaModel> model = nullptr; |
15 |
| -static std::unique_ptr<OgaTokenizer> tokenizer = nullptr; |
16 |
| - |
17 |
| -+ (void)generate:(nonnull NSString*)input_user_question { |
18 |
| - std::vector<long long> tokenTimes; // per-token generation times |
19 |
| - TimePoint startTime, firstTokenTime, tokenStartTime; |
20 |
| - |
21 |
| - @try { |
22 |
| - NSLog(@"Starting token generation..."); |
23 |
| - |
24 |
| - if (!model) { |
25 |
| - NSLog(@"Creating model..."); |
26 |
| - NSString* llmPath = [[NSBundle mainBundle] resourcePath]; |
27 |
| - const char* modelPath = llmPath.cString; |
28 |
| - model = OgaModel::Create(modelPath); // throws exception |
29 |
| - |
30 |
| - if (!model) { |
31 |
| - @throw [NSException exceptionWithName:@"ModelCreationError" reason:@"Failed to create model." userInfo:nil]; |
32 |
| - } |
33 |
| - } |
34 |
| - |
35 |
| - if (!tokenizer) { |
36 |
| - NSLog(@"Creating tokenizer..."); |
37 |
| - tokenizer = OgaTokenizer::Create(*model); // throws exception |
38 |
| - if (!tokenizer) { |
39 |
| - @throw [NSException exceptionWithName:@"TokenizerCreationError" reason:@"Failed to create tokenizer." userInfo:nil]; |
40 |
| - } |
41 |
| - } |
42 |
| - |
43 |
| - auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); |
44 |
| - |
45 |
| - // Construct the prompt |
46 |
| - NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; |
47 |
| - const char* prompt = [promptString UTF8String]; |
48 |
| - |
49 |
| - // Encode the prompt |
50 |
| - auto sequences = OgaSequences::Create(); |
51 |
| - tokenizer->Encode(prompt, *sequences); |
52 |
| - |
53 |
| - size_t promptTokensCount = sequences->SequenceCount(0); |
54 |
| - |
55 |
| - NSLog(@"Setting generator parameters..."); |
56 |
| - auto params = OgaGeneratorParams::Create(*model); |
57 |
| - params->SetSearchOption("max_length", 200); |
58 |
| - params->SetInputSequences(*sequences); |
59 |
| - |
60 |
| - auto generator = OgaGenerator::Create(*model, *params); |
61 |
| - |
62 |
| - bool isFirstToken = true; |
63 |
| - NSLog(@"Starting token generation loop..."); |
64 |
| - |
65 |
| - startTime = Clock::now(); |
66 |
| - while (!generator->IsDone()) { |
67 |
| - tokenStartTime = Clock::now(); |
68 |
| - |
69 |
| - generator->ComputeLogits(); |
70 |
| - generator->GenerateNextToken(); |
71 |
| - |
72 |
| - if (isFirstToken) { |
73 |
| - firstTokenTime = Clock::now(); |
74 |
| - isFirstToken = false; |
75 |
| - } |
76 |
| - |
77 |
| - // Get the sequence data and decode the token |
78 |
| - const int32_t* seq = generator->GetSequenceData(0); |
79 |
| - size_t seq_len = generator->GetSequenceCount(0); |
80 |
| - const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); |
81 |
| - |
82 |
| - if (!decode_tokens) { |
83 |
| - @throw [NSException exceptionWithName:@"TokenDecodeError" reason:@"Token decoding failed." userInfo:nil]; |
84 |
| - } |
85 |
| - |
86 |
| - // Measure token generation time excluding logging |
87 |
| - TimePoint tokenEndTime = Clock::now(); |
88 |
| - auto tokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(tokenEndTime - tokenStartTime).count(); |
89 |
| - tokenTimes.push_back(tokenDuration); |
90 |
| - NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; |
91 |
| - [SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; |
92 |
| - } |
93 |
| - |
94 |
| - TimePoint endTime = Clock::now(); |
95 |
| - // Log token times |
96 |
| - NSLog(@"Per-token generation times: %@", [self formatTokenTimes:tokenTimes]); |
97 |
| - |
98 |
| - // Calculate metrics |
99 |
| - auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count(); |
100 |
| - auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count(); |
101 |
| - |
102 |
| - double promtProcTime = (double)promptTokensCount / firstTokenDuration; |
103 |
| - double tokenGenRate = (double)(tokenTimes.size() - 1) * 1000.0 / (totalDuration - firstTokenDuration); |
104 |
| - |
105 |
| - NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %zu", totalDuration, firstTokenDuration, tokenTimes.size()); |
106 |
| - NSLog(@"Prompt tokens: %zu, Prompt Processing Time: %f tokens/s", promptTokensCount, promtProcTime); |
107 |
| - NSLog(@"Generated tokens: %zu, Token Generation Rate: %f tokens/s", tokenTimes.size(), tokenGenRate); |
108 |
| - |
109 |
| - |
110 |
| - NSDictionary *stats = @{ |
111 |
| - @"tokenGenRate" : @(tokenGenRate), |
112 |
| - @"promptProcRate": @(promtProcTime) |
113 |
| - }; |
114 |
| - // notify main thread that token generation is complete |
115 |
| - dispatch_async(dispatch_get_main_queue(), ^{ |
116 |
| - [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats]; |
117 |
| - [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil]; |
118 |
| - }); |
119 |
| - |
120 |
| - NSLog(@"Token generation completed."); |
121 |
| - |
122 |
| - } @catch (NSException* e) { |
123 |
| - NSString* errorMessage = e.reason; |
124 |
| - NSLog(@"Error during generation: %@", errorMessage); |
125 |
| - |
126 |
| - // Send error to the UI |
127 |
| - NSDictionary *errorInfo = @{@"error": errorMessage}; |
128 |
| - dispatch_async(dispatch_get_main_queue(), ^{ |
129 |
| - [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo]; |
130 |
| - }); |
| 20 | + |
| 21 | +- (instancetype)init { |
| 22 | + self = [super init]; |
| 23 | + if (self) { |
| 24 | + self->model = nullptr; |
| 25 | + self->tokenizer = nullptr; |
| 26 | + } |
| 27 | + return self; |
| 28 | +} |
| 29 | + |
| 30 | +- (void)generate:(nonnull NSString*)input_user_question { |
| 31 | + std::vector<long long> tokenTimes; // per-token generation times |
| 32 | + TimePoint startTime, firstTokenTime, tokenStartTime; |
| 33 | + |
| 34 | + try { |
| 35 | + NSLog(@"Starting token generation..."); |
| 36 | + |
| 37 | + if (!self->model) { |
| 38 | + NSLog(@"Creating model..."); |
| 39 | + NSString* llmPath = [[NSBundle mainBundle] resourcePath]; |
| 40 | + const char* modelPath = llmPath.cString; |
| 41 | + self->model = OgaModel::Create(modelPath); // throws exception |
| 42 | + } |
| 43 | + |
| 44 | + if (!self->tokenizer) { |
| 45 | + NSLog(@"Creating tokenizer..."); |
| 46 | + self->tokenizer = OgaTokenizer::Create(*self->model); // throws exception |
| 47 | + } |
| 48 | + |
| 49 | + auto tokenizer_stream = OgaTokenizerStream::Create(*self->tokenizer); |
| 50 | + |
| 51 | + // Construct the prompt |
| 52 | + NSString* promptString = [NSString stringWithFormat:@"<|user|>\n%@<|end|>\n<|assistant|>", input_user_question]; |
| 53 | + const char* prompt = [promptString UTF8String]; |
| 54 | + |
| 55 | + // Encode the prompt |
| 56 | + auto sequences = OgaSequences::Create(); |
| 57 | + self->tokenizer->Encode(prompt, *sequences); |
| 58 | + |
| 59 | + size_t promptTokensCount = sequences->SequenceCount(0); |
| 60 | + |
| 61 | + NSLog(@"Setting generator parameters..."); |
| 62 | + auto params = OgaGeneratorParams::Create(*self->model); |
| 63 | + params->SetSearchOption("max_length", 200); |
| 64 | + params->SetInputSequences(*sequences); |
| 65 | + |
| 66 | + auto generator = OgaGenerator::Create(*self->model, *params); |
| 67 | + |
| 68 | + bool isFirstToken = true; |
| 69 | + NSLog(@"Starting token generation loop..."); |
| 70 | + |
| 71 | + startTime = Clock::now(); |
| 72 | + while (!generator->IsDone()) { |
| 73 | + tokenStartTime = Clock::now(); |
| 74 | + |
| 75 | + generator->ComputeLogits(); |
| 76 | + generator->GenerateNextToken(); |
| 77 | + |
| 78 | + if (isFirstToken) { |
| 79 | + firstTokenTime = Clock::now(); |
| 80 | + isFirstToken = false; |
| 81 | + } |
| 82 | + |
| 83 | + // Get the sequence data and decode the token |
| 84 | + const int32_t* seq = generator->GetSequenceData(0); |
| 85 | + size_t seq_len = generator->GetSequenceCount(0); |
| 86 | + const char* decode_tokens = tokenizer_stream->Decode(seq[seq_len - 1]); |
| 87 | + |
| 88 | + if (!decode_tokens) { |
| 89 | + @throw [NSException exceptionWithName:@"TokenDecodeError" reason:@"Token decoding failed." userInfo:nil]; |
| 90 | + } |
| 91 | + |
| 92 | + // Measure token generation time excluding logging |
| 93 | + TimePoint tokenEndTime = Clock::now(); |
| 94 | + auto tokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(tokenEndTime - tokenStartTime).count(); |
| 95 | + tokenTimes.push_back(tokenDuration); |
| 96 | + NSString* decodedTokenString = [NSString stringWithUTF8String:decode_tokens]; |
| 97 | + [SharedTokenUpdater.shared addDecodedToken:decodedTokenString]; |
131 | 98 | }
|
| 99 | + |
| 100 | + TimePoint endTime = Clock::now(); |
| 101 | + // Log token times |
| 102 | + NSLog(@"Per-token generation times: %@", [self formatTokenTimes:tokenTimes]); |
| 103 | + |
| 104 | + // Calculate metrics |
| 105 | + auto totalDuration = std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count(); |
| 106 | + auto firstTokenDuration = std::chrono::duration_cast<std::chrono::milliseconds>(firstTokenTime - startTime).count(); |
| 107 | + |
| 108 | + double promptProcRate = (double)promptTokensCount * 1000.0 / firstTokenDuration; |
| 109 | + double tokenGenRate = (double)(tokenTimes.size() - 1) * 1000.0 / (totalDuration - firstTokenDuration); |
| 110 | + |
| 111 | + NSLog(@"Token generation completed. Total time: %lld ms, First token time: %lld ms, Total tokens: %zu", |
| 112 | + totalDuration, firstTokenDuration, tokenTimes.size()); |
| 113 | + NSLog(@"Prompt tokens: %zu, Prompt Processing Rate: %f tokens/s", promptTokensCount, promptProcRate); |
| 114 | + NSLog(@"Generated tokens: %zu, Token Generation Rate: %f tokens/s", tokenTimes.size(), tokenGenRate); |
| 115 | + |
| 116 | + NSDictionary* stats = @{@"tokenGenRate" : @(tokenGenRate), @"promptProcRate" : @(promptProcRate)}; |
| 117 | + // notify main thread that token generation is complete |
| 118 | + dispatch_async(dispatch_get_main_queue(), ^{ |
| 119 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationStats" object:nil userInfo:stats]; |
| 120 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationCompleted" object:nil]; |
| 121 | + }); |
| 122 | + |
| 123 | + NSLog(@"Token generation completed."); |
| 124 | + |
| 125 | + } catch (const std::exception& e) { |
| 126 | + NSString* errorMessage = [NSString stringWithUTF8String:e.what()]; |
| 127 | + NSLog(@"Error during generation: %@", errorMessage); |
| 128 | + |
| 129 | + // Send error to the UI |
| 130 | + NSDictionary* errorInfo = @{@"error" : errorMessage}; |
| 131 | + dispatch_async(dispatch_get_main_queue(), ^{ |
| 132 | + [[NSNotificationCenter defaultCenter] postNotificationName:@"TokenGenerationError" object:nil userInfo:errorInfo]; |
| 133 | + }); |
| 134 | + } |
132 | 135 | }
|
133 | 136 |
|
134 | 137 | // Utility function to format token times for logging
|
135 |
| -+ (NSString*)formatTokenTimes:(const std::vector<long long>&)tokenTimes { |
136 |
| - NSMutableString *formattedTimes = [NSMutableString string]; |
137 |
| - for (size_t i = 0; i < tokenTimes.size(); i++) { |
138 |
| - [formattedTimes appendFormat:@"%lld ms, ", tokenTimes[i]]; |
139 |
| - } |
140 |
| - return [formattedTimes copy]; |
| 138 | +- (NSString*)formatTokenTimes:(const std::vector<long long>&)tokenTimes { |
| 139 | + NSMutableString* formattedTimes = [NSMutableString string]; |
| 140 | + for (size_t i = 0; i < tokenTimes.size(); i++) { |
| 141 | + [formattedTimes appendFormat:@"%lld ms, ", tokenTimes[i]]; |
| 142 | + } |
| 143 | + return [formattedTimes copy]; |
141 | 144 | }
|
142 | 145 |
|
143 | 146 | @end
|
0 commit comments