app.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { isMobileScreen, trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. import { showToast } from "../components/ui-lib";
  12. export type Message = ChatCompletionResponseMessage & {
  13. date: string;
  14. streaming?: boolean;
  15. isError?: boolean;
  16. id?: number;
  17. };
  18. export function createMessage(override: Partial<Message>): Message {
  19. return {
  20. id: Date.now(),
  21. date: new Date().toLocaleString(),
  22. role: "user",
  23. content: "",
  24. ...override,
  25. };
  26. }
  27. export enum SubmitKey {
  28. Enter = "Enter",
  29. CtrlEnter = "Ctrl + Enter",
  30. ShiftEnter = "Shift + Enter",
  31. AltEnter = "Alt + Enter",
  32. MetaEnter = "Meta + Enter",
  33. }
  34. export enum Theme {
  35. Auto = "auto",
  36. Dark = "dark",
  37. Light = "light",
  38. }
  39. export interface ChatConfig {
  40. historyMessageCount: number; // -1 means all
  41. compressMessageLengthThreshold: number;
  42. sendBotMessages: boolean; // send bot's message or not
  43. submitKey: SubmitKey;
  44. avatar: string;
  45. fontSize: number;
  46. theme: Theme;
  47. tightBorder: boolean;
  48. sendPreviewBubble: boolean;
  49. sidebarWidth: number;
  50. disablePromptHint: boolean;
  51. modelConfig: {
  52. model: string;
  53. temperature: number;
  54. max_tokens: number;
  55. presence_penalty: number;
  56. };
  57. }
  58. export type ModelConfig = ChatConfig["modelConfig"];
  59. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  60. const ENABLE_GPT4 = true;
  61. export const ALL_MODELS = [
  62. {
  63. name: "gpt-4",
  64. available: ENABLE_GPT4,
  65. },
  66. {
  67. name: "gpt-4-0314",
  68. available: ENABLE_GPT4,
  69. },
  70. {
  71. name: "gpt-4-32k",
  72. available: ENABLE_GPT4,
  73. },
  74. {
  75. name: "gpt-4-32k-0314",
  76. available: ENABLE_GPT4,
  77. },
  78. {
  79. name: "gpt-3.5-turbo",
  80. available: true,
  81. },
  82. {
  83. name: "gpt-3.5-turbo-0301",
  84. available: true,
  85. },
  86. ];
  87. export function limitNumber(
  88. x: number,
  89. min: number,
  90. max: number,
  91. defaultValue: number,
  92. ) {
  93. if (typeof x !== "number" || isNaN(x)) {
  94. return defaultValue;
  95. }
  96. return Math.min(max, Math.max(min, x));
  97. }
  98. export function limitModel(name: string) {
  99. return ALL_MODELS.some((m) => m.name === name && m.available)
  100. ? name
  101. : ALL_MODELS[4].name;
  102. }
  103. export const ModalConfigValidator = {
  104. model(x: string) {
  105. return limitModel(x);
  106. },
  107. max_tokens(x: number) {
  108. return limitNumber(x, 0, 32000, 2000);
  109. },
  110. presence_penalty(x: number) {
  111. return limitNumber(x, -2, 2, 0);
  112. },
  113. temperature(x: number) {
  114. return limitNumber(x, 0, 2, 1);
  115. },
  116. };
  117. const DEFAULT_CONFIG: ChatConfig = {
  118. historyMessageCount: 4,
  119. compressMessageLengthThreshold: 1000,
  120. sendBotMessages: true as boolean,
  121. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  122. avatar: "1f603",
  123. fontSize: 14,
  124. theme: Theme.Auto as Theme,
  125. tightBorder: false,
  126. sendPreviewBubble: true,
  127. sidebarWidth: 300,
  128. disablePromptHint: false,
  129. modelConfig: {
  130. model: "gpt-3.5-turbo",
  131. temperature: 1,
  132. max_tokens: 2000,
  133. presence_penalty: 0,
  134. },
  135. };
  136. export interface ChatStat {
  137. tokenCount: number;
  138. wordCount: number;
  139. charCount: number;
  140. }
  141. export interface ChatSession {
  142. id: number;
  143. topic: string;
  144. sendMemory: boolean;
  145. memoryPrompt: string;
  146. context: Message[];
  147. messages: Message[];
  148. stat: ChatStat;
  149. lastUpdate: string;
  150. lastSummarizeIndex: number;
  151. }
  152. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  153. export const BOT_HELLO: Message = createMessage({
  154. role: "assistant",
  155. content: Locale.Store.BotHello,
  156. });
  157. function createEmptySession(): ChatSession {
  158. const createDate = new Date().toLocaleString();
  159. return {
  160. id: Date.now(),
  161. topic: DEFAULT_TOPIC,
  162. sendMemory: true,
  163. memoryPrompt: "",
  164. context: [],
  165. messages: [],
  166. stat: {
  167. tokenCount: 0,
  168. wordCount: 0,
  169. charCount: 0,
  170. },
  171. lastUpdate: createDate,
  172. lastSummarizeIndex: 0,
  173. };
  174. }
  175. interface ChatStore {
  176. config: ChatConfig;
  177. sessions: ChatSession[];
  178. currentSessionIndex: number;
  179. clearSessions: () => void;
  180. removeSession: (index: number) => void;
  181. moveSession: (from: number, to: number) => void;
  182. selectSession: (index: number) => void;
  183. newSession: () => void;
  184. deleteSession: (index?: number) => void;
  185. currentSession: () => ChatSession;
  186. onNewMessage: (message: Message) => void;
  187. onUserInput: (content: string) => Promise<void>;
  188. summarizeSession: () => void;
  189. updateStat: (message: Message) => void;
  190. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  191. updateMessage: (
  192. sessionIndex: number,
  193. messageIndex: number,
  194. updater: (message?: Message) => void,
  195. ) => void;
  196. resetSession: () => void;
  197. getMessagesWithMemory: () => Message[];
  198. getMemoryPrompt: () => Message;
  199. getConfig: () => ChatConfig;
  200. resetConfig: () => void;
  201. updateConfig: (updater: (config: ChatConfig) => void) => void;
  202. clearAllData: () => void;
  203. }
  204. function countMessages(msgs: Message[]) {
  205. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  206. }
  207. const LOCAL_KEY = "chat-next-web-store";
  208. export const useChatStore = create<ChatStore>()(
  209. persist(
  210. (set, get) => ({
  211. sessions: [createEmptySession()],
  212. currentSessionIndex: 0,
  213. config: {
  214. ...DEFAULT_CONFIG,
  215. },
  216. clearSessions() {
  217. set(() => ({
  218. sessions: [createEmptySession()],
  219. currentSessionIndex: 0,
  220. }));
  221. },
  222. resetConfig() {
  223. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  224. },
  225. getConfig() {
  226. return get().config;
  227. },
  228. updateConfig(updater) {
  229. const config = get().config;
  230. updater(config);
  231. set(() => ({ config }));
  232. },
  233. selectSession(index: number) {
  234. set({
  235. currentSessionIndex: index,
  236. });
  237. },
  238. removeSession(index: number) {
  239. set((state) => {
  240. let nextIndex = state.currentSessionIndex;
  241. const sessions = state.sessions;
  242. if (sessions.length === 1) {
  243. return {
  244. currentSessionIndex: 0,
  245. sessions: [createEmptySession()],
  246. };
  247. }
  248. sessions.splice(index, 1);
  249. if (nextIndex === index) {
  250. nextIndex -= 1;
  251. }
  252. return {
  253. currentSessionIndex: nextIndex,
  254. sessions,
  255. };
  256. });
  257. },
  258. moveSession(from: number, to: number) {
  259. set((state) => {
  260. const { sessions, currentSessionIndex: oldIndex } = state;
  261. // move the session
  262. const newSessions = [...sessions];
  263. const session = newSessions[from];
  264. newSessions.splice(from, 1);
  265. newSessions.splice(to, 0, session);
  266. // modify current session id
  267. let newIndex = oldIndex === from ? to : oldIndex;
  268. if (oldIndex > from && oldIndex <= to) {
  269. newIndex -= 1;
  270. } else if (oldIndex < from && oldIndex >= to) {
  271. newIndex += 1;
  272. }
  273. return {
  274. currentSessionIndex: newIndex,
  275. sessions: newSessions,
  276. };
  277. });
  278. },
  279. newSession() {
  280. set((state) => ({
  281. currentSessionIndex: 0,
  282. sessions: [createEmptySession()].concat(state.sessions),
  283. }));
  284. },
  285. deleteSession(i?: number) {
  286. const deletedSession = get().currentSession();
  287. const index = i ?? get().currentSessionIndex;
  288. const isLastSession = get().sessions.length === 1;
  289. if (!isMobileScreen() || confirm(Locale.Home.DeleteChat)) {
  290. get().removeSession(index);
  291. showToast(
  292. Locale.Home.DeleteToast,
  293. {
  294. text: Locale.Home.Revert,
  295. onClick() {
  296. set((state) => ({
  297. sessions: state.sessions
  298. .slice(0, index)
  299. .concat([deletedSession])
  300. .concat(
  301. state.sessions.slice(index + Number(isLastSession)),
  302. ),
  303. }));
  304. },
  305. },
  306. 5000,
  307. );
  308. }
  309. },
  310. currentSession() {
  311. let index = get().currentSessionIndex;
  312. const sessions = get().sessions;
  313. if (index < 0 || index >= sessions.length) {
  314. index = Math.min(sessions.length - 1, Math.max(0, index));
  315. set(() => ({ currentSessionIndex: index }));
  316. }
  317. const session = sessions[index];
  318. return session;
  319. },
  320. onNewMessage(message) {
  321. get().updateCurrentSession((session) => {
  322. session.lastUpdate = new Date().toLocaleString();
  323. });
  324. get().updateStat(message);
  325. get().summarizeSession();
  326. },
  327. async onUserInput(content) {
  328. const userMessage: Message = createMessage({
  329. role: "user",
  330. content,
  331. });
  332. const botMessage: Message = createMessage({
  333. role: "assistant",
  334. streaming: true,
  335. id: userMessage.id! + 1,
  336. });
  337. // get recent messages
  338. const recentMessages = get().getMessagesWithMemory();
  339. const sendMessages = recentMessages.concat(userMessage);
  340. const sessionIndex = get().currentSessionIndex;
  341. const messageIndex = get().currentSession().messages.length + 1;
  342. // save user's and bot's message
  343. get().updateCurrentSession((session) => {
  344. session.messages.push(userMessage);
  345. session.messages.push(botMessage);
  346. });
  347. // make request
  348. console.log("[User Input] ", sendMessages);
  349. requestChatStream(sendMessages, {
  350. onMessage(content, done) {
  351. // stream response
  352. if (done) {
  353. botMessage.streaming = false;
  354. botMessage.content = content;
  355. get().onNewMessage(botMessage);
  356. ControllerPool.remove(
  357. sessionIndex,
  358. botMessage.id ?? messageIndex,
  359. );
  360. } else {
  361. botMessage.content = content;
  362. set(() => ({}));
  363. }
  364. },
  365. onError(error, statusCode) {
  366. if (statusCode === 401) {
  367. botMessage.content = Locale.Error.Unauthorized;
  368. } else if (!error.message.includes("aborted")) {
  369. botMessage.content += "\n\n" + Locale.Store.Error;
  370. }
  371. botMessage.streaming = false;
  372. userMessage.isError = true;
  373. botMessage.isError = true;
  374. set(() => ({}));
  375. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  376. },
  377. onController(controller) {
  378. // collect controller for stop/retry
  379. ControllerPool.addController(
  380. sessionIndex,
  381. botMessage.id ?? messageIndex,
  382. controller,
  383. );
  384. },
  385. filterBot: !get().config.sendBotMessages,
  386. modelConfig: get().config.modelConfig,
  387. });
  388. },
  389. getMemoryPrompt() {
  390. const session = get().currentSession();
  391. return {
  392. role: "system",
  393. content: Locale.Store.Prompt.History(session.memoryPrompt),
  394. date: "",
  395. } as Message;
  396. },
  397. getMessagesWithMemory() {
  398. const session = get().currentSession();
  399. const config = get().config;
  400. const messages = session.messages.filter((msg) => !msg.isError);
  401. const n = messages.length;
  402. const context = session.context.slice();
  403. // long term memory
  404. if (
  405. session.sendMemory &&
  406. session.memoryPrompt &&
  407. session.memoryPrompt.length > 0
  408. ) {
  409. const memoryPrompt = get().getMemoryPrompt();
  410. context.push(memoryPrompt);
  411. }
  412. // get short term and unmemoried long term memory
  413. const shortTermMemoryMessageIndex = Math.max(
  414. 0,
  415. n - config.historyMessageCount,
  416. );
  417. const longTermMemoryMessageIndex = session.lastSummarizeIndex;
  418. const oldestIndex = Math.min(
  419. shortTermMemoryMessageIndex,
  420. longTermMemoryMessageIndex,
  421. );
  422. const threshold = config.compressMessageLengthThreshold;
  423. // get recent messages as many as possible
  424. const reversedRecentMessages = [];
  425. for (
  426. let i = n - 1, count = 0;
  427. i >= oldestIndex && count < threshold;
  428. i -= 1
  429. ) {
  430. const msg = messages[i];
  431. if (!msg || msg.isError) continue;
  432. count += msg.content.length;
  433. reversedRecentMessages.push(msg);
  434. }
  435. // concat
  436. const recentMessages = context.concat(reversedRecentMessages.reverse());
  437. return recentMessages;
  438. },
  439. updateMessage(
  440. sessionIndex: number,
  441. messageIndex: number,
  442. updater: (message?: Message) => void,
  443. ) {
  444. const sessions = get().sessions;
  445. const session = sessions.at(sessionIndex);
  446. const messages = session?.messages;
  447. updater(messages?.at(messageIndex));
  448. set(() => ({ sessions }));
  449. },
  450. resetSession() {
  451. get().updateCurrentSession((session) => {
  452. session.messages = [];
  453. session.memoryPrompt = "";
  454. });
  455. },
  456. summarizeSession() {
  457. const session = get().currentSession();
  458. // should summarize topic after chating more than 50 words
  459. const SUMMARIZE_MIN_LEN = 50;
  460. if (
  461. session.topic === DEFAULT_TOPIC &&
  462. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  463. ) {
  464. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  465. (res) => {
  466. get().updateCurrentSession(
  467. (session) =>
  468. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  469. );
  470. },
  471. );
  472. }
  473. const config = get().config;
  474. let toBeSummarizedMsgs = session.messages.slice(
  475. session.lastSummarizeIndex,
  476. );
  477. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  478. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  479. const n = toBeSummarizedMsgs.length;
  480. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  481. Math.max(0, n - config.historyMessageCount),
  482. );
  483. }
  484. // add memory prompt
  485. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  486. const lastSummarizeIndex = session.messages.length;
  487. console.log(
  488. "[Chat History] ",
  489. toBeSummarizedMsgs,
  490. historyMsgLength,
  491. config.compressMessageLengthThreshold,
  492. );
  493. if (
  494. historyMsgLength > config.compressMessageLengthThreshold &&
  495. session.sendMemory
  496. ) {
  497. requestChatStream(
  498. toBeSummarizedMsgs.concat({
  499. role: "system",
  500. content: Locale.Store.Prompt.Summarize,
  501. date: "",
  502. }),
  503. {
  504. filterBot: false,
  505. onMessage(message, done) {
  506. session.memoryPrompt = message;
  507. if (done) {
  508. console.log("[Memory] ", session.memoryPrompt);
  509. session.lastSummarizeIndex = lastSummarizeIndex;
  510. }
  511. },
  512. onError(error) {
  513. console.error("[Summarize] ", error);
  514. },
  515. },
  516. );
  517. }
  518. },
  519. updateStat(message) {
  520. get().updateCurrentSession((session) => {
  521. session.stat.charCount += message.content.length;
  522. // TODO: should update chat count and word count
  523. });
  524. },
  525. updateCurrentSession(updater) {
  526. const sessions = get().sessions;
  527. const index = get().currentSessionIndex;
  528. updater(sessions[index]);
  529. set(() => ({ sessions }));
  530. },
  531. clearAllData() {
  532. if (confirm(Locale.Store.ConfirmClearAll)) {
  533. localStorage.clear();
  534. location.reload();
  535. }
  536. },
  537. }),
  538. {
  539. name: LOCAL_KEY,
  540. version: 1.2,
  541. migrate(persistedState, version) {
  542. const state = persistedState as ChatStore;
  543. if (version === 1) {
  544. state.sessions.forEach((s) => (s.context = []));
  545. }
  546. if (version < 1.2) {
  547. state.sessions.forEach((s) => (s.sendMemory = true));
  548. }
  549. return state;
  550. },
  551. },
  552. ),
  553. );