app.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { isMobileScreen, trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. import { showToast } from "../components/ui-lib";
  12. export type Message = ChatCompletionResponseMessage & {
  13. date: string;
  14. streaming?: boolean;
  15. isError?: boolean;
  16. id?: number;
  17. };
  18. export function createMessage(override: Partial<Message>): Message {
  19. return {
  20. id: Date.now(),
  21. date: new Date().toLocaleString(),
  22. role: "user",
  23. content: "",
  24. ...override,
  25. };
  26. }
  27. export enum SubmitKey {
  28. Enter = "Enter",
  29. CtrlEnter = "Ctrl + Enter",
  30. ShiftEnter = "Shift + Enter",
  31. AltEnter = "Alt + Enter",
  32. MetaEnter = "Meta + Enter",
  33. }
  34. export enum Theme {
  35. Auto = "auto",
  36. Dark = "dark",
  37. Light = "light",
  38. }
  39. export interface ChatConfig {
  40. historyMessageCount: number; // -1 means all
  41. compressMessageLengthThreshold: number;
  42. sendBotMessages: boolean; // send bot's message or not
  43. submitKey: SubmitKey;
  44. avatar: string;
  45. fontSize: number;
  46. theme: Theme;
  47. tightBorder: boolean;
  48. sendPreviewBubble: boolean;
  49. disablePromptHint: boolean;
  50. modelConfig: {
  51. model: string;
  52. temperature: number;
  53. max_tokens: number;
  54. presence_penalty: number;
  55. };
  56. }
  57. export type ModelConfig = ChatConfig["modelConfig"];
  58. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  59. const ENABLE_GPT4 = true;
  60. export const ALL_MODELS = [
  61. {
  62. name: "gpt-4",
  63. available: ENABLE_GPT4,
  64. },
  65. {
  66. name: "gpt-4-0314",
  67. available: ENABLE_GPT4,
  68. },
  69. {
  70. name: "gpt-4-32k",
  71. available: ENABLE_GPT4,
  72. },
  73. {
  74. name: "gpt-4-32k-0314",
  75. available: ENABLE_GPT4,
  76. },
  77. {
  78. name: "gpt-3.5-turbo",
  79. available: true,
  80. },
  81. {
  82. name: "gpt-3.5-turbo-0301",
  83. available: true,
  84. },
  85. ];
  86. export function limitNumber(
  87. x: number,
  88. min: number,
  89. max: number,
  90. defaultValue: number,
  91. ) {
  92. if (typeof x !== "number" || isNaN(x)) {
  93. return defaultValue;
  94. }
  95. return Math.min(max, Math.max(min, x));
  96. }
  97. export function limitModel(name: string) {
  98. return ALL_MODELS.some((m) => m.name === name && m.available)
  99. ? name
  100. : ALL_MODELS[4].name;
  101. }
  102. export const ModalConfigValidator = {
  103. model(x: string) {
  104. return limitModel(x);
  105. },
  106. max_tokens(x: number) {
  107. return limitNumber(x, 0, 32000, 2000);
  108. },
  109. presence_penalty(x: number) {
  110. return limitNumber(x, -2, 2, 0);
  111. },
  112. temperature(x: number) {
  113. return limitNumber(x, 0, 2, 1);
  114. },
  115. };
  116. const DEFAULT_CONFIG: ChatConfig = {
  117. historyMessageCount: 4,
  118. compressMessageLengthThreshold: 1000,
  119. sendBotMessages: true as boolean,
  120. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  121. avatar: "1f603",
  122. fontSize: 14,
  123. theme: Theme.Auto as Theme,
  124. tightBorder: false,
  125. sendPreviewBubble: true,
  126. disablePromptHint: false,
  127. modelConfig: {
  128. model: "gpt-3.5-turbo",
  129. temperature: 1,
  130. max_tokens: 2000,
  131. presence_penalty: 0,
  132. },
  133. };
  134. export interface ChatStat {
  135. tokenCount: number;
  136. wordCount: number;
  137. charCount: number;
  138. }
  139. export interface ChatSession {
  140. id: number;
  141. topic: string;
  142. sendMemory: boolean;
  143. memoryPrompt: string;
  144. context: Message[];
  145. messages: Message[];
  146. stat: ChatStat;
  147. lastUpdate: string;
  148. lastSummarizeIndex: number;
  149. }
  150. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  151. export const BOT_HELLO: Message = createMessage({
  152. role: "assistant",
  153. content: Locale.Store.BotHello,
  154. });
  155. function createEmptySession(): ChatSession {
  156. const createDate = new Date().toLocaleString();
  157. return {
  158. id: Date.now(),
  159. topic: DEFAULT_TOPIC,
  160. sendMemory: true,
  161. memoryPrompt: "",
  162. context: [],
  163. messages: [],
  164. stat: {
  165. tokenCount: 0,
  166. wordCount: 0,
  167. charCount: 0,
  168. },
  169. lastUpdate: createDate,
  170. lastSummarizeIndex: 0,
  171. };
  172. }
  173. interface ChatStore {
  174. config: ChatConfig;
  175. sessions: ChatSession[];
  176. currentSessionIndex: number;
  177. clearSessions: () => void;
  178. removeSession: (index: number) => void;
  179. moveSession: (from: number, to: number) => void;
  180. selectSession: (index: number) => void;
  181. newSession: () => void;
  182. deleteSession: (index?: number) => void;
  183. currentSession: () => ChatSession;
  184. onNewMessage: (message: Message) => void;
  185. onUserInput: (content: string) => Promise<void>;
  186. summarizeSession: () => void;
  187. updateStat: (message: Message) => void;
  188. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  189. updateMessage: (
  190. sessionIndex: number,
  191. messageIndex: number,
  192. updater: (message?: Message) => void,
  193. ) => void;
  194. resetSession: () => void;
  195. getMessagesWithMemory: () => Message[];
  196. getMemoryPrompt: () => Message;
  197. getConfig: () => ChatConfig;
  198. resetConfig: () => void;
  199. updateConfig: (updater: (config: ChatConfig) => void) => void;
  200. clearAllData: () => void;
  201. }
  202. function countMessages(msgs: Message[]) {
  203. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  204. }
  205. const LOCAL_KEY = "chat-next-web-store";
  206. export const useChatStore = create<ChatStore>()(
  207. persist(
  208. (set, get) => ({
  209. sessions: [createEmptySession()],
  210. currentSessionIndex: 0,
  211. config: {
  212. ...DEFAULT_CONFIG,
  213. },
  214. clearSessions() {
  215. set(() => ({
  216. sessions: [createEmptySession()],
  217. currentSessionIndex: 0,
  218. }));
  219. },
  220. resetConfig() {
  221. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  222. },
  223. getConfig() {
  224. return get().config;
  225. },
  226. updateConfig(updater) {
  227. const config = get().config;
  228. updater(config);
  229. set(() => ({ config }));
  230. },
  231. selectSession(index: number) {
  232. set({
  233. currentSessionIndex: index,
  234. });
  235. },
  236. removeSession(index: number) {
  237. set((state) => {
  238. let nextIndex = state.currentSessionIndex;
  239. const sessions = state.sessions;
  240. if (sessions.length === 1) {
  241. return {
  242. currentSessionIndex: 0,
  243. sessions: [createEmptySession()],
  244. };
  245. }
  246. sessions.splice(index, 1);
  247. if (nextIndex === index) {
  248. nextIndex -= 1;
  249. }
  250. return {
  251. currentSessionIndex: nextIndex,
  252. sessions,
  253. };
  254. });
  255. },
  256. moveSession(from: number, to: number) {
  257. set((state) => {
  258. const { sessions, currentSessionIndex: oldIndex } = state;
  259. // move the session
  260. const newSessions = [...sessions];
  261. const session = newSessions[from];
  262. newSessions.splice(from, 1);
  263. newSessions.splice(to, 0, session);
  264. // modify current session id
  265. let newIndex = oldIndex === from ? to : oldIndex;
  266. if (oldIndex > from && oldIndex <= to) {
  267. newIndex -= 1;
  268. } else if (oldIndex < from && oldIndex >= to) {
  269. newIndex += 1;
  270. }
  271. return {
  272. currentSessionIndex: newIndex,
  273. sessions: newSessions,
  274. };
  275. });
  276. },
  277. newSession() {
  278. set((state) => ({
  279. currentSessionIndex: 0,
  280. sessions: [createEmptySession()].concat(state.sessions),
  281. }));
  282. },
  283. deleteSession(i?: number) {
  284. const deletedSession = get().currentSession();
  285. const index = i ?? get().currentSessionIndex;
  286. const isLastSession = get().sessions.length === 1;
  287. if (!isMobileScreen() || confirm(Locale.Home.DeleteChat)) {
  288. get().removeSession(index);
  289. showToast(
  290. Locale.Home.DeleteToast,
  291. {
  292. text: Locale.Home.Revert,
  293. onClick() {
  294. set((state) => ({
  295. sessions: state.sessions
  296. .slice(0, index)
  297. .concat([deletedSession])
  298. .concat(
  299. state.sessions.slice(index + Number(isLastSession)),
  300. ),
  301. }));
  302. },
  303. },
  304. 5000,
  305. );
  306. }
  307. },
  308. currentSession() {
  309. let index = get().currentSessionIndex;
  310. const sessions = get().sessions;
  311. if (index < 0 || index >= sessions.length) {
  312. index = Math.min(sessions.length - 1, Math.max(0, index));
  313. set(() => ({ currentSessionIndex: index }));
  314. }
  315. const session = sessions[index];
  316. return session;
  317. },
  318. onNewMessage(message) {
  319. get().updateCurrentSession((session) => {
  320. session.lastUpdate = new Date().toLocaleString();
  321. });
  322. get().updateStat(message);
  323. get().summarizeSession();
  324. },
  325. async onUserInput(content) {
  326. const userMessage: Message = createMessage({
  327. role: "user",
  328. content,
  329. });
  330. const botMessage: Message = createMessage({
  331. role: "assistant",
  332. streaming: true,
  333. });
  334. // get recent messages
  335. const recentMessages = get().getMessagesWithMemory();
  336. const sendMessages = recentMessages.concat(userMessage);
  337. const sessionIndex = get().currentSessionIndex;
  338. const messageIndex = get().currentSession().messages.length + 1;
  339. // save user's and bot's message
  340. get().updateCurrentSession((session) => {
  341. session.messages.push(userMessage);
  342. session.messages.push(botMessage);
  343. });
  344. // make request
  345. console.log("[User Input] ", sendMessages);
  346. requestChatStream(sendMessages, {
  347. onMessage(content, done) {
  348. // stream response
  349. if (done) {
  350. botMessage.streaming = false;
  351. botMessage.content = content;
  352. get().onNewMessage(botMessage);
  353. ControllerPool.remove(
  354. sessionIndex,
  355. botMessage.id ?? messageIndex,
  356. );
  357. } else {
  358. botMessage.content = content;
  359. set(() => ({}));
  360. }
  361. },
  362. onError(error, statusCode) {
  363. if (statusCode === 401) {
  364. botMessage.content = Locale.Error.Unauthorized;
  365. } else {
  366. botMessage.content += "\n\n" + Locale.Store.Error;
  367. }
  368. botMessage.streaming = false;
  369. userMessage.isError = true;
  370. botMessage.isError = true;
  371. set(() => ({}));
  372. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  373. },
  374. onController(controller) {
  375. // collect controller for stop/retry
  376. ControllerPool.addController(
  377. sessionIndex,
  378. botMessage.id ?? messageIndex,
  379. controller,
  380. );
  381. },
  382. filterBot: !get().config.sendBotMessages,
  383. modelConfig: get().config.modelConfig,
  384. });
  385. },
  386. getMemoryPrompt() {
  387. const session = get().currentSession();
  388. return {
  389. role: "system",
  390. content: Locale.Store.Prompt.History(session.memoryPrompt),
  391. date: "",
  392. } as Message;
  393. },
  394. getMessagesWithMemory() {
  395. const session = get().currentSession();
  396. const config = get().config;
  397. const messages = session.messages.filter((msg) => !msg.isError);
  398. const n = messages.length;
  399. const context = session.context.slice();
  400. if (
  401. session.sendMemory &&
  402. session.memoryPrompt &&
  403. session.memoryPrompt.length > 0
  404. ) {
  405. const memoryPrompt = get().getMemoryPrompt();
  406. context.push(memoryPrompt);
  407. }
  408. const recentMessages = context.concat(
  409. messages.slice(Math.max(0, n - config.historyMessageCount)),
  410. );
  411. return recentMessages;
  412. },
  413. updateMessage(
  414. sessionIndex: number,
  415. messageIndex: number,
  416. updater: (message?: Message) => void,
  417. ) {
  418. const sessions = get().sessions;
  419. const session = sessions.at(sessionIndex);
  420. const messages = session?.messages;
  421. updater(messages?.at(messageIndex));
  422. set(() => ({ sessions }));
  423. },
  424. resetSession() {
  425. get().updateCurrentSession((session) => {
  426. session.messages = [];
  427. session.memoryPrompt = "";
  428. });
  429. },
  430. summarizeSession() {
  431. const session = get().currentSession();
  432. // should summarize topic after chating more than 50 words
  433. const SUMMARIZE_MIN_LEN = 50;
  434. if (
  435. session.topic === DEFAULT_TOPIC &&
  436. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  437. ) {
  438. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  439. (res) => {
  440. get().updateCurrentSession(
  441. (session) =>
  442. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  443. );
  444. },
  445. );
  446. }
  447. const config = get().config;
  448. let toBeSummarizedMsgs = session.messages.slice(
  449. session.lastSummarizeIndex,
  450. );
  451. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  452. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  453. const n = toBeSummarizedMsgs.length;
  454. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  455. Math.max(0, n - config.historyMessageCount),
  456. );
  457. }
  458. // add memory prompt
  459. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  460. const lastSummarizeIndex = session.messages.length;
  461. console.log(
  462. "[Chat History] ",
  463. toBeSummarizedMsgs,
  464. historyMsgLength,
  465. config.compressMessageLengthThreshold,
  466. );
  467. if (historyMsgLength > config.compressMessageLengthThreshold) {
  468. requestChatStream(
  469. toBeSummarizedMsgs.concat({
  470. role: "system",
  471. content: Locale.Store.Prompt.Summarize,
  472. date: "",
  473. }),
  474. {
  475. filterBot: false,
  476. onMessage(message, done) {
  477. session.memoryPrompt = message;
  478. if (done) {
  479. console.log("[Memory] ", session.memoryPrompt);
  480. session.lastSummarizeIndex = lastSummarizeIndex;
  481. }
  482. },
  483. onError(error) {
  484. console.error("[Summarize] ", error);
  485. },
  486. },
  487. );
  488. }
  489. },
  490. updateStat(message) {
  491. get().updateCurrentSession((session) => {
  492. session.stat.charCount += message.content.length;
  493. // TODO: should update chat count and word count
  494. });
  495. },
  496. updateCurrentSession(updater) {
  497. const sessions = get().sessions;
  498. const index = get().currentSessionIndex;
  499. updater(sessions[index]);
  500. set(() => ({ sessions }));
  501. },
  502. clearAllData() {
  503. if (confirm(Locale.Store.ConfirmClearAll)) {
  504. localStorage.clear();
  505. location.reload();
  506. }
  507. },
  508. }),
  509. {
  510. name: LOCAL_KEY,
  511. version: 1.2,
  512. migrate(persistedState, version) {
  513. const state = persistedState as ChatStore;
  514. if (version === 1) {
  515. state.sessions.forEach((s) => (s.context = []));
  516. }
  517. if (version < 1.2) {
  518. state.sessions.forEach((s) => (s.sendMemory = true));
  519. }
  520. return state;
  521. },
  522. },
  523. ),
  524. );