app.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { isMobileScreen, trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. import { showToast } from "../components/ui-lib";
  12. export type Message = ChatCompletionResponseMessage & {
  13. date: string;
  14. streaming?: boolean;
  15. isError?: boolean;
  16. id?: number;
  17. model?: ModelType;
  18. };
  19. export function createMessage(override: Partial<Message>): Message {
  20. return {
  21. id: Date.now(),
  22. date: new Date().toLocaleString(),
  23. role: "user",
  24. content: "",
  25. ...override,
  26. };
  27. }
  28. export enum SubmitKey {
  29. Enter = "Enter",
  30. CtrlEnter = "Ctrl + Enter",
  31. ShiftEnter = "Shift + Enter",
  32. AltEnter = "Alt + Enter",
  33. MetaEnter = "Meta + Enter",
  34. }
  35. export enum Theme {
  36. Auto = "auto",
  37. Dark = "dark",
  38. Light = "light",
  39. }
  40. export interface ChatConfig {
  41. historyMessageCount: number; // -1 means all
  42. compressMessageLengthThreshold: number;
  43. sendBotMessages: boolean; // send bot's message or not
  44. submitKey: SubmitKey;
  45. avatar: string;
  46. fontSize: number;
  47. theme: Theme;
  48. tightBorder: boolean;
  49. sendPreviewBubble: boolean;
  50. sidebarWidth: number;
  51. disablePromptHint: boolean;
  52. modelConfig: {
  53. model: ModelType;
  54. temperature: number;
  55. max_tokens: number;
  56. presence_penalty: number;
  57. };
  58. }
  59. export type ModelConfig = ChatConfig["modelConfig"];
  60. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  61. const ENABLE_GPT4 = true;
  62. export const ALL_MODELS = [
  63. {
  64. name: "gpt-4",
  65. available: ENABLE_GPT4,
  66. },
  67. {
  68. name: "gpt-4-0314",
  69. available: ENABLE_GPT4,
  70. },
  71. {
  72. name: "gpt-4-32k",
  73. available: ENABLE_GPT4,
  74. },
  75. {
  76. name: "gpt-4-32k-0314",
  77. available: ENABLE_GPT4,
  78. },
  79. {
  80. name: "gpt-3.5-turbo",
  81. available: true,
  82. },
  83. {
  84. name: "gpt-3.5-turbo-0301",
  85. available: true,
  86. },
  87. ] as const;
  88. export type ModelType = (typeof ALL_MODELS)[number]["name"];
  89. export function limitNumber(
  90. x: number,
  91. min: number,
  92. max: number,
  93. defaultValue: number,
  94. ) {
  95. if (typeof x !== "number" || isNaN(x)) {
  96. return defaultValue;
  97. }
  98. return Math.min(max, Math.max(min, x));
  99. }
  100. export function limitModel(name: string) {
  101. return ALL_MODELS.some((m) => m.name === name && m.available)
  102. ? name
  103. : ALL_MODELS[4].name;
  104. }
  105. export const ModalConfigValidator = {
  106. model(x: string) {
  107. return limitModel(x) as ModelType;
  108. },
  109. max_tokens(x: number) {
  110. return limitNumber(x, 0, 32000, 2000);
  111. },
  112. presence_penalty(x: number) {
  113. return limitNumber(x, -2, 2, 0);
  114. },
  115. temperature(x: number) {
  116. return limitNumber(x, 0, 2, 1);
  117. },
  118. };
  119. const DEFAULT_CONFIG: ChatConfig = {
  120. historyMessageCount: 4,
  121. compressMessageLengthThreshold: 1000,
  122. sendBotMessages: true as boolean,
  123. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  124. avatar: "1f603",
  125. fontSize: 14,
  126. theme: Theme.Auto as Theme,
  127. tightBorder: false,
  128. sendPreviewBubble: true,
  129. sidebarWidth: 300,
  130. disablePromptHint: false,
  131. modelConfig: {
  132. model: "gpt-3.5-turbo",
  133. temperature: 1,
  134. max_tokens: 2000,
  135. presence_penalty: 0,
  136. },
  137. };
  138. export interface ChatStat {
  139. tokenCount: number;
  140. wordCount: number;
  141. charCount: number;
  142. }
  143. export interface ChatSession {
  144. id: number;
  145. topic: string;
  146. sendMemory: boolean;
  147. memoryPrompt: string;
  148. context: Message[];
  149. messages: Message[];
  150. stat: ChatStat;
  151. lastUpdate: string;
  152. lastSummarizeIndex: number;
  153. }
  154. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  155. export const BOT_HELLO: Message = createMessage({
  156. role: "assistant",
  157. content: Locale.Store.BotHello,
  158. });
  159. function createEmptySession(): ChatSession {
  160. const createDate = new Date().toLocaleString();
  161. return {
  162. id: Date.now(),
  163. topic: DEFAULT_TOPIC,
  164. sendMemory: true,
  165. memoryPrompt: "",
  166. context: [],
  167. messages: [],
  168. stat: {
  169. tokenCount: 0,
  170. wordCount: 0,
  171. charCount: 0,
  172. },
  173. lastUpdate: createDate,
  174. lastSummarizeIndex: 0,
  175. };
  176. }
  177. interface ChatStore {
  178. config: ChatConfig;
  179. sessions: ChatSession[];
  180. currentSessionIndex: number;
  181. clearSessions: () => void;
  182. removeSession: (index: number) => void;
  183. moveSession: (from: number, to: number) => void;
  184. selectSession: (index: number) => void;
  185. newSession: () => void;
  186. deleteSession: (index?: number) => void;
  187. currentSession: () => ChatSession;
  188. onNewMessage: (message: Message) => void;
  189. onUserInput: (content: string) => Promise<void>;
  190. summarizeSession: () => void;
  191. updateStat: (message: Message) => void;
  192. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  193. updateMessage: (
  194. sessionIndex: number,
  195. messageIndex: number,
  196. updater: (message?: Message) => void,
  197. ) => void;
  198. resetSession: () => void;
  199. getMessagesWithMemory: () => Message[];
  200. getMemoryPrompt: () => Message;
  201. getConfig: () => ChatConfig;
  202. resetConfig: () => void;
  203. updateConfig: (updater: (config: ChatConfig) => void) => void;
  204. clearAllData: () => void;
  205. }
  206. function countMessages(msgs: Message[]) {
  207. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  208. }
  209. const LOCAL_KEY = "chat-next-web-store";
  210. export const useChatStore = create<ChatStore>()(
  211. persist(
  212. (set, get) => ({
  213. sessions: [createEmptySession()],
  214. currentSessionIndex: 0,
  215. config: {
  216. ...DEFAULT_CONFIG,
  217. },
  218. clearSessions() {
  219. set(() => ({
  220. sessions: [createEmptySession()],
  221. currentSessionIndex: 0,
  222. }));
  223. },
  224. resetConfig() {
  225. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  226. },
  227. getConfig() {
  228. return get().config;
  229. },
  230. updateConfig(updater) {
  231. const config = get().config;
  232. updater(config);
  233. set(() => ({ config }));
  234. },
  235. selectSession(index: number) {
  236. set({
  237. currentSessionIndex: index,
  238. });
  239. },
  240. removeSession(index: number) {
  241. set((state) => {
  242. let nextIndex = state.currentSessionIndex;
  243. const sessions = state.sessions;
  244. if (sessions.length === 1) {
  245. return {
  246. currentSessionIndex: 0,
  247. sessions: [createEmptySession()],
  248. };
  249. }
  250. sessions.splice(index, 1);
  251. if (nextIndex === index) {
  252. nextIndex -= 1;
  253. }
  254. return {
  255. currentSessionIndex: nextIndex,
  256. sessions,
  257. };
  258. });
  259. },
  260. moveSession(from: number, to: number) {
  261. set((state) => {
  262. const { sessions, currentSessionIndex: oldIndex } = state;
  263. // move the session
  264. const newSessions = [...sessions];
  265. const session = newSessions[from];
  266. newSessions.splice(from, 1);
  267. newSessions.splice(to, 0, session);
  268. // modify current session id
  269. let newIndex = oldIndex === from ? to : oldIndex;
  270. if (oldIndex > from && oldIndex <= to) {
  271. newIndex -= 1;
  272. } else if (oldIndex < from && oldIndex >= to) {
  273. newIndex += 1;
  274. }
  275. return {
  276. currentSessionIndex: newIndex,
  277. sessions: newSessions,
  278. };
  279. });
  280. },
  281. newSession() {
  282. set((state) => ({
  283. currentSessionIndex: 0,
  284. sessions: [createEmptySession()].concat(state.sessions),
  285. }));
  286. },
  287. deleteSession(i?: number) {
  288. const deletedSession = get().currentSession();
  289. const index = i ?? get().currentSessionIndex;
  290. const isLastSession = get().sessions.length === 1;
  291. if (!isMobileScreen() || confirm(Locale.Home.DeleteChat)) {
  292. get().removeSession(index);
  293. showToast(
  294. Locale.Home.DeleteToast,
  295. {
  296. text: Locale.Home.Revert,
  297. onClick() {
  298. set((state) => ({
  299. sessions: state.sessions
  300. .slice(0, index)
  301. .concat([deletedSession])
  302. .concat(
  303. state.sessions.slice(index + Number(isLastSession)),
  304. ),
  305. }));
  306. },
  307. },
  308. 5000,
  309. );
  310. }
  311. },
  312. currentSession() {
  313. let index = get().currentSessionIndex;
  314. const sessions = get().sessions;
  315. if (index < 0 || index >= sessions.length) {
  316. index = Math.min(sessions.length - 1, Math.max(0, index));
  317. set(() => ({ currentSessionIndex: index }));
  318. }
  319. const session = sessions[index];
  320. return session;
  321. },
  322. onNewMessage(message) {
  323. get().updateCurrentSession((session) => {
  324. session.lastUpdate = new Date().toLocaleString();
  325. });
  326. get().updateStat(message);
  327. get().summarizeSession();
  328. },
  329. async onUserInput(content) {
  330. const userMessage: Message = createMessage({
  331. role: "user",
  332. content,
  333. });
  334. const botMessage: Message = createMessage({
  335. role: "assistant",
  336. streaming: true,
  337. id: userMessage.id! + 1,
  338. model: get().config.modelConfig.model,
  339. });
  340. // get recent messages
  341. const recentMessages = get().getMessagesWithMemory();
  342. const sendMessages = recentMessages.concat(userMessage);
  343. const sessionIndex = get().currentSessionIndex;
  344. const messageIndex = get().currentSession().messages.length + 1;
  345. // save user's and bot's message
  346. get().updateCurrentSession((session) => {
  347. session.messages.push(userMessage);
  348. session.messages.push(botMessage);
  349. });
  350. // make request
  351. console.log("[User Input] ", sendMessages);
  352. requestChatStream(sendMessages, {
  353. onMessage(content, done) {
  354. // stream response
  355. if (done) {
  356. botMessage.streaming = false;
  357. botMessage.content = content;
  358. get().onNewMessage(botMessage);
  359. ControllerPool.remove(
  360. sessionIndex,
  361. botMessage.id ?? messageIndex,
  362. );
  363. } else {
  364. botMessage.content = content;
  365. set(() => ({}));
  366. }
  367. },
  368. onError(error, statusCode) {
  369. if (statusCode === 401) {
  370. botMessage.content = Locale.Error.Unauthorized;
  371. } else if (!error.message.includes("aborted")) {
  372. botMessage.content += "\n\n" + Locale.Store.Error;
  373. }
  374. botMessage.streaming = false;
  375. userMessage.isError = true;
  376. botMessage.isError = true;
  377. set(() => ({}));
  378. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  379. },
  380. onController(controller) {
  381. // collect controller for stop/retry
  382. ControllerPool.addController(
  383. sessionIndex,
  384. botMessage.id ?? messageIndex,
  385. controller,
  386. );
  387. },
  388. filterBot: !get().config.sendBotMessages,
  389. modelConfig: get().config.modelConfig,
  390. });
  391. },
  392. getMemoryPrompt() {
  393. const session = get().currentSession();
  394. return {
  395. role: "system",
  396. content: Locale.Store.Prompt.History(session.memoryPrompt),
  397. date: "",
  398. } as Message;
  399. },
  400. getMessagesWithMemory() {
  401. const session = get().currentSession();
  402. const config = get().config;
  403. const messages = session.messages.filter((msg) => !msg.isError);
  404. const n = messages.length;
  405. const context = session.context.slice();
  406. // long term memory
  407. if (
  408. session.sendMemory &&
  409. session.memoryPrompt &&
  410. session.memoryPrompt.length > 0
  411. ) {
  412. const memoryPrompt = get().getMemoryPrompt();
  413. context.push(memoryPrompt);
  414. }
  415. // get short term and unmemoried long term memory
  416. const shortTermMemoryMessageIndex = Math.max(
  417. 0,
  418. n - config.historyMessageCount,
  419. );
  420. const longTermMemoryMessageIndex = session.lastSummarizeIndex;
  421. const oldestIndex = Math.max(
  422. shortTermMemoryMessageIndex,
  423. longTermMemoryMessageIndex,
  424. );
  425. const threshold = config.compressMessageLengthThreshold;
  426. // get recent messages as many as possible
  427. const reversedRecentMessages = [];
  428. for (
  429. let i = n - 1, count = 0;
  430. i >= oldestIndex && count < threshold;
  431. i -= 1
  432. ) {
  433. const msg = messages[i];
  434. if (!msg || msg.isError) continue;
  435. count += msg.content.length;
  436. reversedRecentMessages.push(msg);
  437. }
  438. // concat
  439. const recentMessages = context.concat(reversedRecentMessages.reverse());
  440. return recentMessages;
  441. },
  442. updateMessage(
  443. sessionIndex: number,
  444. messageIndex: number,
  445. updater: (message?: Message) => void,
  446. ) {
  447. const sessions = get().sessions;
  448. const session = sessions.at(sessionIndex);
  449. const messages = session?.messages;
  450. updater(messages?.at(messageIndex));
  451. set(() => ({ sessions }));
  452. },
  453. resetSession() {
  454. get().updateCurrentSession((session) => {
  455. session.messages = [];
  456. session.memoryPrompt = "";
  457. });
  458. },
  459. summarizeSession() {
  460. const session = get().currentSession();
  461. // should summarize topic after chating more than 50 words
  462. const SUMMARIZE_MIN_LEN = 50;
  463. if (
  464. session.topic === DEFAULT_TOPIC &&
  465. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  466. ) {
  467. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic, {
  468. model: "gpt-3.5-turbo",
  469. }).then((res) => {
  470. get().updateCurrentSession(
  471. (session) =>
  472. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  473. );
  474. });
  475. }
  476. const config = get().config;
  477. let toBeSummarizedMsgs = session.messages.slice(
  478. session.lastSummarizeIndex,
  479. );
  480. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  481. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  482. const n = toBeSummarizedMsgs.length;
  483. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  484. Math.max(0, n - config.historyMessageCount),
  485. );
  486. }
  487. // add memory prompt
  488. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  489. const lastSummarizeIndex = session.messages.length;
  490. console.log(
  491. "[Chat History] ",
  492. toBeSummarizedMsgs,
  493. historyMsgLength,
  494. config.compressMessageLengthThreshold,
  495. );
  496. if (
  497. historyMsgLength > config.compressMessageLengthThreshold &&
  498. session.sendMemory
  499. ) {
  500. requestChatStream(
  501. toBeSummarizedMsgs.concat({
  502. role: "system",
  503. content: Locale.Store.Prompt.Summarize,
  504. date: "",
  505. }),
  506. {
  507. filterBot: false,
  508. model: "gpt-3.5-turbo",
  509. onMessage(message, done) {
  510. session.memoryPrompt = message;
  511. if (done) {
  512. console.log("[Memory] ", session.memoryPrompt);
  513. session.lastSummarizeIndex = lastSummarizeIndex;
  514. }
  515. },
  516. onError(error) {
  517. console.error("[Summarize] ", error);
  518. },
  519. },
  520. );
  521. }
  522. },
  523. updateStat(message) {
  524. get().updateCurrentSession((session) => {
  525. session.stat.charCount += message.content.length;
  526. // TODO: should update chat count and word count
  527. });
  528. },
  529. updateCurrentSession(updater) {
  530. const sessions = get().sessions;
  531. const index = get().currentSessionIndex;
  532. updater(sessions[index]);
  533. set(() => ({ sessions }));
  534. },
  535. clearAllData() {
  536. if (confirm(Locale.Store.ConfirmClearAll)) {
  537. localStorage.clear();
  538. location.reload();
  539. }
  540. },
  541. }),
  542. {
  543. name: LOCAL_KEY,
  544. version: 1.2,
  545. migrate(persistedState, version) {
  546. const state = persistedState as ChatStore;
  547. if (version === 1) {
  548. state.sessions.forEach((s) => (s.context = []));
  549. }
  550. if (version < 1.2) {
  551. state.sessions.forEach((s) => (s.sendMemory = true));
  552. }
  553. return state;
  554. },
  555. },
  556. ),
  557. );