Browse Source

feat: close #3187 use CUSTOM_MODELS to control model list

Yidadaa 1 year ago
parent
commit
d93f05f511

+ 7 - 0
README.md

@@ -197,6 +197,13 @@ If you do want users to query balance, set this value to 1, or you should set it
 
 
 If you want to disable parse settings from url, set this to 1.
 If you want to disable parse settings from url, set this to 1.
 
 
+### `CUSTOM_MODELS` (optional)
+
+> Default: Empty
+> Example: `+llama,+claude-2,-gpt-3.5-turbo` means add `llama, claude-2` to model list, and remove `gpt-3.5-turbo` from list.
+
+To control custom models, use `+` to add a custom model, use `-` to hide a model, separated by comma.
+
 ## Requirements
 ## Requirements
 
 
 NodeJS >= 18, Docker >= 20
 NodeJS >= 18, Docker >= 20

+ 6 - 0
README_CN.md

@@ -106,6 +106,12 @@ OpenAI 接口代理 URL,如果你手动配置了 openai 接口代理,请填
 
 
 如果你想禁用从链接解析预制设置,将此环境变量设置为 1 即可。
 如果你想禁用从链接解析预制设置,将此环境变量设置为 1 即可。
 
 
+### `CUSTOM_MODELS` (可选)
+
+> 示例:`+qwen-7b-chat,+glm-6b,-gpt-3.5-turbo` 表示增加 `qwen-7b-chat` 和 `glm-6b` 到模型列表,而从列表中删除 `gpt-3.5-turbo`。
+
+用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,用英文逗号隔开。
+
 ## 开发
 ## 开发
 
 
 点击下方按钮,开始二次开发:
 点击下方按钮,开始二次开发:

+ 16 - 15
app/api/common.ts

@@ -1,10 +1,9 @@
 import { NextRequest, NextResponse } from "next/server";
 import { NextRequest, NextResponse } from "next/server";
+import { getServerSideConfig } from "../config/server";
+import { DEFAULT_MODELS, OPENAI_BASE_URL } from "../constant";
+import { collectModelTable, collectModels } from "../utils/model";
 
 
-export const OPENAI_URL = "api.openai.com";
-const DEFAULT_PROTOCOL = "https";
-const PROTOCOL = process.env.PROTOCOL || DEFAULT_PROTOCOL;
-const BASE_URL = process.env.BASE_URL || OPENAI_URL;
-const DISABLE_GPT4 = !!process.env.DISABLE_GPT4;
+const serverConfig = getServerSideConfig();
 
 
 export async function requestOpenai(req: NextRequest) {
 export async function requestOpenai(req: NextRequest) {
   const controller = new AbortController();
   const controller = new AbortController();
@@ -14,10 +13,10 @@ export async function requestOpenai(req: NextRequest) {
     "",
     "",
   );
   );
 
 
-  let baseUrl = BASE_URL;
+  let baseUrl = serverConfig.baseUrl ?? OPENAI_BASE_URL;
 
 
   if (!baseUrl.startsWith("http")) {
   if (!baseUrl.startsWith("http")) {
-    baseUrl = `${PROTOCOL}://${baseUrl}`;
+    baseUrl = `https://${baseUrl}`;
   }
   }
 
 
   if (baseUrl.endsWith("/")) {
   if (baseUrl.endsWith("/")) {
@@ -26,10 +25,7 @@ export async function requestOpenai(req: NextRequest) {
 
 
   console.log("[Proxy] ", openaiPath);
   console.log("[Proxy] ", openaiPath);
   console.log("[Base Url]", baseUrl);
   console.log("[Base Url]", baseUrl);
-
-  if (process.env.OPENAI_ORG_ID) {
-    console.log("[Org ID]", process.env.OPENAI_ORG_ID);
-  }
+  console.log("[Org ID]", serverConfig.openaiOrgId);
 
 
   const timeoutId = setTimeout(
   const timeoutId = setTimeout(
     () => {
     () => {
@@ -58,18 +54,23 @@ export async function requestOpenai(req: NextRequest) {
   };
   };
 
 
   // #1815 try to refuse gpt4 request
   // #1815 try to refuse gpt4 request
-  if (DISABLE_GPT4 && req.body) {
+  if (serverConfig.customModels && req.body) {
     try {
     try {
+      const modelTable = collectModelTable(
+        DEFAULT_MODELS,
+        serverConfig.customModels,
+      );
       const clonedBody = await req.text();
       const clonedBody = await req.text();
       fetchOptions.body = clonedBody;
       fetchOptions.body = clonedBody;
 
 
-      const jsonBody = JSON.parse(clonedBody);
+      const jsonBody = JSON.parse(clonedBody) as { model?: string };
 
 
-      if ((jsonBody?.model ?? "").includes("gpt-4")) {
+      // not undefined and is false
+      if (modelTable[jsonBody?.model ?? ""] === false) {
         return NextResponse.json(
         return NextResponse.json(
           {
           {
             error: true,
             error: true,
-            message: "you are not allowed to use gpt-4 model",
+            message: `you are not allowed to use ${jsonBody?.model} model`,
           },
           },
           {
           {
             status: 403,
             status: 403,

+ 1 - 0
app/api/config/route.ts

@@ -12,6 +12,7 @@ const DANGER_CONFIG = {
   disableGPT4: serverConfig.disableGPT4,
   disableGPT4: serverConfig.disableGPT4,
   hideBalanceQuery: serverConfig.hideBalanceQuery,
   hideBalanceQuery: serverConfig.hideBalanceQuery,
   disableFastLink: serverConfig.disableFastLink,
   disableFastLink: serverConfig.disableFastLink,
+  customModels: serverConfig.customModels,
 };
 };
 
 
 declare global {
 declare global {

+ 4 - 8
app/components/chat.tsx

@@ -88,6 +88,7 @@ import { ChatCommandPrefix, useChatCommand, useCommand } from "../command";
 import { prettyObject } from "../utils/format";
 import { prettyObject } from "../utils/format";
 import { ExportMessageModal } from "./exporter";
 import { ExportMessageModal } from "./exporter";
 import { getClientConfig } from "../config/client";
 import { getClientConfig } from "../config/client";
+import { useAllModels } from "../utils/hooks";
 
 
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
 const Markdown = dynamic(async () => (await import("./markdown")).Markdown, {
   loading: () => <LoadingIcon />,
   loading: () => <LoadingIcon />,
@@ -430,14 +431,9 @@ export function ChatActions(props: {
 
 
   // switch model
   // switch model
   const currentModel = chatStore.currentSession().mask.modelConfig.model;
   const currentModel = chatStore.currentSession().mask.modelConfig.model;
-  const models = useMemo(
-    () =>
-      config
-        .allModels()
-        .filter((m) => m.available)
-        .map((m) => m.name),
-    [config],
-  );
+  const models = useAllModels()
+    .filter((m) => m.available)
+    .map((m) => m.name);
   const [showModelSelector, setShowModelSelector] = useState(false);
   const [showModelSelector, setShowModelSelector] = useState(false);
 
 
   return (
   return (

+ 4 - 3
app/components/model-config.tsx

@@ -1,14 +1,15 @@
-import { ModalConfigValidator, ModelConfig, useAppConfig } from "../store";
+import { ModalConfigValidator, ModelConfig } from "../store";
 
 
 import Locale from "../locales";
 import Locale from "../locales";
 import { InputRange } from "./input-range";
 import { InputRange } from "./input-range";
 import { ListItem, Select } from "./ui-lib";
 import { ListItem, Select } from "./ui-lib";
+import { useAllModels } from "../utils/hooks";
 
 
 export function ModelConfigList(props: {
 export function ModelConfigList(props: {
   modelConfig: ModelConfig;
   modelConfig: ModelConfig;
   updateConfig: (updater: (config: ModelConfig) => void) => void;
   updateConfig: (updater: (config: ModelConfig) => void) => void;
 }) {
 }) {
-  const config = useAppConfig();
+  const allModels = useAllModels();
 
 
   return (
   return (
     <>
     <>
@@ -24,7 +25,7 @@ export function ModelConfigList(props: {
             );
             );
           }}
           }}
         >
         >
-          {config.allModels().map((v, i) => (
+          {allModels.map((v, i) => (
             <option value={v.name} key={i} disabled={!v.available}>
             <option value={v.name} key={i} disabled={!v.available}>
               {v.name}
               {v.name}
             </option>
             </option>

+ 16 - 1
app/config/server.ts

@@ -1,4 +1,5 @@
 import md5 from "spark-md5";
 import md5 from "spark-md5";
+import { DEFAULT_MODELS } from "../constant";
 
 
 declare global {
 declare global {
   namespace NodeJS {
   namespace NodeJS {
@@ -7,6 +8,7 @@ declare global {
       CODE?: string;
       CODE?: string;
       BASE_URL?: string;
       BASE_URL?: string;
       PROXY_URL?: string;
       PROXY_URL?: string;
+      OPENAI_ORG_ID?: string;
       VERCEL?: string;
       VERCEL?: string;
       HIDE_USER_API_KEY?: string; // disable user's api key input
       HIDE_USER_API_KEY?: string; // disable user's api key input
       DISABLE_GPT4?: string; // allow user to use gpt-4 or not
       DISABLE_GPT4?: string; // allow user to use gpt-4 or not
@@ -14,6 +16,7 @@ declare global {
       BUILD_APP?: string; // is building desktop app
       BUILD_APP?: string; // is building desktop app
       ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
       ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
       DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
       DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
+      CUSTOM_MODELS?: string; // to control custom models
     }
     }
   }
   }
 }
 }
@@ -38,6 +41,16 @@ export const getServerSideConfig = () => {
     );
     );
   }
   }
 
 
+  let disableGPT4 = !!process.env.DISABLE_GPT4;
+  let customModels = process.env.CUSTOM_MODELS ?? "";
+
+  if (disableGPT4) {
+    if (customModels) customModels += ",";
+    customModels += DEFAULT_MODELS.filter((m) => m.name.startsWith("gpt-4"))
+      .map((m) => "-" + m.name)
+      .join(",");
+  }
+
   return {
   return {
     apiKey: process.env.OPENAI_API_KEY,
     apiKey: process.env.OPENAI_API_KEY,
     code: process.env.CODE,
     code: process.env.CODE,
@@ -45,10 +58,12 @@ export const getServerSideConfig = () => {
     needCode: ACCESS_CODES.size > 0,
     needCode: ACCESS_CODES.size > 0,
     baseUrl: process.env.BASE_URL,
     baseUrl: process.env.BASE_URL,
     proxyUrl: process.env.PROXY_URL,
     proxyUrl: process.env.PROXY_URL,
+    openaiOrgId: process.env.OPENAI_ORG_ID,
     isVercel: !!process.env.VERCEL,
     isVercel: !!process.env.VERCEL,
     hideUserApiKey: !!process.env.HIDE_USER_API_KEY,
     hideUserApiKey: !!process.env.HIDE_USER_API_KEY,
-    disableGPT4: !!process.env.DISABLE_GPT4,
+    disableGPT4,
     hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
     hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
     disableFastLink: !!process.env.DISABLE_FAST_LINK,
     disableFastLink: !!process.env.DISABLE_FAST_LINK,
+    customModels,
   };
   };
 };
 };

+ 1 - 6
app/store/access.ts

@@ -17,6 +17,7 @@ const DEFAULT_ACCESS_STATE = {
   hideBalanceQuery: false,
   hideBalanceQuery: false,
   disableGPT4: false,
   disableGPT4: false,
   disableFastLink: false,
   disableFastLink: false,
+  customModels: "",
 
 
   openaiUrl: DEFAULT_OPENAI_URL,
   openaiUrl: DEFAULT_OPENAI_URL,
 };
 };
@@ -52,12 +53,6 @@ export const useAccessStore = createPersistStore(
         .then((res: DangerConfig) => {
         .then((res: DangerConfig) => {
           console.log("[Config] got config from server", res);
           console.log("[Config] got config from server", res);
           set(() => ({ ...res }));
           set(() => ({ ...res }));
-
-          if (res.disableGPT4) {
-            DEFAULT_MODELS.forEach(
-              (m: any) => (m.available = !m.name.startsWith("gpt-4")),
-            );
-          }
         })
         })
         .catch(() => {
         .catch(() => {
           console.error("[Config] failed to fetch config");
           console.error("[Config] failed to fetch config");

+ 1 - 9
app/store/config.ts

@@ -128,15 +128,7 @@ export const useAppConfig = createPersistStore(
       }));
       }));
     },
     },
 
 
-    allModels() {
-      const customModels = get()
-        .customModels.split(",")
-        .filter((v) => !!v && v.length > 0)
-        .map((m) => ({ name: m, available: true }));
-      const allModels = get().models.concat(customModels);
-      allModels.sort((a, b) => (a.name < b.name ? -1 : 1));
-      return allModels;
-    },
+    allModels() {},
   }),
   }),
   {
   {
     name: StoreKey.Config,
     name: StoreKey.Config,

+ 16 - 0
app/utils/hooks.ts

@@ -0,0 +1,16 @@
+import { useMemo } from "react";
+import { useAccessStore, useAppConfig } from "../store";
+import { collectModels } from "./model";
+
+export function useAllModels() {
+  const accessStore = useAccessStore();
+  const configStore = useAppConfig();
+  const models = useMemo(() => {
+    return collectModels(
+      configStore.models,
+      [accessStore.customModels, configStore.customModels].join(","),
+    );
+  }, [accessStore.customModels, configStore.customModels, configStore.models]);
+
+  return models;
+}

+ 40 - 0
app/utils/model.ts

@@ -0,0 +1,40 @@
+import { LLMModel } from "../client/api";
+
+export function collectModelTable(
+  models: readonly LLMModel[],
+  customModels: string,
+) {
+  const modelTable: Record<string, boolean> = {};
+
+  // default models
+  models.forEach((m) => (modelTable[m.name] = m.available));
+
+  // server custom models
+  customModels
+    .split(",")
+    .filter((v) => !!v && v.length > 0)
+    .map((m) => {
+      if (m.startsWith("+")) {
+        modelTable[m.slice(1)] = true;
+      } else if (m.startsWith("-")) {
+        modelTable[m.slice(1)] = false;
+      } else modelTable[m] = true;
+    });
+  return modelTable;
+}
+
+/**
+ * Generate full model table.
+ */
+export function collectModels(
+  models: readonly LLMModel[],
+  customModels: string,
+) {
+  const modelTable = collectModelTable(models, customModels);
+  const allModels = Object.keys(modelTable).map((m) => ({
+    name: m,
+    available: modelTable[m],
+  }));
+
+  return allModels;
+}