Skip to content

Commit f3603e5

Browse files
authoredNov 6, 2024
Merge pull request ChatGPTNextWeb#5769 from ryanhex53/fix-model-multi@
Custom model names can include the `@` symbol by itself.
2 parents 00d6cb2 + 8e2484f commit f3603e5

File tree

6 files changed

+60
-11
lines changed

6 files changed

+60
-11
lines changed
 

‎app/api/common.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import { NextRequest, NextResponse } from "next/server";
22
import { getServerSideConfig } from "../config/server";
33
import { OPENAI_BASE_URL, ServiceProvider } from "../constant";
4-
import { isModelAvailableInServer } from "../utils/model";
54
import { cloudflareAIGatewayUrl } from "../utils/cloudflare";
5+
import { getModelProvider, isModelAvailableInServer } from "../utils/model";
66

77
const serverConfig = getServerSideConfig();
88

@@ -71,7 +71,7 @@ export async function requestOpenai(req: NextRequest) {
7171
.filter((v) => !!v && !v.startsWith("-") && v.includes(modelName))
7272
.forEach((m) => {
7373
const [fullName, displayName] = m.split("=");
74-
const [_, providerName] = fullName.split("@");
74+
const [_, providerName] = getModelProvider(fullName);
7575
if (providerName === "azure" && !displayName) {
7676
const [_, deployId] = (serverConfig?.azureUrl ?? "").split(
7777
"deployments/",

‎app/components/chat.tsx

+2-1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ import { createTTSPlayer } from "../utils/audio";
120120
import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts";
121121

122122
import { isEmpty } from "lodash-es";
123+
import { getModelProvider } from "../utils/model";
123124

124125
const localStorage = safeLocalStorage();
125126

@@ -645,7 +646,7 @@ export function ChatActions(props: {
645646
onClose={() => setShowModelSelector(false)}
646647
onSelection={(s) => {
647648
if (s.length === 0) return;
648-
const [model, providerName] = s[0].split("@");
649+
const [model, providerName] = getModelProvider(s[0]);
649650
chatStore.updateCurrentSession((session) => {
650651
session.mask.modelConfig.model = model as ModelType;
651652
session.mask.modelConfig.providerName =

‎app/components/model-config.tsx

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { ListItem, Select } from "./ui-lib";
77
import { useAllModels } from "../utils/hooks";
88
import { groupBy } from "lodash-es";
99
import styles from "./model-config.module.scss";
10+
import { getModelProvider } from "../utils/model";
1011

1112
export function ModelConfigList(props: {
1213
modelConfig: ModelConfig;
@@ -28,7 +29,9 @@ export function ModelConfigList(props: {
2829
value={value}
2930
align="left"
3031
onChange={(e) => {
31-
const [model, providerName] = e.currentTarget.value.split("@");
32+
const [model, providerName] = getModelProvider(
33+
e.currentTarget.value,
34+
);
3235
props.updateConfig((config) => {
3336
config.model = ModalConfigValidator.model(model);
3437
config.providerName = providerName as ServiceProvider;
@@ -247,7 +250,9 @@ export function ModelConfigList(props: {
247250
aria-label={Locale.Settings.CompressModel.Title}
248251
value={compressModelValue}
249252
onChange={(e) => {
250-
const [model, providerName] = e.currentTarget.value.split("@");
253+
const [model, providerName] = getModelProvider(
254+
e.currentTarget.value,
255+
);
251256
props.updateConfig((config) => {
252257
config.compressModel = ModalConfigValidator.model(model);
253258
config.compressProviderName = providerName as ServiceProvider;

‎app/store/access.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { getClientConfig } from "../config/client";
2121
import { createPersistStore } from "../utils/store";
2222
import { ensure } from "../utils/clone";
2323
import { DEFAULT_CONFIG } from "./config";
24+
import { getModelProvider } from "../utils/model";
2425

2526
let fetchState = 0; // 0 not fetch, 1 fetching, 2 done
2627

@@ -226,9 +227,9 @@ export const useAccessStore = createPersistStore(
226227
.then((res) => {
227228
const defaultModel = res.defaultModel ?? "";
228229
if (defaultModel !== "") {
229-
const [model, providerName] = defaultModel.split("@");
230+
const [model, providerName] = getModelProvider(defaultModel);
230231
DEFAULT_CONFIG.modelConfig.model = model;
231-
DEFAULT_CONFIG.modelConfig.providerName = providerName;
232+
DEFAULT_CONFIG.modelConfig.providerName = providerName as any;
232233
}
233234

234235
return res;

‎app/utils/model.ts

+15-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ const sortModelTable = (models: ReturnType<typeof collectModels>) =>
3737
}
3838
});
3939

40+
/**
41+
* get model name and provider from a formatted string,
42+
* e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
43+
* @param modelWithProvider model name with provider separated by last `@` char,
44+
* @returns [model, provider] tuple, if no `@` char found, provider is undefined
45+
*/
46+
export function getModelProvider(modelWithProvider: string): [string, string?] {
47+
const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
48+
return [model, provider];
49+
}
50+
4051
export function collectModelTable(
4152
models: readonly LLMModel[],
4253
customModels: string,
@@ -79,10 +90,10 @@ export function collectModelTable(
7990
);
8091
} else {
8192
// 1. find model by name, and set available value
82-
const [customModelName, customProviderName] = name.split("@");
93+
const [customModelName, customProviderName] = getModelProvider(name);
8394
let count = 0;
8495
for (const fullName in modelTable) {
85-
const [modelName, providerName] = fullName.split("@");
96+
const [modelName, providerName] = getModelProvider(fullName);
8697
if (
8798
customModelName == modelName &&
8899
(customProviderName === undefined ||
@@ -102,7 +113,7 @@ export function collectModelTable(
102113
}
103114
// 2. if model not exists, create new model with available value
104115
if (count === 0) {
105-
let [customModelName, customProviderName] = name.split("@");
116+
let [customModelName, customProviderName] = getModelProvider(name);
106117
const provider = customProvider(
107118
customProviderName || customModelName,
108119
);
@@ -139,7 +150,7 @@ export function collectModelTableWithDefaultModel(
139150
for (const key of Object.keys(modelTable)) {
140151
if (
141152
modelTable[key].available &&
142-
key.split("@").shift() == defaultModel
153+
getModelProvider(key)[0] == defaultModel
143154
) {
144155
modelTable[key].isDefault = true;
145156
break;

‎test/model-provider.test.ts

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { getModelProvider } from "../app/utils/model";
2+
3+
describe("getModelProvider", () => {
4+
test("should return model and provider when input contains '@'", () => {
5+
const input = "model@provider";
6+
const [model, provider] = getModelProvider(input);
7+
expect(model).toBe("model");
8+
expect(provider).toBe("provider");
9+
});
10+
11+
test("should return model and undefined provider when input does not contain '@'", () => {
12+
const input = "model";
13+
const [model, provider] = getModelProvider(input);
14+
expect(model).toBe("model");
15+
expect(provider).toBeUndefined();
16+
});
17+
18+
test("should handle multiple '@' characters correctly", () => {
19+
const input = "model@provider@extra";
20+
const [model, provider] = getModelProvider(input);
21+
expect(model).toBe("model@provider");
22+
expect(provider).toBe("extra");
23+
});
24+
25+
test("should return empty strings when input is empty", () => {
26+
const input = "";
27+
const [model, provider] = getModelProvider(input);
28+
expect(model).toBe("");
29+
expect(provider).toBeUndefined();
30+
});
31+
});

0 commit comments

Comments
 (0)
Please sign in to comment.