Browse Source

feat: close #3187 use CUSTOM_MODELS to control model list

Yidadaa 1 year ago
parent
commit
d93f05f511

+ 7 - 0
README.md

@@ -197,6 +197,13 @@ If you do want users to query balance, set this value to 1, or you should set it
 
 If you want to disable parse settings from url, set this to 1.
 
+### `CUSTOM_MODELS` (optional)
+
+> Default: Empty
+> Example: `+llama,+claude-2,-gpt-3.5-turbo` means add `llama, claude-2` to model list, and remove `gpt-3.5-turbo` from list.
+
+To control custom models, use `+` to add a custom model, use `-` to hide a model, separated by comma.
+
 ## Requirements
 
 NodeJS >= 18, Docker >= 20

+ 6 - 0
README_CN.md

@@ -106,6 +106,12 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填
 
 如果你想禁用从链接解析预制设置,将此环境变量设置为 1 即可。
 
+### `CUSTOM_MODELS` (可选)
+
+> 示例:`+qwen-7b-chat,+glm-6b,-gpt-3.5-turbo` 表示增加 `qwen-7b-chat` 和 `glm-6b` 到模型列表,而从列表中删除 `gpt-3.5-turbo`。
+
+用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,用英文逗号隔开。
+
 ## 开发
 
 点击下方按钮,开始二次开发:

+ 16 - 15
app/api/common.ts

@@ -1,10 +1,9 @@
 import { NextRequest, NextResponse } from "next/server";
+import { getServerSideConfig } from "../config/server";
+import { DEFAULT_MODELS, OPENAI_BASE_URL } from "../constant";
+import { collectModelTable, collectModels } from "../utils/model";
 
-export const OPENAI_URL = "api.openai.com";
-const DEFAULT_PROTOCOL = "https";
-const PROTOCOL = process.env.PROTOCOL || DEFAULT_PROTOCOL;
-const BASE_URL = process.env.BASE_URL || OPENAI_URL;
-const DISABLE_GPT4 = !!process.env.DISABLE_GPT4;
+const serverConfig = getServerSideConfig();
 
 export async function requestOpenai(req: NextRequest) {
   const controller = new AbortController();
@@ -14,10 +13,10 @@ export async function requestOpenai(req: NextRequest) {
     "",
   );
 
-  let baseUrl = BASE_URL;
+  let baseUrl = serverConfig.baseUrl ?? OPENAI_BASE_URL;
 
   if (!baseUrl.startsWith("http")) {
-    baseUrl = `${PROTOCOL}://${baseUrl}`;
+    baseUrl = `https://${baseUrl}`;
   }
 
   if (baseUrl.endsWith("/")) {
@@ -26,10 +25,7 @@ export async function requestOpenai(req: NextRequest) {
 
   console.log("[Proxy] ", openaiPath);
   console.log("[Base Url]", baseUrl);
-
-  if (process.env.OPENAI_ORG_ID) {
-    console.log("[Org ID]", process.env.OPENAI_ORG_ID);
-  }
+  console.log("[Org ID]", serverConfig.openaiOrgId);
 
   const timeoutId = setTimeout(
     () => {
@@ -58,18 +54,23 @@ export async function requestOpenai(req: NextRequest) {
   };
 
   // #1815 try to refuse gpt4 request
-  if (DISABLE_GPT4 && req.body) {
+  if (serverConfig.customModels && req.body) {
     try {
+      const modelTable = collectModelTable(
+        DEFAULT_MODELS,
+        serverConfig.customModels,
+      );
       const clonedBody = await req.text();
       fetchOptions.body = clonedBody;
 
-      const jsonBody = JSON.parse(clonedBody);
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
 
-      if ((jsonBody?.model ?? "").includes("gpt-4")) {
+      // not undefined and is false
+      if (modelTable[jsonBody?.model ?? ""] === false) {
         return NextResponse.json(
           {
             error: true,
-            message: "you are not allowed to use gpt-4 model",
+            message: `you are not allowed to use ${jsonBody?.model} model`,
           },
           {
             status: 403,

+ 1 - 0
app/api/config/route.ts

@@ -12,6 +12,7 @@ const DANGER_CONFIG = {
   disableGPT4: serverConfig.disableGPT4,
   hideBalanceQuery: serverConfig.hideBalanceQuery,
   disableFastLink: serverConfig.disableFastLink,
+  customModels: serverConfig.customModels,
 };
 
 declare global {

+ 4 - 8
app/components/chat.tsx

@@ -88,6 +88,7 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command";
 import { prettyObject } from "../utils/format";
 import { ExportMessageModal } from "./exporter";
 import { getClientConfig } from "../config/client";
+import { useAllModels } from "../utils/hooks";
 
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
   loading: () => <LoadingIcon />,
@@ -430,14 +431,9 @@ export function ChatActions(props: {
 
   // switch model
   const currentModel = chatStore.currentSession().mask.modelConfig.model;
-  const models = useMemo(
-    () =>
-      config
-        .allModels()
-        .filter((m) => m.available)
-        .map((m) => m.name),
-    [config],
-  );
+  const models = useAllModels()
+    .filter((m) => m.available)
+    .map((m) => m.name);
   const [showModelSelector, setShowModelSelector] = useState(false);
 
   return (

+ 4 - 3
app/components/model-config.tsx

@@ -1,14 +1,15 @@
-import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store";
+import { ModalConfigValidator, ModelConfig } from "../store";
 
 import Locale from "../locales";
 import { InputRange } from "./input-range";
 import { ListItem, Select } from "./ui-lib";
+import { useAllModels } from "../utils/hooks";
 
 export function ModelConfigList(props: {
   modelConfig: ModelConfig;
   updateConfig: (updater: (config: ModelConfig) => void) => void;
 }) {
-  const config = useAppConfig();
+  const allModels = useAllModels();
 
   return (
     <>
@@ -24,7 +25,7 @@ export function ModelConfigList(props: {
             );
           }}
         >
-          {config.allModels().map((v, i) => (
+          {allModels.map((v, i) => (
             <option value={v.name} key={i} disabled={!v.available}>
               {v.name}
             </option>

+ 16 - 1
app/config/server.ts

@@ -1,4 +1,5 @@
 import md5 from "spark-md5";
+import { DEFAULT_MODELS } from "../constant";
 
 declare global {
   namespace NodeJS {
@@ -7,6 +8,7 @@ declare global {
       CODE?: string;
       BASE_URL?: string;
       PROXY_URL?: string;
+      OPENAI_ORG_ID?: string;
       VERCEL?: string;
       HIDE_USER_API_KEY?: string; // disable user's api key input
       DISABLE_GPT4?: string; // allow user to use gpt-4 or not
@@ -14,6 +16,7 @@ declare global {
       BUILD_APP?: string; // is building desktop app
       ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
       DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
+      CUSTOM_MODELS?: string; // to control custom models
     }
   }
 }
@@ -38,6 +41,16 @@ export const getServerSideConfig = () => {
     );
   }
 
+  let disableGPT4 = !!process.env.DISABLE_GPT4;
+  let customModels = process.env.CUSTOM_MODELS ?? "";
+
+  if (disableGPT4) {
+    if (customModels) customModels += ",";
+    customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4"))
+      .map((m) => "-" + m.name)
+      .join(",");
+  }
+
   return {
     apiKey: process.env.OPENAI_API_KEY,
     code: process.env.CODE,
@@ -45,10 +58,12 @@ export const getServerSideConfig = () => {
     needCode: ACCESS_CODES.size > 0,
     baseUrl: process.env.BASE_URL,
     proxyUrl: process.env.PROXY_URL,
+    openaiOrgId: process.env.OPENAI_ORG_ID,
     isVercel: !!process.env.VERCEL,
     hideUserApiKey: !!process.env.HIDE_USER_API_KEY,
-    disableGPT4: !!process.env.DISABLE_GPT4,
+    disableGPT4,
     hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
     disableFastLink: !!process.env.DISABLE_FAST_LINK,
+    customModels,
   };
 };

+ 1 - 6
app/store/access.ts

@@ -17,6 +17,7 @@ const DEFAULT_ACCESS_STATE = {
   hideBalanceQuery: false,
   disableGPT4: false,
   disableFastLink: false,
+  customModels: "",
 
   openaiUrl: DEFAULT_OPENAI_URL,
 };
@@ -52,12 +53,6 @@ export const useAccessStore = createPersistStore(
         .then((res: DangerConfig) => {
           console.log("[Config] got config from server", res);
           set(() => ({ ...res }));
-
-          if (res.disableGPT4) {
-            DEFAULT_MODELS.forEach(
-              (m: any) => (m.available = !m.name.startsWith("gpt-4")),
-            );
-          }
         })
         .catch(() => {
           console.error("[Config] failed to fetch config");

+ 1 - 9
app/store/config.ts

@@ -128,15 +128,7 @@ export const useAppConfig = createPersistStore(
       }));
     },
 
-    allModels() {
-      const customModels = get()
-        .customModels.split(",")
-        .filter((v) => !!v && v.length > 0)
-        .map((m) => ({ name: m, available: true }));
-      const allModels = get().models.concat(customModels);
-      allModels.sort((a, b) => (a.name < b.name ? -1 : 1));
-      return allModels;
-    },
+    allModels() {},
   }),
   {
     name: StoreKey.Config,

+ 16 - 0
app/utils/hooks.ts

@@ -0,0 +1,16 @@
+import { useMemo } from "react";
+import { useAccessStore, useAppConfig } from "../store";
+import { collectModels } from "./model";
+
+export function useAllModels() {
+  const accessStore = useAccessStore();
+  const configStore = useAppConfig();
+  const models = useMemo(() => {
+    return collectModels(
+      configStore.models,
+      [accessStore.customModels, configStore.customModels].join(","),
+    );
+  }, [accessStore.customModels, configStore.customModels, configStore.models]);
+
+  return models;
+}

+ 40 - 0
app/utils/model.ts

@@ -0,0 +1,40 @@
+import { LLMModel } from "../client/api";
+
+export function collectModelTable(
+  models: readonly LLMModel[],
+  customModels: string,
+) {
+  const modelTable: Record<string, boolean> = {};
+
+  // default models
+  models.forEach((m) => (modelTable[m.name] = m.available));
+
+  // server custom models
+  customModels
+    .split(",")
+    .filter((v) => !!v && v.length > 0)
+    .map((m) => {
+      if (m.startsWith("+")) {
+        modelTable[m.slice(1)] = true;
+      } else if (m.startsWith("-")) {
+        modelTable[m.slice(1)] = false;
+      } else modelTable[m] = true;
+    });
+  return modelTable;
+}
+
+/**
+ * Generate full model table.
+ */
+export function collectModels(
+  models: readonly LLMModel[],
+  customModels: string,
+) {
+  const modelTable = collectModelTable(models, customModels);
+  const allModels = Object.keys(modelTable).map((m) => ({
+    name: m,
+    available: modelTable[m],
+  }));
+
+  return allModels;
+}