app.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { isMobileScreen, trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. import { showToast } from "../components/ui-lib";
  12. export type Message = ChatCompletionResponseMessage & {
  13. date: string;
  14. streaming?: boolean;
  15. isError?: boolean;
  16. id?: number;
  17. };
  18. export function createMessage(override: Partial<Message>): Message {
  19. return {
  20. id: Date.now(),
  21. date: new Date().toLocaleString(),
  22. role: "user",
  23. content: "",
  24. ...override,
  25. };
  26. }
  27. export enum SubmitKey {
  28. Enter = "Enter",
  29. CtrlEnter = "Ctrl + Enter",
  30. ShiftEnter = "Shift + Enter",
  31. AltEnter = "Alt + Enter",
  32. MetaEnter = "Meta + Enter",
  33. }
  34. export enum Theme {
  35. Auto = "auto",
  36. Dark = "dark",
  37. Light = "light",
  38. }
  39. export interface ChatConfig {
  40. historyMessageCount: number; // -1 means all
  41. compressMessageLengthThreshold: number;
  42. sendBotMessages: boolean; // send bot's message or not
  43. submitKey: SubmitKey;
  44. avatar: string;
  45. fontSize: number;
  46. theme: Theme;
  47. tightBorder: boolean;
  48. sendPreviewBubble: boolean;
  49. sidebarWidth: number;
  50. disablePromptHint: boolean;
  51. modelConfig: {
  52. model: string;
  53. temperature: number;
  54. max_tokens: number;
  55. presence_penalty: number;
  56. };
  57. }
  58. export type ModelConfig = ChatConfig["modelConfig"];
  59. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  60. const ENABLE_GPT4 = true;
  61. export const ALL_MODELS = [
  62. {
  63. name: "gpt-4",
  64. available: ENABLE_GPT4,
  65. },
  66. {
  67. name: "gpt-4-0314",
  68. available: ENABLE_GPT4,
  69. },
  70. {
  71. name: "gpt-4-32k",
  72. available: ENABLE_GPT4,
  73. },
  74. {
  75. name: "gpt-4-32k-0314",
  76. available: ENABLE_GPT4,
  77. },
  78. {
  79. name: "gpt-3.5-turbo",
  80. available: true,
  81. },
  82. {
  83. name: "gpt-3.5-turbo-0301",
  84. available: true,
  85. },
  86. ];
  87. export function limitNumber(
  88. x: number,
  89. min: number,
  90. max: number,
  91. defaultValue: number,
  92. ) {
  93. if (typeof x !== "number" || isNaN(x)) {
  94. return defaultValue;
  95. }
  96. return Math.min(max, Math.max(min, x));
  97. }
  98. export function limitModel(name: string) {
  99. return ALL_MODELS.some((m) => m.name === name && m.available)
  100. ? name
  101. : ALL_MODELS[4].name;
  102. }
  103. export const ModalConfigValidator = {
  104. model(x: string) {
  105. return limitModel(x);
  106. },
  107. max_tokens(x: number) {
  108. return limitNumber(x, 0, 32000, 2000);
  109. },
  110. presence_penalty(x: number) {
  111. return limitNumber(x, -2, 2, 0);
  112. },
  113. temperature(x: number) {
  114. return limitNumber(x, 0, 2, 1);
  115. },
  116. };
  117. const DEFAULT_CONFIG: ChatConfig = {
  118. historyMessageCount: 4,
  119. compressMessageLengthThreshold: 1000,
  120. sendBotMessages: true as boolean,
  121. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  122. avatar: "1f603",
  123. fontSize: 14,
  124. theme: Theme.Auto as Theme,
  125. tightBorder: false,
  126. sendPreviewBubble: true,
  127. sidebarWidth: 300,
  128. disablePromptHint: false,
  129. modelConfig: {
  130. model: "gpt-3.5-turbo",
  131. temperature: 1,
  132. max_tokens: 2000,
  133. presence_penalty: 0,
  134. },
  135. };
  136. export interface ChatStat {
  137. tokenCount: number;
  138. wordCount: number;
  139. charCount: number;
  140. }
  141. export interface ChatSession {
  142. id: number;
  143. topic: string;
  144. sendMemory: boolean;
  145. memoryPrompt: string;
  146. context: Message[];
  147. messages: Message[];
  148. stat: ChatStat;
  149. lastUpdate: string;
  150. lastSummarizeIndex: number;
  151. }
  152. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  153. export const BOT_HELLO: Message = createMessage({
  154. role: "assistant",
  155. content: Locale.Store.BotHello,
  156. });
  157. function createEmptySession(): ChatSession {
  158. const createDate = new Date().toLocaleString();
  159. return {
  160. id: Date.now(),
  161. topic: DEFAULT_TOPIC,
  162. sendMemory: true,
  163. memoryPrompt: "",
  164. context: [],
  165. messages: [],
  166. stat: {
  167. tokenCount: 0,
  168. wordCount: 0,
  169. charCount: 0,
  170. },
  171. lastUpdate: createDate,
  172. lastSummarizeIndex: 0,
  173. };
  174. }
  175. interface ChatStore {
  176. config: ChatConfig;
  177. sessions: ChatSession[];
  178. currentSessionIndex: number;
  179. clearSessions: () => void;
  180. removeSession: (index: number) => void;
  181. moveSession: (from: number, to: number) => void;
  182. selectSession: (index: number) => void;
  183. newSession: () => void;
  184. deleteSession: (index?: number) => void;
  185. currentSession: () => ChatSession;
  186. onNewMessage: (message: Message) => void;
  187. onUserInput: (content: string) => Promise<void>;
  188. summarizeSession: () => void;
  189. updateStat: (message: Message) => void;
  190. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  191. updateMessage: (
  192. sessionIndex: number,
  193. messageIndex: number,
  194. updater: (message?: Message) => void,
  195. ) => void;
  196. resetSession: () => void;
  197. getMessagesWithMemory: () => Message[];
  198. getMemoryPrompt: () => Message;
  199. getConfig: () => ChatConfig;
  200. resetConfig: () => void;
  201. updateConfig: (updater: (config: ChatConfig) => void) => void;
  202. clearAllData: () => void;
  203. }
  204. function countMessages(msgs: Message[]) {
  205. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  206. }
  207. const LOCAL_KEY = "chat-next-web-store";
  208. export const useChatStore = create<ChatStore>()(
  209. persist(
  210. (set, get) => ({
  211. sessions: [createEmptySession()],
  212. currentSessionIndex: 0,
  213. config: {
  214. ...DEFAULT_CONFIG,
  215. },
  216. clearSessions() {
  217. set(() => ({
  218. sessions: [createEmptySession()],
  219. currentSessionIndex: 0,
  220. }));
  221. },
  222. resetConfig() {
  223. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  224. },
  225. getConfig() {
  226. return get().config;
  227. },
  228. updateConfig(updater) {
  229. const config = get().config;
  230. updater(config);
  231. set(() => ({ config }));
  232. },
  233. selectSession(index: number) {
  234. set({
  235. currentSessionIndex: index,
  236. });
  237. },
  238. removeSession(index: number) {
  239. set((state) => {
  240. let nextIndex = state.currentSessionIndex;
  241. const sessions = state.sessions;
  242. if (sessions.length === 1) {
  243. return {
  244. currentSessionIndex: 0,
  245. sessions: [createEmptySession()],
  246. };
  247. }
  248. sessions.splice(index, 1);
  249. if (nextIndex === index) {
  250. nextIndex -= 1;
  251. }
  252. return {
  253. currentSessionIndex: nextIndex,
  254. sessions,
  255. };
  256. });
  257. },
  258. moveSession(from: number, to: number) {
  259. set((state) => {
  260. const { sessions, currentSessionIndex: oldIndex } = state;
  261. // move the session
  262. const newSessions = [...sessions];
  263. const session = newSessions[from];
  264. newSessions.splice(from, 1);
  265. newSessions.splice(to, 0, session);
  266. // modify current session id
  267. let newIndex = oldIndex === from ? to : oldIndex;
  268. if (oldIndex > from && oldIndex <= to) {
  269. newIndex -= 1;
  270. } else if (oldIndex < from && oldIndex >= to) {
  271. newIndex += 1;
  272. }
  273. return {
  274. currentSessionIndex: newIndex,
  275. sessions: newSessions,
  276. };
  277. });
  278. },
  279. newSession() {
  280. set((state) => ({
  281. currentSessionIndex: 0,
  282. sessions: [createEmptySession()].concat(state.sessions),
  283. }));
  284. },
  285. deleteSession(i?: number) {
  286. const deletedSession = get().currentSession();
  287. const index = i ?? get().currentSessionIndex;
  288. const isLastSession = get().sessions.length === 1;
  289. if (!isMobileScreen() || confirm(Locale.Home.DeleteChat)) {
  290. get().removeSession(index);
  291. showToast(
  292. Locale.Home.DeleteToast,
  293. {
  294. text: Locale.Home.Revert,
  295. onClick() {
  296. set((state) => ({
  297. sessions: state.sessions
  298. .slice(0, index)
  299. .concat([deletedSession])
  300. .concat(
  301. state.sessions.slice(index + Number(isLastSession)),
  302. ),
  303. }));
  304. },
  305. },
  306. 5000,
  307. );
  308. }
  309. },
  310. currentSession() {
  311. let index = get().currentSessionIndex;
  312. const sessions = get().sessions;
  313. if (index < 0 || index >= sessions.length) {
  314. index = Math.min(sessions.length - 1, Math.max(0, index));
  315. set(() => ({ currentSessionIndex: index }));
  316. }
  317. const session = sessions[index];
  318. return session;
  319. },
  320. onNewMessage(message) {
  321. get().updateCurrentSession((session) => {
  322. session.lastUpdate = new Date().toLocaleString();
  323. });
  324. get().updateStat(message);
  325. get().summarizeSession();
  326. },
  327. async onUserInput(content) {
  328. const userMessage: Message = createMessage({
  329. role: "user",
  330. content,
  331. });
  332. const botMessage: Message = createMessage({
  333. role: "assistant",
  334. streaming: true,
  335. });
  336. // get recent messages
  337. const recentMessages = get().getMessagesWithMemory();
  338. const sendMessages = recentMessages.concat(userMessage);
  339. const sessionIndex = get().currentSessionIndex;
  340. const messageIndex = get().currentSession().messages.length + 1;
  341. // save user's and bot's message
  342. get().updateCurrentSession((session) => {
  343. session.messages.push(userMessage);
  344. session.messages.push(botMessage);
  345. });
  346. // make request
  347. console.log("[User Input] ", sendMessages);
  348. requestChatStream(sendMessages, {
  349. onMessage(content, done) {
  350. // stream response
  351. if (done) {
  352. botMessage.streaming = false;
  353. botMessage.content = content;
  354. get().onNewMessage(botMessage);
  355. ControllerPool.remove(
  356. sessionIndex,
  357. botMessage.id ?? messageIndex,
  358. );
  359. } else {
  360. botMessage.content = content;
  361. set(() => ({}));
  362. }
  363. },
  364. onError(error, statusCode) {
  365. if (statusCode === 401) {
  366. botMessage.content = Locale.Error.Unauthorized;
  367. } else {
  368. botMessage.content += "\n\n" + Locale.Store.Error;
  369. }
  370. botMessage.streaming = false;
  371. userMessage.isError = true;
  372. botMessage.isError = true;
  373. set(() => ({}));
  374. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  375. },
  376. onController(controller) {
  377. // collect controller for stop/retry
  378. ControllerPool.addController(
  379. sessionIndex,
  380. botMessage.id ?? messageIndex,
  381. controller,
  382. );
  383. },
  384. filterBot: !get().config.sendBotMessages,
  385. modelConfig: get().config.modelConfig,
  386. });
  387. },
  388. getMemoryPrompt() {
  389. const session = get().currentSession();
  390. return {
  391. role: "system",
  392. content: Locale.Store.Prompt.History(session.memoryPrompt),
  393. date: "",
  394. } as Message;
  395. },
  396. getMessagesWithMemory() {
  397. const session = get().currentSession();
  398. const config = get().config;
  399. const messages = session.messages.filter((msg) => !msg.isError);
  400. const n = messages.length;
  401. const context = session.context.slice();
  402. if (
  403. session.sendMemory &&
  404. session.memoryPrompt &&
  405. session.memoryPrompt.length > 0
  406. ) {
  407. const memoryPrompt = get().getMemoryPrompt();
  408. context.push(memoryPrompt);
  409. }
  410. const recentMessages = context.concat(
  411. messages.slice(Math.max(0, n - config.historyMessageCount)),
  412. );
  413. return recentMessages;
  414. },
  415. updateMessage(
  416. sessionIndex: number,
  417. messageIndex: number,
  418. updater: (message?: Message) => void,
  419. ) {
  420. const sessions = get().sessions;
  421. const session = sessions.at(sessionIndex);
  422. const messages = session?.messages;
  423. updater(messages?.at(messageIndex));
  424. set(() => ({ sessions }));
  425. },
  426. resetSession() {
  427. get().updateCurrentSession((session) => {
  428. session.messages = [];
  429. session.memoryPrompt = "";
  430. });
  431. },
  432. summarizeSession() {
  433. const session = get().currentSession();
  434. // should summarize topic after chating more than 50 words
  435. const SUMMARIZE_MIN_LEN = 50;
  436. if (
  437. session.topic === DEFAULT_TOPIC &&
  438. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  439. ) {
  440. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  441. (res) => {
  442. get().updateCurrentSession(
  443. (session) =>
  444. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  445. );
  446. },
  447. );
  448. }
  449. const config = get().config;
  450. let toBeSummarizedMsgs = session.messages.slice(
  451. session.lastSummarizeIndex,
  452. );
  453. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  454. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  455. const n = toBeSummarizedMsgs.length;
  456. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  457. Math.max(0, n - config.historyMessageCount),
  458. );
  459. }
  460. // add memory prompt
  461. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  462. const lastSummarizeIndex = session.messages.length;
  463. console.log(
  464. "[Chat History] ",
  465. toBeSummarizedMsgs,
  466. historyMsgLength,
  467. config.compressMessageLengthThreshold,
  468. );
  469. if (historyMsgLength > config.compressMessageLengthThreshold) {
  470. requestChatStream(
  471. toBeSummarizedMsgs.concat({
  472. role: "system",
  473. content: Locale.Store.Prompt.Summarize,
  474. date: "",
  475. }),
  476. {
  477. filterBot: false,
  478. onMessage(message, done) {
  479. session.memoryPrompt = message;
  480. if (done) {
  481. console.log("[Memory] ", session.memoryPrompt);
  482. session.lastSummarizeIndex = lastSummarizeIndex;
  483. }
  484. },
  485. onError(error) {
  486. console.error("[Summarize] ", error);
  487. },
  488. },
  489. );
  490. }
  491. },
  492. updateStat(message) {
  493. get().updateCurrentSession((session) => {
  494. session.stat.charCount += message.content.length;
  495. // TODO: should update chat count and word count
  496. });
  497. },
  498. updateCurrentSession(updater) {
  499. const sessions = get().sessions;
  500. const index = get().currentSessionIndex;
  501. updater(sessions[index]);
  502. set(() => ({ sessions }));
  503. },
  504. clearAllData() {
  505. if (confirm(Locale.Store.ConfirmClearAll)) {
  506. localStorage.clear();
  507. location.reload();
  508. }
  509. },
  510. }),
  511. {
  512. name: LOCAL_KEY,
  513. version: 1.2,
  514. migrate(persistedState, version) {
  515. const state = persistedState as ChatStore;
  516. if (version === 1) {
  517. state.sessions.forEach((s) => (s.context = []));
  518. }
  519. if (version < 1.2) {
  520. state.sessions.forEach((s) => (s.sendMemory = true));
  521. }
  522. return state;
  523. },
  524. },
  525. ),
  526. );