Browse Source

feat: close #2192 use /list/models to get model ids

Yidadaa 1 year ago
parent
commit
4131fccbe0

+ 1 - 1
app/api/config/route.ts

@@ -9,7 +9,7 @@ const serverConfig = getServerSideConfig();
 const DANGER_CONFIG = {
   needCode: serverConfig.needCode,
   hideUserApiKey: serverConfig.hideUserApiKey,
-  enableGPT4: serverConfig.enableGPT4,
+  disableGPT4: serverConfig.disableGPT4,
   hideBalanceQuery: serverConfig.hideBalanceQuery,
 };
 

+ 26 - 1
app/api/openai/[...path]/route.ts

@@ -1,3 +1,5 @@
+import { type OpenAIListModelResponse } from "@/app/client/platforms/openai";
+import { getServerSideConfig } from "@/app/config/server";
 import { OpenaiPath } from "@/app/constant";
 import { prettyObject } from "@/app/utils/format";
 import { NextRequest, NextResponse } from "next/server";
@@ -6,6 +8,18 @@ import { requestOpenai } from "../../common";
 
 const ALLOWD_PATH = new Set(Object.values(OpenaiPath));
 
+function getModels(remoteModelRes: OpenAIListModelResponse) {
+  const config = getServerSideConfig();
+
+  if (config.disableGPT4) {
+    remoteModelRes.data = remoteModelRes.data.filter(
+      (m) => !m.id.startsWith("gpt-4"),
+    );
+  }
+
+  return remoteModelRes;
+}
+
 async function handle(
   req: NextRequest,
   { params }: { params: { path: string[] } },
@@ -39,7 +53,18 @@ async function handle(
   }
 
   try {
-    return await requestOpenai(req);
+    const response = await requestOpenai(req);
+
+    // list models
+    if (subpath === OpenaiPath.ListModelPath && response.status === 200) {
+      const resJson = (await response.json()) as OpenAIListModelResponse;
+      const availableModels = getModels(resJson);
+      return NextResponse.json(availableModels, {
+        status: response.status,
+      });
+    }
+
+    return response;
   } catch (e) {
     console.error("[OpenAI] ", e);
     return NextResponse.json(prettyObject(e));

+ 6 - 0
app/client/api.ts

@@ -38,9 +38,15 @@ export interface LLMUsage {
   total: number;
 }
 
+export interface LLMModel {
+  name: string;
+  available: boolean;
+}
+
 export abstract class LLMApi {
   abstract chat(options: ChatOptions): Promise<void>;
   abstract usage(): Promise<LLMUsage>;
+  abstract models(): Promise<LLMModel[]>;
 }
 
 type ProviderName = "openai" | "azure" | "claude" | "palm";

+ 31 - 1
app/client/platforms/openai.ts

@@ -5,7 +5,7 @@ import {
 } from "@/app/constant";
 import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
 
-import { ChatOptions, getHeaders, LLMApi, LLMUsage } from "../api";
+import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
 import Locale from "../../locales";
 import {
   EventStreamContentType,
@@ -13,6 +13,15 @@ import {
 } from "@fortaine/fetch-event-source";
 import { prettyObject } from "@/app/utils/format";
 
+export interface OpenAIListModelResponse {
+  object: string;
+  data: Array<{
+    id: string;
+    object: string;
+    root: string;
+  }>;
+}
+
 export class ChatGPTApi implements LLMApi {
   path(path: string): string {
     let openaiUrl = useAccessStore.getState().openaiUrl;
@@ -22,6 +31,9 @@ export class ChatGPTApi implements LLMApi {
     if (openaiUrl.endsWith("/")) {
       openaiUrl = openaiUrl.slice(0, openaiUrl.length - 1);
     }
+    if (!openaiUrl.startsWith("http") && !openaiUrl.startsWith("/api/openai")) {
+      openaiUrl = "https://" + openaiUrl;
+    }
     return [openaiUrl, path].join("/");
   }
 
@@ -232,5 +244,23 @@ export class ChatGPTApi implements LLMApi {
       total: total.hard_limit_usd,
     } as LLMUsage;
   }
+
+  async models(): Promise<LLMModel[]> {
+    const res = await fetch(this.path(OpenaiPath.ListModelPath), {
+      method: "GET",
+      headers: {
+        ...getHeaders(),
+      },
+    });
+
+    const resJson = (await res.json()) as OpenAIListModelResponse;
+    const chatModels = resJson.data.filter((m) => m.id.startsWith("gpt-"));
+    console.log("[Models]", chatModels);
+
+    return chatModels.map((m) => ({
+      name: m.id,
+      available: true,
+    }));
+  }
 }
 export { OpenaiPath };

+ 3 - 4
app/components/chat.tsx

@@ -42,12 +42,11 @@ import {
   Theme,
   useAppConfig,
   DEFAULT_TOPIC,
-  ALL_MODELS,
+  ModelType,
 } from "../store";
 
 import {
   copyToClipboard,
-  downloadAs,
   selectOrCopy,
   autoGrowTextArea,
   useMobileScreen,
@@ -387,12 +386,12 @@ export function ChatActions(props: {
   // switch model
   const currentModel = chatStore.currentSession().mask.modelConfig.model;
   function nextModel() {
-    const models = ALL_MODELS.filter((m) => m.available).map((m) => m.name);
+    const models = config.models.filter((m) => m.available).map((m) => m.name);
     const modelIndex = models.indexOf(currentModel);
     const nextIndex = (modelIndex + 1) % models.length;
     const nextModel = models[nextIndex];
     chatStore.updateCurrentSession((session) => {
-      session.mask.modelConfig.model = nextModel;
+      session.mask.modelConfig.model = nextModel as ModelType;
       session.mask.syncGlobalConfig = false;
     });
   }

+ 14 - 0
app/components/home.tsx

@@ -27,6 +27,7 @@ import { SideBar } from "./sidebar";
 import { useAppConfig } from "../store/config";
 import { AuthPage } from "./auth";
 import { getClientConfig } from "../config/client";
+import { api } from "../client/api";
 
 export function Loading(props: { noLogo?: boolean }) {
   return (
@@ -152,8 +153,21 @@ function Screen() {
   );
 }
 
+export function useLoadData() {
+  const config = useAppConfig();
+
+  useEffect(() => {
+    (async () => {
+      const models = await api.llm.models();
+      config.mergeModels(models);
+    })();
+    // eslint-disable-next-line react-hooks/exhaustive-deps
+  }, []);
+}
+
 export function Home() {
   useSwitchTheme();
+  useLoadData();
 
   useEffect(() => {
     console.log("[Config] got config from build time", getClientConfig());

+ 4 - 2
app/components/model-config.tsx

@@ -1,4 +1,4 @@
-import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
+import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store";
 
 import Locale from "../locales";
 import { InputRange } from "./input-range";
@@ -8,6 +8,8 @@ export function ModelConfigList(props: {
   modelConfig: ModelConfig;
   updateConfig: (updater: (config: ModelConfig) => void) => void;
 }) {
+  const config = useAppConfig();
+
   return (
     <>
       <ListItem title={Locale.Settings.Model}>
@@ -22,7 +24,7 @@ export function ModelConfigList(props: {
             );
           }}
         >
-          {ALL_MODELS.map((v) => (
+          {config.models.map((v) => (
             <option value={v.name} key={v.name} disabled={!v.available}>
               {v.name}
             </option>

+ 32 - 29
app/components/settings.tsx

@@ -340,6 +340,10 @@ export function Settings() {
   };
   const [loadingUsage, setLoadingUsage] = useState(false);
   function checkUsage(force = false) {
+    if (accessStore.hideBalanceQuery) {
+      return;
+    }
+
     setLoadingUsage(true);
     updateStore.updateUsage(force).finally(() => {
       setLoadingUsage(false);
@@ -577,19 +581,34 @@ export function Settings() {
           )}
 
           {!accessStore.hideUserApiKey ? (
-            <ListItem
-              title={Locale.Settings.Token.Title}
-              subTitle={Locale.Settings.Token.SubTitle}
-            >
-              <PasswordInput
-                value={accessStore.token}
-                type="text"
-                placeholder={Locale.Settings.Token.Placeholder}
-                onChange={(e) => {
-                  accessStore.updateToken(e.currentTarget.value);
-                }}
-              />
-            </ListItem>
+            <>
+              <ListItem
+                title={Locale.Settings.Endpoint.Title}
+                subTitle={Locale.Settings.Endpoint.SubTitle}
+              >
+                <input
+                  type="text"
+                  value={accessStore.openaiUrl}
+                  placeholder="https://api.openai.com/"
+                  onChange={(e) =>
+                    accessStore.updateOpenAiUrl(e.currentTarget.value)
+                  }
+                ></input>
+              </ListItem>
+              <ListItem
+                title={Locale.Settings.Token.Title}
+                subTitle={Locale.Settings.Token.SubTitle}
+              >
+                <PasswordInput
+                  value={accessStore.token}
+                  type="text"
+                  placeholder={Locale.Settings.Token.Placeholder}
+                  onChange={(e) => {
+                    accessStore.updateToken(e.currentTarget.value);
+                  }}
+                />
+              </ListItem>
+            </>
           ) : null}
 
           {!accessStore.hideBalanceQuery ? (
@@ -617,22 +636,6 @@ export function Settings() {
               )}
             </ListItem>
           ) : null}
-
-          {!accessStore.hideUserApiKey ? (
-            <ListItem
-              title={Locale.Settings.Endpoint.Title}
-              subTitle={Locale.Settings.Endpoint.SubTitle}
-            >
-              <input
-                type="text"
-                value={accessStore.openaiUrl}
-                placeholder="https://api.openai.com/"
-                onChange={(e) =>
-                  accessStore.updateOpenAiUrl(e.currentTarget.value)
-                }
-              ></input>
-            </ListItem>
-          ) : null}
         </List>
 
         <List>

+ 1 - 1
app/config/server.ts

@@ -46,7 +46,7 @@ export const getServerSideConfig = () => {
     proxyUrl: process.env.PROXY_URL,
     isVercel: !!process.env.VERCEL,
     hideUserApiKey: !!process.env.HIDE_USER_API_KEY,
-    enableGPT4: !process.env.DISABLE_GPT4,
+    disableGPT4: !!process.env.DISABLE_GPT4,
     hideBalanceQuery: !!process.env.HIDE_BALANCE_QUERY,
   };
 };

+ 68 - 0
app/constant.ts

@@ -53,6 +53,7 @@ export const OpenaiPath = {
   ChatPath: "v1/chat/completions",
   UsagePath: "dashboard/billing/usage",
   SubsPath: "dashboard/billing/subscription",
+  ListModelPath: "v1/models",
 };
 
 export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -61,3 +62,70 @@ You are ChatGPT, a large language model trained by OpenAI.
 Knowledge cutoff: 2021-09
 Current model: {{model}}
 Current time: {{time}}`;
+
+export const DEFAULT_MODELS = [
+  {
+    name: "gpt-4",
+    available: false,
+  },
+  {
+    name: "gpt-4-0314",
+    available: false,
+  },
+  {
+    name: "gpt-4-0613",
+    available: false,
+  },
+  {
+    name: "gpt-4-32k",
+    available: false,
+  },
+  {
+    name: "gpt-4-32k-0314",
+    available: false,
+  },
+  {
+    name: "gpt-4-32k-0613",
+    available: false,
+  },
+  {
+    name: "gpt-3.5-turbo",
+    available: true,
+  },
+  {
+    name: "gpt-3.5-turbo-0301",
+    available: true,
+  },
+  {
+    name: "gpt-3.5-turbo-0613",
+    available: true,
+  },
+  {
+    name: "gpt-3.5-turbo-16k",
+    available: true,
+  },
+  {
+    name: "gpt-3.5-turbo-16k-0613",
+    available: true,
+  },
+  {
+    name: "qwen-v1", // 通义千问
+    available: false,
+  },
+  {
+    name: "ernie", // 文心一言
+    available: false,
+  },
+  {
+    name: "spark", // 讯飞星火
+    available: false,
+  },
+  {
+    name: "llama", // llama
+    available: false,
+  },
+  {
+    name: "chatglm", // chatglm-6b
+    available: false,
+  },
+] as const;

+ 0 - 9
app/store/access.ts

@@ -3,7 +3,6 @@ import { persist } from "zustand/middleware";
 import { DEFAULT_API_HOST, StoreKey } from "../constant";
 import { getHeaders } from "../client/api";
 import { BOT_HELLO } from "./chat";
-import { ALL_MODELS } from "./config";
 import { getClientConfig } from "../config/client";
 
 export interface AccessControlStore {
@@ -76,14 +75,6 @@ export const useAccessStore = create<AccessControlStore>()(
             console.log("[Config] got config from server", res);
             set(() => ({ ...res }));
 
-            if (!res.enableGPT4) {
-              ALL_MODELS.forEach((model) => {
-                if (model.name.startsWith("gpt-4")) {
-                  (model as any).available = false;
-                }
-              });
-            }
-
             if ((res as any).botHello) {
               BOT_HELLO.content = (res as any).botHello;
             }

+ 28 - 73
app/store/config.ts

@@ -1,7 +1,10 @@
 import { create } from "zustand";
 import { persist } from "zustand/middleware";
+import { LLMModel } from "../client/api";
 import { getClientConfig } from "../config/client";
-import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
+import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, StoreKey } from "../constant";
+
+export type ModelType = (typeof DEFAULT_MODELS)[number]["name"];
 
 export enum SubmitKey {
   Enter = "Enter",
@@ -30,6 +33,8 @@ export const DEFAULT_CONFIG = {
 
   dontShowMaskSplashScreen: false, // dont show splash screen when create chat
 
+  models: DEFAULT_MODELS as any as LLMModel[],
+
   modelConfig: {
     model: "gpt-3.5-turbo" as ModelType,
     temperature: 0.5,
@@ -49,81 +54,11 @@ export type ChatConfig = typeof DEFAULT_CONFIG;
 export type ChatConfigStore = ChatConfig & {
   reset: () => void;
   update: (updater: (config: ChatConfig) => void) => void;
+  mergeModels: (newModels: LLMModel[]) => void;
 };
 
 export type ModelConfig = ChatConfig["modelConfig"];
 
-const ENABLE_GPT4 = true;
-
-export const ALL_MODELS = [
-  {
-    name: "gpt-4",
-    available: ENABLE_GPT4,
-  },
-  {
-    name: "gpt-4-0314",
-    available: ENABLE_GPT4,
-  },
-  {
-    name: "gpt-4-0613",
-    available: ENABLE_GPT4,
-  },
-  {
-    name: "gpt-4-32k",
-    available: ENABLE_GPT4,
-  },
-  {
-    name: "gpt-4-32k-0314",
-    available: ENABLE_GPT4,
-  },
-  {
-    name: "gpt-4-32k-0613",
-    available: ENABLE_GPT4,
-  },
-  {
-    name: "gpt-3.5-turbo",
-    available: true,
-  },
-  {
-    name: "gpt-3.5-turbo-0301",
-    available: true,
-  },
-  {
-    name: "gpt-3.5-turbo-0613",
-    available: true,
-  },
-  {
-    name: "gpt-3.5-turbo-16k",
-    available: true,
-  },
-  {
-    name: "gpt-3.5-turbo-16k-0613",
-    available: true,
-  },
-  {
-    name: "qwen-v1", // 通义千问
-    available: false,
-  },
-  {
-    name: "ernie", // 文心一言
-    available: false,
-  },
-  {
-    name: "spark", // 讯飞星火
-    available: false,
-  },
-  {
-    name: "llama", // llama
-    available: false,
-  },
-  {
-    name: "chatglm", // chatglm-6b
-    available: false,
-  },
-] as const;
-
-export type ModelType = (typeof ALL_MODELS)[number]["name"];
-
 export function limitNumber(
   x: number,
   min: number,
@@ -138,7 +73,8 @@ export function limitNumber(
 }
 
 export function limitModel(name: string) {
-  return ALL_MODELS.some((m) => m.name === name && m.available)
+  const allModels = useAppConfig.getState().models;
+  return allModels.some((m) => m.name === name && m.available)
     ? name
     : "gpt-3.5-turbo";
 }
@@ -178,6 +114,25 @@ export const useAppConfig = create<ChatConfigStore>()(
         updater(config);
         set(() => config);
       },
+
+      mergeModels(newModels) {
+        const oldModels = get().models;
+        const modelMap: Record<string, LLMModel> = {};
+
+        for (const model of oldModels) {
+          model.available = false;
+          modelMap[model.name] = model;
+        }
+
+        for (const model of newModels) {
+          model.available = true;
+          modelMap[model.name] = model;
+        }
+
+        set(() => ({
+          models: Object.values(modelMap),
+        }));
+      },
     }),
     {
       name: StoreKey.Config,