Browse Source

feat: #close 1789 add user input template

Yidadaa 1 year ago
parent
commit
be597a551d
6 changed files with 114 additions and 55 deletions
  1. 16 1
      app/components/model-config.tsx
  2. 7 0
      app/constant.ts
  3. 5 0
      app/locales/cn.ts
  4. 6 0
      app/locales/en.ts
  5. 75 51
      app/store/chat.ts
  6. 5 3
      app/store/config.ts

+ 16 - 1
app/components/model-config.tsx

@@ -2,7 +2,7 @@ import { ALL_MODELS, ModalConfigValidator, ModelConfig } from "../store";
 
 import Locale from "../locales";
 import { InputRange } from "./input-range";
-import { List, ListItem, Select } from "./ui-lib";
+import { ListItem, Select } from "./ui-lib";
 
 export function ModelConfigList(props: {
   modelConfig: ModelConfig;
@@ -109,6 +109,21 @@ export function ModelConfigList(props: {
         ></InputRange>
       </ListItem>
 
+      <ListItem
+        title={Locale.Settings.InputTemplate.Title}
+        subTitle={Locale.Settings.InputTemplate.SubTitle}
+      >
+        <input
+          type="text"
+          value={props.modelConfig.template}
+          onChange={(e) =>
+            props.updateConfig(
+              (config) => (config.template = e.currentTarget.value),
+            )
+          }
+        ></input>
+      </ListItem>
+
       <ListItem
         title={Locale.Settings.HistoryCount.Title}
         subTitle={Locale.Settings.HistoryCount.SubTitle}

+ 7 - 0
app/constant.ts

@@ -52,3 +52,10 @@ export const OpenaiPath = {
   UsagePath: "dashboard/billing/usage",
   SubsPath: "dashboard/billing/subscription",
 };
+
+export const DEFAULT_INPUT_TEMPLATE = `
+Act as a virtual assistant powered by model: '{{model}}', my input is:
+'''
+{{input}}
+'''
+`;

+ 5 - 0
app/locales/cn.ts

@@ -115,6 +115,11 @@ const cn = {
       SubTitle: "聊天内容的字体大小",
     },
 
+    InputTemplate: {
+      Title: "用户输入预处理",
+      SubTitle: "用户最新的一条消息会填充到此模板",
+    },
+
     Update: {
       Version: (x: string) => `当前版本:${x}`,
       IsLatest: "已是最新版本",

+ 6 - 0
app/locales/en.ts

@@ -116,6 +116,12 @@ const en: LocaleType = {
       Title: "Font Size",
       SubTitle: "Adjust font size of chat content",
     },
+
+    InputTemplate: {
+      Title: "Input Template",
+      SubTitle: "Newest message will be filled to this template",
+    },
+
     Update: {
       Version: (x: string) => `Version: ${x}`,
       IsLatest: "Latest version",

+ 75 - 51
app/store/chat.ts

@@ -3,11 +3,11 @@ import { persist } from "zustand/middleware";
 
 import { trimTopic } from "../utils";
 
-import Locale from "../locales";
+import Locale, { getLang } from "../locales";
 import { showToast } from "../components/ui-lib";
-import { ModelType } from "./config";
+import { ModelConfig, ModelType } from "./config";
 import { createEmptyMask, Mask } from "./mask";
-import { StoreKey } from "../constant";
+import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
 import { api, RequestMessage } from "../client/api";
 import { ChatControllerPool } from "../client/controller";
 import { prettyObject } from "../utils/format";
@@ -106,6 +106,29 @@ function countMessages(msgs: ChatMessage[]) {
   return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
 }
 
+function fillTemplateWith(input: string, modelConfig: ModelConfig) {
+  const vars = {
+    model: modelConfig.model,
+    time: new Date().toLocaleString(),
+    lang: getLang(),
+    input: input,
+  };
+
+  let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
+
+  // must contains {{input}}
+  const inputVar = "{{input}}";
+  if (!output.includes(inputVar)) {
+    output += "\n" + inputVar;
+  }
+
+  Object.entries(vars).forEach(([name, value]) => {
+    output = output.replaceAll(`{{${name}}}`, value);
+  });
+
+  return output;
+}
+
 export const useChatStore = create<ChatStore>()(
   persist(
     (set, get) => ({
@@ -238,9 +261,12 @@ export const useChatStore = create<ChatStore>()(
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
 
+        const userContent = fillTemplateWith(content, modelConfig);
+        console.log("[User Input] fill with template: ", userContent);
+
         const userMessage: ChatMessage = createMessage({
           role: "user",
-          content,
+          content: userContent,
         });
 
         const botMessage: ChatMessage = createMessage({
@@ -250,31 +276,22 @@ export const useChatStore = create<ChatStore>()(
           model: modelConfig.model,
         });
 
-        const systemInfo = createMessage({
-          role: "system",
-          content: `IMPORTANT: You are a virtual assistant powered by the ${
-            modelConfig.model
-          } model, now time is ${new Date().toLocaleString()}}`,
-          id: botMessage.id! + 1,
-        });
-
         // get recent messages
-        const systemMessages = [];
-        // if user define a mask with context prompts, wont send system info
-        if (session.mask.context.length === 0) {
-          systemMessages.push(systemInfo);
-        }
-
         const recentMessages = get().getMessagesWithMemory();
-        const sendMessages = systemMessages.concat(
-          recentMessages.concat(userMessage),
-        );
+        const sendMessages = recentMessages.concat(userMessage);
         const sessionIndex = get().currentSessionIndex;
         const messageIndex = get().currentSession().messages.length + 1;
 
         // save user's and bot's message
         get().updateCurrentSession((session) => {
-          session.messages = session.messages.concat([userMessage, botMessage]);
+          const savedUserMessage = {
+            ...userMessage,
+            content,
+          };
+          session.messages = session.messages.concat([
+            savedUserMessage,
+            botMessage,
+          ]);
         });
 
         // make request
@@ -350,55 +367,62 @@ export const useChatStore = create<ChatStore>()(
       getMessagesWithMemory() {
         const session = get().currentSession();
         const modelConfig = session.mask.modelConfig;
+        const clearContextIndex = session.clearContextIndex ?? 0;
+        const messages = session.messages.slice();
+        const totalMessageCount = session.messages.length;
 
-        // wont send cleared context messages
-        const clearedContextMessages = session.messages.slice(
-          session.clearContextIndex ?? 0,
-        );
-        const messages = clearedContextMessages.filter((msg) => !msg.isError);
-        const n = messages.length;
-
-        const context = session.mask.context.slice();
+        // in-context prompts
+        const contextPrompts = session.mask.context.slice();
 
         // long term memory
-        if (
+        const shouldSendLongTermMemory =
           modelConfig.sendMemory &&
           session.memoryPrompt &&
-          session.memoryPrompt.length > 0
-        ) {
-          const memoryPrompt = get().getMemoryPrompt();
-          context.push(memoryPrompt);
-        }
-
-        // get short term and unmemorized long term memory
-        const shortTermMemoryMessageIndex = Math.max(
+          session.memoryPrompt.length > 0 &&
+          session.lastSummarizeIndex <= clearContextIndex;
+        const longTermMemoryPrompts = shouldSendLongTermMemory
+          ? [get().getMemoryPrompt()]
+          : [];
+        const longTermMemoryStartIndex = session.lastSummarizeIndex;
+
+        // short term memory
+        const shortTermMemoryStartIndex = Math.max(
           0,
-          n - modelConfig.historyMessageCount,
+          totalMessageCount - modelConfig.historyMessageCount,
         );
-        const longTermMemoryMessageIndex = session.lastSummarizeIndex;
 
-        // try to concat history messages
+        // lets concat send messages, including 4 parts:
+        // 1. long term memory: summarized memory messages
+        // 2. pre-defined in-context prompts
+        // 3. short term memory: latest n messages
+        // 4. newest input message
         const memoryStartIndex = Math.min(
-          shortTermMemoryMessageIndex,
-          longTermMemoryMessageIndex,
+          longTermMemoryStartIndex,
+          shortTermMemoryStartIndex,
         );
-        const threshold = modelConfig.max_tokens;
+        // and if user has cleared history messages, we should exclude the memory too.
+        const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
+        const maxTokenThreshold = modelConfig.max_tokens;
 
-        // get recent messages as many as possible
+        // get recent messages as much as possible
         const reversedRecentMessages = [];
         for (
-          let i = n - 1, count = 0;
-          i >= memoryStartIndex && count < threshold;
+          let i = totalMessageCount - 1, tokenCount = 0;
+          i >= contextStartIndex && tokenCount < maxTokenThreshold;
           i -= 1
         ) {
           const msg = messages[i];
           if (!msg || msg.isError) continue;
-          count += estimateTokenLength(msg.content);
+          tokenCount += estimateTokenLength(msg.content);
           reversedRecentMessages.push(msg);
         }
 
-        // concat
-        const recentMessages = context.concat(reversedRecentMessages.reverse());
+        // concat all messages
+        const recentMessages = [
+          ...longTermMemoryPrompts,
+          ...contextPrompts,
+          ...reversedRecentMessages.reverse(),
+        ];
 
         return recentMessages;
       },

+ 5 - 3
app/store/config.ts

@@ -1,7 +1,7 @@
 import { create } from "zustand";
 import { persist } from "zustand/middleware";
 import { getClientConfig } from "../config/client";
-import { StoreKey } from "../constant";
+import { DEFAULT_INPUT_TEMPLATE, StoreKey } from "../constant";
 
 export enum SubmitKey {
   Enter = "Enter",
@@ -39,6 +39,7 @@ export const DEFAULT_CONFIG = {
     sendMemory: true,
     historyMessageCount: 4,
     compressMessageLengthThreshold: 1000,
+    template: DEFAULT_INPUT_TEMPLATE,
   },
 };
 
@@ -176,15 +177,16 @@ export const useAppConfig = create<ChatConfigStore>()(
     }),
     {
       name: StoreKey.Config,
-      version: 3,
+      version: 3.1,
       migrate(persistedState, version) {
-        if (version === 3) return persistedState as any;
+        if (version === 3.1) return persistedState as any;
 
         const state = persistedState as ChatConfig;
         state.modelConfig.sendMemory = true;
         state.modelConfig.historyMessageCount = 4;
         state.modelConfig.compressMessageLengthThreshold = 1000;
         state.modelConfig.frequency_penalty = 0;
+        state.modelConfig.template = DEFAULT_INPUT_TEMPLATE;
         state.dontShowMaskSplashScreen = false;
 
         return state;