store.ts 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import { requestChatStream, requestWithPrompt } from "./requests";
  5. import { trimTopic } from "./utils";
  6. import Locale from './locales'
  7. export type Message = ChatCompletionResponseMessage & {
  8. date: string;
  9. streaming?: boolean;
  10. };
  11. export enum SubmitKey {
  12. Enter = "Enter",
  13. CtrlEnter = "Ctrl + Enter",
  14. ShiftEnter = "Shift + Enter",
  15. AltEnter = "Alt + Enter",
  16. }
  17. export enum Theme {
  18. Auto = "auto",
  19. Dark = "dark",
  20. Light = "light",
  21. }
  22. export interface ChatConfig {
  23. maxToken?: number
  24. historyMessageCount: number; // -1 means all
  25. compressMessageLengthThreshold: number;
  26. sendBotMessages: boolean; // send bot's message or not
  27. submitKey: SubmitKey;
  28. avatar: string;
  29. theme: Theme;
  30. tightBorder: boolean;
  31. }
  32. const DEFAULT_CONFIG: ChatConfig = {
  33. historyMessageCount: 4,
  34. compressMessageLengthThreshold: 1000,
  35. sendBotMessages: true as boolean,
  36. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  37. avatar: "1f603",
  38. theme: Theme.Auto as Theme,
  39. tightBorder: false,
  40. };
  41. export interface ChatStat {
  42. tokenCount: number;
  43. wordCount: number;
  44. charCount: number;
  45. }
  46. export interface ChatSession {
  47. id: number;
  48. topic: string;
  49. memoryPrompt: string;
  50. messages: Message[];
  51. stat: ChatStat;
  52. lastUpdate: string;
  53. lastSummarizeIndex: number;
  54. }
  55. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  56. function createEmptySession(): ChatSession {
  57. const createDate = new Date().toLocaleString();
  58. return {
  59. id: Date.now(),
  60. topic: DEFAULT_TOPIC,
  61. memoryPrompt: "",
  62. messages: [
  63. {
  64. role: "assistant",
  65. content: Locale.Store.BotHello,
  66. date: createDate,
  67. },
  68. ],
  69. stat: {
  70. tokenCount: 0,
  71. wordCount: 0,
  72. charCount: 0,
  73. },
  74. lastUpdate: createDate,
  75. lastSummarizeIndex: 0,
  76. };
  77. }
  78. interface ChatStore {
  79. config: ChatConfig;
  80. sessions: ChatSession[];
  81. currentSessionIndex: number;
  82. removeSession: (index: number) => void;
  83. selectSession: (index: number) => void;
  84. newSession: () => void;
  85. currentSession: () => ChatSession;
  86. onNewMessage: (message: Message) => void;
  87. onUserInput: (content: string) => Promise<void>;
  88. summarizeSession: () => void;
  89. updateStat: (message: Message) => void;
  90. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  91. updateMessage: (
  92. sessionIndex: number,
  93. messageIndex: number,
  94. updater: (message?: Message) => void
  95. ) => void;
  96. getMessagesWithMemory: () => Message[];
  97. getMemoryPrompt: () => Message,
  98. getConfig: () => ChatConfig;
  99. resetConfig: () => void;
  100. updateConfig: (updater: (config: ChatConfig) => void) => void;
  101. clearAllData: () => void;
  102. }
  103. const LOCAL_KEY = "chat-next-web-store";
  104. export const useChatStore = create<ChatStore>()(
  105. persist(
  106. (set, get) => ({
  107. sessions: [createEmptySession()],
  108. currentSessionIndex: 0,
  109. config: {
  110. ...DEFAULT_CONFIG,
  111. },
  112. resetConfig() {
  113. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  114. },
  115. getConfig() {
  116. return get().config;
  117. },
  118. updateConfig(updater) {
  119. const config = get().config;
  120. updater(config);
  121. set(() => ({ config }));
  122. },
  123. selectSession(index: number) {
  124. set({
  125. currentSessionIndex: index,
  126. });
  127. },
  128. removeSession(index: number) {
  129. set((state) => {
  130. let nextIndex = state.currentSessionIndex;
  131. const sessions = state.sessions;
  132. if (sessions.length === 1) {
  133. return {
  134. currentSessionIndex: 0,
  135. sessions: [createEmptySession()],
  136. };
  137. }
  138. sessions.splice(index, 1);
  139. if (nextIndex === index) {
  140. nextIndex -= 1;
  141. }
  142. return {
  143. currentSessionIndex: nextIndex,
  144. sessions,
  145. };
  146. });
  147. },
  148. newSession() {
  149. set((state) => ({
  150. currentSessionIndex: 0,
  151. sessions: [createEmptySession()].concat(state.sessions),
  152. }));
  153. },
  154. currentSession() {
  155. let index = get().currentSessionIndex;
  156. const sessions = get().sessions;
  157. if (index < 0 || index >= sessions.length) {
  158. index = Math.min(sessions.length - 1, Math.max(0, index));
  159. set(() => ({ currentSessionIndex: index }));
  160. }
  161. const session = sessions[index];
  162. return session;
  163. },
  164. onNewMessage(message) {
  165. get().updateCurrentSession(session => {
  166. session.lastUpdate = new Date().toLocaleString()
  167. })
  168. get().updateStat(message);
  169. get().summarizeSession();
  170. },
  171. async onUserInput(content) {
  172. const userMessage: Message = {
  173. role: "user",
  174. content,
  175. date: new Date().toLocaleString(),
  176. };
  177. const botMessage: Message = {
  178. content: "",
  179. role: "assistant",
  180. date: new Date().toLocaleString(),
  181. streaming: true,
  182. };
  183. // get recent messages
  184. const recentMessages = get().getMessagesWithMemory()
  185. const sendMessages = recentMessages.concat(userMessage)
  186. // save user's and bot's message
  187. get().updateCurrentSession((session) => {
  188. session.messages.push(userMessage);
  189. session.messages.push(botMessage);
  190. });
  191. console.log('[User Input] ', sendMessages)
  192. requestChatStream(sendMessages, {
  193. onMessage(content, done) {
  194. if (done) {
  195. botMessage.streaming = false;
  196. get().onNewMessage(botMessage)
  197. } else {
  198. botMessage.content = content;
  199. set(() => ({}));
  200. }
  201. },
  202. onError(error) {
  203. botMessage.content += "\n\n" + Locale.Store.Error;
  204. botMessage.streaming = false;
  205. set(() => ({}));
  206. },
  207. filterBot: !get().config.sendBotMessages,
  208. });
  209. },
  210. getMemoryPrompt() {
  211. const session = get().currentSession()
  212. return {
  213. role: 'system',
  214. content: Locale.Store.Prompt.History(session.memoryPrompt),
  215. date: ''
  216. } as Message
  217. },
  218. getMessagesWithMemory() {
  219. const session = get().currentSession()
  220. const config = get().config
  221. const n = session.messages.length
  222. const recentMessages = session.messages.slice(n - config.historyMessageCount);
  223. const memoryPrompt = get().getMemoryPrompt()
  224. if (session.memoryPrompt) {
  225. recentMessages.unshift(memoryPrompt)
  226. }
  227. return recentMessages
  228. },
  229. updateMessage(
  230. sessionIndex: number,
  231. messageIndex: number,
  232. updater: (message?: Message) => void
  233. ) {
  234. const sessions = get().sessions;
  235. const session = sessions.at(sessionIndex);
  236. const messages = session?.messages;
  237. updater(messages?.at(messageIndex));
  238. set(() => ({ sessions }));
  239. },
  240. summarizeSession() {
  241. const session = get().currentSession();
  242. if (session.topic === DEFAULT_TOPIC && session.messages.length >= 3) {
  243. // should summarize topic
  244. requestWithPrompt(
  245. session.messages,
  246. Locale.Store.Prompt.Topic
  247. ).then((res) => {
  248. get().updateCurrentSession(
  249. (session) => (session.topic = trimTopic(res))
  250. );
  251. });
  252. }
  253. const config = get().config
  254. let toBeSummarizedMsgs = session.messages.slice(session.lastSummarizeIndex)
  255. const historyMsgLength = toBeSummarizedMsgs.reduce((pre, cur) => pre + cur.content.length, 0)
  256. if (historyMsgLength > 4000) {
  257. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(-config.historyMessageCount)
  258. }
  259. // add memory prompt
  260. toBeSummarizedMsgs.unshift(get().getMemoryPrompt())
  261. const lastSummarizeIndex = session.messages.length
  262. console.log('[Chat History] ', toBeSummarizedMsgs, historyMsgLength, config.compressMessageLengthThreshold)
  263. if (historyMsgLength > config.compressMessageLengthThreshold) {
  264. requestChatStream(toBeSummarizedMsgs.concat({
  265. role: 'system',
  266. content: Locale.Store.Prompt.Summarize,
  267. date: ''
  268. }), {
  269. filterBot: false,
  270. onMessage(message, done) {
  271. session.memoryPrompt = message
  272. if (done) {
  273. console.log('[Memory] ', session.memoryPrompt)
  274. session.lastSummarizeIndex = lastSummarizeIndex
  275. }
  276. },
  277. onError(error) {
  278. console.error('[Summarize] ', error)
  279. },
  280. })
  281. }
  282. },
  283. updateStat(message) {
  284. get().updateCurrentSession((session) => {
  285. session.stat.charCount += message.content.length;
  286. // TODO: should update chat count and word count
  287. });
  288. },
  289. updateCurrentSession(updater) {
  290. const sessions = get().sessions;
  291. const index = get().currentSessionIndex;
  292. updater(sessions[index]);
  293. set(() => ({ sessions }));
  294. },
  295. clearAllData() {
  296. if (confirm(Locale.Store.ConfirmClearAll)) {
  297. localStorage.clear()
  298. location.reload()
  299. }
  300. },
  301. }),
  302. {
  303. name: LOCAL_KEY,
  304. }
  305. )
  306. );