123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643 |
- import { trimTopic } from "../utils";
- import Locale, { getLang } from "../locales";
- import { showToast } from "../components/ui-lib";
- import { ModelConfig, ModelType, useAppConfig } from "./config";
- import { createEmptyMask, Mask } from "./mask";
- import {
- DEFAULT_INPUT_TEMPLATE,
- DEFAULT_SYSTEM_TEMPLATE,
- KnowledgeCutOffDate,
- StoreKey,
- SUMMARIZE_MODEL,
- } from "../constant";
- import { api, RequestMessage } from "../client/api";
- import { ChatControllerPool } from "../client/controller";
- import { prettyObject } from "../utils/format";
- import { estimateTokenLength } from "../utils/token";
- import { nanoid } from "nanoid";
- import { createPersistStore } from "../utils/store";
- export type ChatMessage = RequestMessage & {
- date: string;
- streaming?: boolean;
- isError?: boolean;
- id: string;
- model?: ModelType;
- };
- export function createMessage(override: Partial<ChatMessage>): ChatMessage {
- return {
- id: nanoid(),
- date: new Date().toLocaleString(),
- role: "user",
- content: "",
- ...override,
- };
- }
- export interface ChatStat {
- tokenCount: number;
- wordCount: number;
- charCount: number;
- }
- export interface ChatSession {
- id: string;
- topic: string;
- memoryPrompt: string;
- messages: ChatMessage[];
- stat: ChatStat;
- lastUpdate: number;
- lastSummarizeIndex: number;
- clearContextIndex?: number;
- mask: Mask;
- }
- export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
- export const BOT_HELLO: ChatMessage = createMessage({
- role: "assistant",
- content: Locale.Store.BotHello,
- });
- function createEmptySession(): ChatSession {
- return {
- id: nanoid(),
- topic: DEFAULT_TOPIC,
- memoryPrompt: "",
- messages: [],
- stat: {
- tokenCount: 0,
- wordCount: 0,
- charCount: 0,
- },
- lastUpdate: Date.now(),
- lastSummarizeIndex: 0,
- mask: createEmptyMask(),
- };
- }
- function getSummarizeModel(currentModel: string) {
- // if it is using gpt-* models, force to use 3.5 to summarize
- return currentModel.startsWith("gpt") ? SUMMARIZE_MODEL : currentModel;
- }
- function countMessages(msgs: ChatMessage[]) {
- return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0);
- }
- function fillTemplateWith(input: string, modelConfig: ModelConfig) {
- let cutoff =
- KnowledgeCutOffDate[modelConfig.model] ?? KnowledgeCutOffDate.default;
- const vars = {
- cutoff,
- model: modelConfig.model,
- time: new Date().toLocaleString(),
- lang: getLang(),
- input: input,
- };
- let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE;
- // must contains {{input}}
- const inputVar = "{{input}}";
- if (!output.includes(inputVar)) {
- output += "\n" + inputVar;
- }
- Object.entries(vars).forEach(([name, value]) => {
- output = output.replaceAll(`{{${name}}}`, value);
- });
- return output;
- }
- const DEFAULT_CHAT_STATE = {
- sessions: [createEmptySession()],
- currentSessionIndex: 0,
- };
- export const useChatStore = createPersistStore(
- DEFAULT_CHAT_STATE,
- (set, _get) => {
- function get() {
- return {
- ..._get(),
- ...methods,
- };
- }
- const methods = {
- clearSessions() {
- set(() => ({
- sessions: [createEmptySession()],
- currentSessionIndex: 0,
- }));
- },
- selectSession(index: number) {
- set({
- currentSessionIndex: index,
- });
- },
- moveSession(from: number, to: number) {
- set((state) => {
- const { sessions, currentSessionIndex: oldIndex } = state;
- // move the session
- const newSessions = [...sessions];
- const session = newSessions[from];
- newSessions.splice(from, 1);
- newSessions.splice(to, 0, session);
- // modify current session id
- let newIndex = oldIndex === from ? to : oldIndex;
- if (oldIndex > from && oldIndex <= to) {
- newIndex -= 1;
- } else if (oldIndex < from && oldIndex >= to) {
- newIndex += 1;
- }
- return {
- currentSessionIndex: newIndex,
- sessions: newSessions,
- };
- });
- },
- newSession(mask?: Mask) {
- const session = createEmptySession();
- if (mask) {
- const config = useAppConfig.getState();
- const globalModelConfig = config.modelConfig;
- session.mask = {
- ...mask,
- modelConfig: {
- ...globalModelConfig,
- ...mask.modelConfig,
- },
- };
- session.topic = mask.name;
- }
- set((state) => ({
- currentSessionIndex: 0,
- sessions: [session].concat(state.sessions),
- }));
- },
- nextSession(delta: number) {
- const n = get().sessions.length;
- const limit = (x: number) => (x + n) % n;
- const i = get().currentSessionIndex;
- get().selectSession(limit(i + delta));
- },
- deleteSession(index: number) {
- const deletingLastSession = get().sessions.length === 1;
- const deletedSession = get().sessions.at(index);
- if (!deletedSession) return;
- const sessions = get().sessions.slice();
- sessions.splice(index, 1);
- const currentIndex = get().currentSessionIndex;
- let nextIndex = Math.min(
- currentIndex - Number(index < currentIndex),
- sessions.length - 1,
- );
- if (deletingLastSession) {
- nextIndex = 0;
- sessions.push(createEmptySession());
- }
- // for undo delete action
- const restoreState = {
- currentSessionIndex: get().currentSessionIndex,
- sessions: get().sessions.slice(),
- };
- set(() => ({
- currentSessionIndex: nextIndex,
- sessions,
- }));
- showToast(
- Locale.Home.DeleteToast,
- {
- text: Locale.Home.Revert,
- onClick() {
- set(() => restoreState);
- },
- },
- 5000,
- );
- },
- currentSession() {
- let index = get().currentSessionIndex;
- const sessions = get().sessions;
- if (index < 0 || index >= sessions.length) {
- index = Math.min(sessions.length - 1, Math.max(0, index));
- set(() => ({ currentSessionIndex: index }));
- }
- return sessions[index];
- },
- onNewMessage(message: ChatMessage) {
- get().updateCurrentSession((session) => {
- session.messages = session.messages.concat();
- session.lastUpdate = Date.now();
- });
- get().updateStat(message);
- get().summarizeSession();
- },
- async onUserInput(content: string) {
- const session = get().currentSession();
- const modelConfig = session.mask.modelConfig;
- const userContent = fillTemplateWith(content, modelConfig);
- console.log("[User Input] after template: ", userContent);
- const userMessage: ChatMessage = createMessage({
- role: "user",
- content: userContent,
- });
- const botMessage: ChatMessage = createMessage({
- role: "assistant",
- streaming: true,
- model: modelConfig.model,
- });
- // get recent messages
- const recentMessages = get().getMessagesWithMemory();
- const sendMessages = recentMessages.concat(userMessage);
- const messageIndex = get().currentSession().messages.length + 1;
- // save user's and bot's message
- get().updateCurrentSession((session) => {
- const savedUserMessage = {
- ...userMessage,
- content,
- };
- session.messages = session.messages.concat([
- savedUserMessage,
- botMessage,
- ]);
- });
- // make request
- api.llm.chat({
- messages: sendMessages,
- config: { ...modelConfig, stream: true },
- onUpdate(message) {
- botMessage.streaming = true;
- if (message) {
- botMessage.content = message;
- }
- get().updateCurrentSession((session) => {
- session.messages = session.messages.concat();
- });
- },
- onFinish(message) {
- botMessage.streaming = false;
- if (message) {
- botMessage.content = message;
- get().onNewMessage(botMessage);
- }
- ChatControllerPool.remove(session.id, botMessage.id);
- },
- onError(error) {
- const isAborted = error.message.includes("aborted");
- botMessage.content +=
- "\n\n" +
- prettyObject({
- error: true,
- message: error.message,
- });
- botMessage.streaming = false;
- userMessage.isError = !isAborted;
- botMessage.isError = !isAborted;
- get().updateCurrentSession((session) => {
- session.messages = session.messages.concat();
- });
- ChatControllerPool.remove(
- session.id,
- botMessage.id ?? messageIndex,
- );
- console.error("[Chat] failed ", error);
- },
- onController(controller) {
- // collect controller for stop/retry
- ChatControllerPool.addController(
- session.id,
- botMessage.id ?? messageIndex,
- controller,
- );
- },
- });
- },
- getMemoryPrompt() {
- const session = get().currentSession();
- return {
- role: "system",
- content:
- session.memoryPrompt.length > 0
- ? Locale.Store.Prompt.History(session.memoryPrompt)
- : "",
- date: "",
- } as ChatMessage;
- },
- getMessagesWithMemory() {
- const session = get().currentSession();
- const modelConfig = session.mask.modelConfig;
- const clearContextIndex = session.clearContextIndex ?? 0;
- const messages = session.messages.slice();
- const totalMessageCount = session.messages.length;
- // in-context prompts
- const contextPrompts = session.mask.context.slice();
- // system prompts, to get close to OpenAI Web ChatGPT
- const shouldInjectSystemPrompts = modelConfig.enableInjectSystemPrompts;
- const systemPrompts = shouldInjectSystemPrompts
- ? [
- createMessage({
- role: "system",
- content: fillTemplateWith("", {
- ...modelConfig,
- template: DEFAULT_SYSTEM_TEMPLATE,
- }),
- }),
- ]
- : [];
- if (shouldInjectSystemPrompts) {
- console.log(
- "[Global System Prompt] ",
- systemPrompts.at(0)?.content ?? "empty",
- );
- }
- // long term memory
- const shouldSendLongTermMemory =
- modelConfig.sendMemory &&
- session.memoryPrompt &&
- session.memoryPrompt.length > 0 &&
- session.lastSummarizeIndex > clearContextIndex;
- const longTermMemoryPrompts = shouldSendLongTermMemory
- ? [get().getMemoryPrompt()]
- : [];
- const longTermMemoryStartIndex = session.lastSummarizeIndex;
- // short term memory
- const shortTermMemoryStartIndex = Math.max(
- 0,
- totalMessageCount - modelConfig.historyMessageCount,
- );
- // lets concat send messages, including 4 parts:
- // 0. system prompt: to get close to OpenAI Web ChatGPT
- // 1. long term memory: summarized memory messages
- // 2. pre-defined in-context prompts
- // 3. short term memory: latest n messages
- // 4. newest input message
- const memoryStartIndex = shouldSendLongTermMemory
- ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex)
- : shortTermMemoryStartIndex;
- // and if user has cleared history messages, we should exclude the memory too.
- const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex);
- const maxTokenThreshold = modelConfig.max_tokens;
- // get recent messages as much as possible
- const reversedRecentMessages = [];
- for (
- let i = totalMessageCount - 1, tokenCount = 0;
- i >= contextStartIndex && tokenCount < maxTokenThreshold;
- i -= 1
- ) {
- const msg = messages[i];
- if (!msg || msg.isError) continue;
- tokenCount += estimateTokenLength(msg.content);
- reversedRecentMessages.push(msg);
- }
- // concat all messages
- const recentMessages = [
- ...systemPrompts,
- ...longTermMemoryPrompts,
- ...contextPrompts,
- ...reversedRecentMessages.reverse(),
- ];
- return recentMessages;
- },
- updateMessage(
- sessionIndex: number,
- messageIndex: number,
- updater: (message?: ChatMessage) => void,
- ) {
- const sessions = get().sessions;
- const session = sessions.at(sessionIndex);
- const messages = session?.messages;
- updater(messages?.at(messageIndex));
- set(() => ({ sessions }));
- },
- resetSession() {
- get().updateCurrentSession((session) => {
- session.messages = [];
- session.memoryPrompt = "";
- });
- },
- summarizeSession() {
- const config = useAppConfig.getState();
- const session = get().currentSession();
- // remove error messages if any
- const messages = session.messages;
- // should summarize topic after chating more than 50 words
- const SUMMARIZE_MIN_LEN = 50;
- if (
- config.enableAutoGenerateTitle &&
- session.topic === DEFAULT_TOPIC &&
- countMessages(messages) >= SUMMARIZE_MIN_LEN
- ) {
- const topicMessages = messages.concat(
- createMessage({
- role: "user",
- content: Locale.Store.Prompt.Topic,
- }),
- );
- api.llm.chat({
- messages: topicMessages,
- config: {
- model: getSummarizeModel(session.mask.modelConfig.model),
- },
- onFinish(message) {
- get().updateCurrentSession(
- (session) =>
- (session.topic =
- message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
- );
- },
- });
- }
- const modelConfig = session.mask.modelConfig;
- const summarizeIndex = Math.max(
- session.lastSummarizeIndex,
- session.clearContextIndex ?? 0,
- );
- let toBeSummarizedMsgs = messages
- .filter((msg) => !msg.isError)
- .slice(summarizeIndex);
- const historyMsgLength = countMessages(toBeSummarizedMsgs);
- if (historyMsgLength > modelConfig?.max_tokens ?? 4000) {
- const n = toBeSummarizedMsgs.length;
- toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
- Math.max(0, n - modelConfig.historyMessageCount),
- );
- }
- // add memory prompt
- toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
- const lastSummarizeIndex = session.messages.length;
- console.log(
- "[Chat History] ",
- toBeSummarizedMsgs,
- historyMsgLength,
- modelConfig.compressMessageLengthThreshold,
- );
- if (
- historyMsgLength > modelConfig.compressMessageLengthThreshold &&
- modelConfig.sendMemory
- ) {
- api.llm.chat({
- messages: toBeSummarizedMsgs.concat(
- createMessage({
- role: "system",
- content: Locale.Store.Prompt.Summarize,
- date: "",
- }),
- ),
- config: {
- ...modelConfig,
- stream: true,
- model: getSummarizeModel(session.mask.modelConfig.model),
- },
- onUpdate(message) {
- session.memoryPrompt = message;
- },
- onFinish(message) {
- console.log("[Memory] ", message);
- get().updateCurrentSession((session) => {
- session.lastSummarizeIndex = lastSummarizeIndex;
- session.memoryPrompt = message; // Update the memory prompt for stored it in local storage
- });
- },
- onError(err) {
- console.error("[Summarize] ", err);
- },
- });
- }
- },
- updateStat(message: ChatMessage) {
- get().updateCurrentSession((session) => {
- session.stat.charCount += message.content.length;
- // TODO: should update chat count and word count
- });
- },
- updateCurrentSession(updater: (session: ChatSession) => void) {
- const sessions = get().sessions;
- const index = get().currentSessionIndex;
- updater(sessions[index]);
- set(() => ({ sessions }));
- },
- clearAllData() {
- localStorage.clear();
- location.reload();
- },
- };
- return methods;
- },
- {
- name: StoreKey.Chat,
- version: 3.1,
- migrate(persistedState, version) {
- const state = persistedState as any;
- const newState = JSON.parse(
- JSON.stringify(state),
- ) as typeof DEFAULT_CHAT_STATE;
- if (version < 2) {
- newState.sessions = [];
- const oldSessions = state.sessions;
- for (const oldSession of oldSessions) {
- const newSession = createEmptySession();
- newSession.topic = oldSession.topic;
- newSession.messages = [...oldSession.messages];
- newSession.mask.modelConfig.sendMemory = true;
- newSession.mask.modelConfig.historyMessageCount = 4;
- newSession.mask.modelConfig.compressMessageLengthThreshold = 1000;
- newState.sessions.push(newSession);
- }
- }
- if (version < 3) {
- // migrate id to nanoid
- newState.sessions.forEach((s) => {
- s.id = nanoid();
- s.messages.forEach((m) => (m.id = nanoid()));
- });
- }
- // Enable `enableInjectSystemPrompts` attribute for old sessions.
- // Resolve issue of old sessions not automatically enabling.
- if (version < 3.1) {
- newState.sessions.forEach((s) => {
- if (
- // Exclude those already set by user
- !s.mask.modelConfig.hasOwnProperty("enableInjectSystemPrompts")
- ) {
- // Because users may have changed this configuration,
- // the user's current configuration is used instead of the default
- const config = useAppConfig.getState();
- s.mask.modelConfig.enableInjectSystemPrompts =
- config.modelConfig.enableInjectSystemPrompts;
- }
- });
- }
- return newState as any;
- },
- },
- );
|