app.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { isMobileScreen, trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. import { showToast } from "../components/ui-lib";
  12. export type Message = ChatCompletionResponseMessage & {
  13. date: string;
  14. streaming?: boolean;
  15. isError?: boolean;
  16. id?: number;
  17. };
  18. export function createMessage(override: Partial<Message>): Message {
  19. return {
  20. id: Date.now(),
  21. date: new Date().toLocaleString(),
  22. role: "user",
  23. content: "",
  24. ...override,
  25. };
  26. }
  27. export enum SubmitKey {
  28. Enter = "Enter",
  29. CtrlEnter = "Ctrl + Enter",
  30. ShiftEnter = "Shift + Enter",
  31. AltEnter = "Alt + Enter",
  32. MetaEnter = "Meta + Enter",
  33. }
  34. export enum Theme {
  35. Auto = "auto",
  36. Dark = "dark",
  37. Light = "light",
  38. }
  39. export interface ChatConfig {
  40. historyMessageCount: number; // -1 means all
  41. compressMessageLengthThreshold: number;
  42. sendBotMessages: boolean; // send bot's message or not
  43. submitKey: SubmitKey;
  44. avatar: string;
  45. fontSize: number;
  46. theme: Theme;
  47. tightBorder: boolean;
  48. sendPreviewBubble: boolean;
  49. disablePromptHint: boolean;
  50. modelConfig: {
  51. model: string;
  52. temperature: number;
  53. max_tokens: number;
  54. presence_penalty: number;
  55. };
  56. }
  57. export type ModelConfig = ChatConfig["modelConfig"];
  58. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  59. const ENABLE_GPT4 = true;
  60. export const ALL_MODELS = [
  61. {
  62. name: "gpt-4",
  63. available: ENABLE_GPT4,
  64. },
  65. {
  66. name: "gpt-4-0314",
  67. available: ENABLE_GPT4,
  68. },
  69. {
  70. name: "gpt-4-32k",
  71. available: ENABLE_GPT4,
  72. },
  73. {
  74. name: "gpt-4-32k-0314",
  75. available: ENABLE_GPT4,
  76. },
  77. {
  78. name: "gpt-3.5-turbo",
  79. available: true,
  80. },
  81. {
  82. name: "gpt-3.5-turbo-0301",
  83. available: true,
  84. },
  85. ];
  86. export function limitNumber(
  87. x: number,
  88. min: number,
  89. max: number,
  90. defaultValue: number,
  91. ) {
  92. if (typeof x !== "number" || isNaN(x)) {
  93. return defaultValue;
  94. }
  95. return Math.min(max, Math.max(min, x));
  96. }
  97. export function limitModel(name: string) {
  98. return ALL_MODELS.some((m) => m.name === name && m.available)
  99. ? name
  100. : ALL_MODELS[4].name;
  101. }
  102. export const ModalConfigValidator = {
  103. model(x: string) {
  104. return limitModel(x);
  105. },
  106. max_tokens(x: number) {
  107. return limitNumber(x, 0, 32000, 2000);
  108. },
  109. presence_penalty(x: number) {
  110. return limitNumber(x, -2, 2, 0);
  111. },
  112. temperature(x: number) {
  113. return limitNumber(x, 0, 2, 1);
  114. },
  115. };
  116. const DEFAULT_CONFIG: ChatConfig = {
  117. historyMessageCount: 4,
  118. compressMessageLengthThreshold: 1000,
  119. sendBotMessages: true as boolean,
  120. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  121. avatar: "1f603",
  122. fontSize: 14,
  123. theme: Theme.Auto as Theme,
  124. tightBorder: false,
  125. sendPreviewBubble: true,
  126. disablePromptHint: false,
  127. modelConfig: {
  128. model: "gpt-3.5-turbo",
  129. temperature: 1,
  130. max_tokens: 2000,
  131. presence_penalty: 0,
  132. },
  133. };
  134. export interface ChatStat {
  135. tokenCount: number;
  136. wordCount: number;
  137. charCount: number;
  138. }
  139. export interface ChatSession {
  140. id: number;
  141. topic: string;
  142. sendMemory: boolean;
  143. memoryPrompt: string;
  144. context: Message[];
  145. messages: Message[];
  146. stat: ChatStat;
  147. lastUpdate: string;
  148. lastSummarizeIndex: number;
  149. }
  150. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  151. export const BOT_HELLO: Message = createMessage({
  152. role: "assistant",
  153. content: Locale.Store.BotHello,
  154. });
  155. function createEmptySession(): ChatSession {
  156. const createDate = new Date().toLocaleString();
  157. return {
  158. id: Date.now(),
  159. topic: DEFAULT_TOPIC,
  160. sendMemory: true,
  161. memoryPrompt: "",
  162. context: [],
  163. messages: [],
  164. stat: {
  165. tokenCount: 0,
  166. wordCount: 0,
  167. charCount: 0,
  168. },
  169. lastUpdate: createDate,
  170. lastSummarizeIndex: 0,
  171. };
  172. }
  173. interface ChatStore {
  174. config: ChatConfig;
  175. sessions: ChatSession[];
  176. currentSessionIndex: number;
  177. clearSessions: () => void;
  178. removeSession: (index: number) => void;
  179. moveSession: (from: number, to: number) => void;
  180. selectSession: (index: number) => void;
  181. newSession: () => void;
  182. deleteSession: () => void;
  183. currentSession: () => ChatSession;
  184. onNewMessage: (message: Message) => void;
  185. onUserInput: (content: string) => Promise<void>;
  186. summarizeSession: () => void;
  187. updateStat: (message: Message) => void;
  188. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  189. updateMessage: (
  190. sessionIndex: number,
  191. messageIndex: number,
  192. updater: (message?: Message) => void,
  193. ) => void;
  194. resetSession: () => void;
  195. getMessagesWithMemory: () => Message[];
  196. getMemoryPrompt: () => Message;
  197. getConfig: () => ChatConfig;
  198. resetConfig: () => void;
  199. updateConfig: (updater: (config: ChatConfig) => void) => void;
  200. clearAllData: () => void;
  201. }
  202. function countMessages(msgs: Message[]) {
  203. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  204. }
  205. const LOCAL_KEY = "chat-next-web-store";
  206. export const useChatStore = create<ChatStore>()(
  207. persist(
  208. (set, get) => ({
  209. sessions: [createEmptySession()],
  210. currentSessionIndex: 0,
  211. config: {
  212. ...DEFAULT_CONFIG,
  213. },
  214. clearSessions() {
  215. set(() => ({
  216. sessions: [createEmptySession()],
  217. currentSessionIndex: 0,
  218. }));
  219. },
  220. resetConfig() {
  221. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  222. },
  223. getConfig() {
  224. return get().config;
  225. },
  226. updateConfig(updater) {
  227. const config = get().config;
  228. updater(config);
  229. set(() => ({ config }));
  230. },
  231. selectSession(index: number) {
  232. set({
  233. currentSessionIndex: index,
  234. });
  235. },
  236. removeSession(index: number) {
  237. set((state) => {
  238. let nextIndex = state.currentSessionIndex;
  239. const sessions = state.sessions;
  240. if (sessions.length === 1) {
  241. return {
  242. currentSessionIndex: 0,
  243. sessions: [createEmptySession()],
  244. };
  245. }
  246. sessions.splice(index, 1);
  247. if (nextIndex === index) {
  248. nextIndex -= 1;
  249. }
  250. return {
  251. currentSessionIndex: nextIndex,
  252. sessions,
  253. };
  254. });
  255. },
  256. moveSession(from: number, to: number) {
  257. set((state) => {
  258. const { sessions, currentSessionIndex: oldIndex } = state;
  259. // move the session
  260. const newSessions = [...sessions];
  261. const session = newSessions[from];
  262. newSessions.splice(from, 1);
  263. newSessions.splice(to, 0, session);
  264. // modify current session id
  265. let newIndex = oldIndex === from ? to : oldIndex;
  266. if (oldIndex > from && oldIndex <= to) {
  267. newIndex -= 1;
  268. } else if (oldIndex < from && oldIndex >= to) {
  269. newIndex += 1;
  270. }
  271. return {
  272. currentSessionIndex: newIndex,
  273. sessions: newSessions,
  274. };
  275. });
  276. },
  277. newSession() {
  278. set((state) => ({
  279. currentSessionIndex: 0,
  280. sessions: [createEmptySession()].concat(state.sessions),
  281. }));
  282. },
  283. deleteSession() {
  284. const deletedSession = get().currentSession();
  285. const index = get().currentSessionIndex;
  286. const isLastSession = get().sessions.length === 1;
  287. if (!isMobileScreen() || confirm(Locale.Home.DeleteChat)) {
  288. get().removeSession(index);
  289. showToast(Locale.Home.DeleteToast, {
  290. text: Locale.Home.Revert,
  291. onClick() {
  292. set((state) => ({
  293. sessions: state.sessions
  294. .slice(0, index)
  295. .concat([deletedSession])
  296. .concat(state.sessions.slice(index + Number(isLastSession))),
  297. }));
  298. },
  299. });
  300. }
  301. },
  302. currentSession() {
  303. let index = get().currentSessionIndex;
  304. const sessions = get().sessions;
  305. if (index < 0 || index >= sessions.length) {
  306. index = Math.min(sessions.length - 1, Math.max(0, index));
  307. set(() => ({ currentSessionIndex: index }));
  308. }
  309. const session = sessions[index];
  310. return session;
  311. },
  312. onNewMessage(message) {
  313. get().updateCurrentSession((session) => {
  314. session.lastUpdate = new Date().toLocaleString();
  315. });
  316. get().updateStat(message);
  317. get().summarizeSession();
  318. },
  319. async onUserInput(content) {
  320. const userMessage: Message = createMessage({
  321. role: "user",
  322. content,
  323. });
  324. const botMessage: Message = createMessage({
  325. role: "assistant",
  326. streaming: true,
  327. });
  328. // get recent messages
  329. const recentMessages = get().getMessagesWithMemory();
  330. const sendMessages = recentMessages.concat(userMessage);
  331. const sessionIndex = get().currentSessionIndex;
  332. const messageIndex = get().currentSession().messages.length + 1;
  333. // save user's and bot's message
  334. get().updateCurrentSession((session) => {
  335. session.messages.push(userMessage);
  336. session.messages.push(botMessage);
  337. });
  338. // make request
  339. console.log("[User Input] ", sendMessages);
  340. requestChatStream(sendMessages, {
  341. onMessage(content, done) {
  342. // stream response
  343. if (done) {
  344. botMessage.streaming = false;
  345. botMessage.content = content;
  346. get().onNewMessage(botMessage);
  347. ControllerPool.remove(
  348. sessionIndex,
  349. botMessage.id ?? messageIndex,
  350. );
  351. } else {
  352. botMessage.content = content;
  353. set(() => ({}));
  354. }
  355. },
  356. onError(error, statusCode) {
  357. if (statusCode === 401) {
  358. botMessage.content = Locale.Error.Unauthorized;
  359. } else {
  360. botMessage.content += "\n\n" + Locale.Store.Error;
  361. }
  362. botMessage.streaming = false;
  363. userMessage.isError = true;
  364. botMessage.isError = true;
  365. set(() => ({}));
  366. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  367. },
  368. onController(controller) {
  369. // collect controller for stop/retry
  370. ControllerPool.addController(
  371. sessionIndex,
  372. botMessage.id ?? messageIndex,
  373. controller,
  374. );
  375. },
  376. filterBot: !get().config.sendBotMessages,
  377. modelConfig: get().config.modelConfig,
  378. });
  379. },
  380. getMemoryPrompt() {
  381. const session = get().currentSession();
  382. return {
  383. role: "system",
  384. content: Locale.Store.Prompt.History(session.memoryPrompt),
  385. date: "",
  386. } as Message;
  387. },
  388. getMessagesWithMemory() {
  389. const session = get().currentSession();
  390. const config = get().config;
  391. const messages = session.messages.filter((msg) => !msg.isError);
  392. const n = messages.length;
  393. const context = session.context.slice();
  394. if (
  395. session.sendMemory &&
  396. session.memoryPrompt &&
  397. session.memoryPrompt.length > 0
  398. ) {
  399. const memoryPrompt = get().getMemoryPrompt();
  400. context.push(memoryPrompt);
  401. }
  402. const recentMessages = context.concat(
  403. messages.slice(Math.max(0, n - config.historyMessageCount)),
  404. );
  405. return recentMessages;
  406. },
  407. updateMessage(
  408. sessionIndex: number,
  409. messageIndex: number,
  410. updater: (message?: Message) => void,
  411. ) {
  412. const sessions = get().sessions;
  413. const session = sessions.at(sessionIndex);
  414. const messages = session?.messages;
  415. updater(messages?.at(messageIndex));
  416. set(() => ({ sessions }));
  417. },
  418. resetSession() {
  419. get().updateCurrentSession((session) => {
  420. session.messages = [];
  421. session.memoryPrompt = "";
  422. });
  423. },
  424. summarizeSession() {
  425. const session = get().currentSession();
  426. // should summarize topic after chating more than 50 words
  427. const SUMMARIZE_MIN_LEN = 50;
  428. if (
  429. session.topic === DEFAULT_TOPIC &&
  430. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  431. ) {
  432. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  433. (res) => {
  434. get().updateCurrentSession(
  435. (session) =>
  436. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  437. );
  438. },
  439. );
  440. }
  441. const config = get().config;
  442. let toBeSummarizedMsgs = session.messages.slice(
  443. session.lastSummarizeIndex,
  444. );
  445. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  446. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  447. const n = toBeSummarizedMsgs.length;
  448. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  449. Math.max(0, n - config.historyMessageCount),
  450. );
  451. }
  452. // add memory prompt
  453. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  454. const lastSummarizeIndex = session.messages.length;
  455. console.log(
  456. "[Chat History] ",
  457. toBeSummarizedMsgs,
  458. historyMsgLength,
  459. config.compressMessageLengthThreshold,
  460. );
  461. if (historyMsgLength > config.compressMessageLengthThreshold) {
  462. requestChatStream(
  463. toBeSummarizedMsgs.concat({
  464. role: "system",
  465. content: Locale.Store.Prompt.Summarize,
  466. date: "",
  467. }),
  468. {
  469. filterBot: false,
  470. onMessage(message, done) {
  471. session.memoryPrompt = message;
  472. if (done) {
  473. console.log("[Memory] ", session.memoryPrompt);
  474. session.lastSummarizeIndex = lastSummarizeIndex;
  475. }
  476. },
  477. onError(error) {
  478. console.error("[Summarize] ", error);
  479. },
  480. },
  481. );
  482. }
  483. },
  484. updateStat(message) {
  485. get().updateCurrentSession((session) => {
  486. session.stat.charCount += message.content.length;
  487. // TODO: should update chat count and word count
  488. });
  489. },
  490. updateCurrentSession(updater) {
  491. const sessions = get().sessions;
  492. const index = get().currentSessionIndex;
  493. updater(sessions[index]);
  494. set(() => ({ sessions }));
  495. },
  496. clearAllData() {
  497. if (confirm(Locale.Store.ConfirmClearAll)) {
  498. localStorage.clear();
  499. location.reload();
  500. }
  501. },
  502. }),
  503. {
  504. name: LOCAL_KEY,
  505. version: 1.2,
  506. migrate(persistedState, version) {
  507. const state = persistedState as ChatStore;
  508. if (version === 1) {
  509. state.sessions.forEach((s) => (s.context = []));
  510. }
  511. if (version < 1.2) {
  512. state.sessions.forEach((s) => (s.sendMemory = true));
  513. }
  514. return state;
  515. },
  516. },
  517. ),
  518. );