app.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. export type Message = ChatCompletionResponseMessage & {
  12. date: string;
  13. streaming?: boolean;
  14. isError?: boolean;
  15. id?: number;
  16. };
  17. export function createMessage(override: Partial<Message>): Message {
  18. return {
  19. id: Date.now(),
  20. date: new Date().toLocaleString(),
  21. role: "user",
  22. content: "",
  23. ...override,
  24. };
  25. }
  26. export enum SubmitKey {
  27. Enter = "Enter",
  28. CtrlEnter = "Ctrl + Enter",
  29. ShiftEnter = "Shift + Enter",
  30. AltEnter = "Alt + Enter",
  31. MetaEnter = "Meta + Enter",
  32. }
  33. export enum Theme {
  34. Auto = "auto",
  35. Dark = "dark",
  36. Light = "light",
  37. }
  38. export interface ChatConfig {
  39. historyMessageCount: number; // -1 means all
  40. compressMessageLengthThreshold: number;
  41. sendBotMessages: boolean; // send bot's message or not
  42. submitKey: SubmitKey;
  43. avatar: string;
  44. fontSize: number;
  45. theme: Theme;
  46. tightBorder: boolean;
  47. sendPreviewBubble: boolean;
  48. disablePromptHint: boolean;
  49. modelConfig: {
  50. model: string;
  51. temperature: number;
  52. max_tokens: number;
  53. presence_penalty: number;
  54. };
  55. }
  56. export type ModelConfig = ChatConfig["modelConfig"];
  57. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  58. const ENABLE_GPT4 = true;
  59. export const ALL_MODELS = [
  60. {
  61. name: "gpt-4",
  62. available: ENABLE_GPT4,
  63. },
  64. {
  65. name: "gpt-4-0314",
  66. available: ENABLE_GPT4,
  67. },
  68. {
  69. name: "gpt-4-32k",
  70. available: ENABLE_GPT4,
  71. },
  72. {
  73. name: "gpt-4-32k-0314",
  74. available: ENABLE_GPT4,
  75. },
  76. {
  77. name: "gpt-3.5-turbo",
  78. available: true,
  79. },
  80. {
  81. name: "gpt-3.5-turbo-0301",
  82. available: true,
  83. },
  84. ];
  85. export function limitNumber(
  86. x: number,
  87. min: number,
  88. max: number,
  89. defaultValue: number,
  90. ) {
  91. if (typeof x !== "number" || isNaN(x)) {
  92. return defaultValue;
  93. }
  94. return Math.min(max, Math.max(min, x));
  95. }
  96. export function limitModel(name: string) {
  97. return ALL_MODELS.some((m) => m.name === name && m.available)
  98. ? name
  99. : ALL_MODELS[4].name;
  100. }
  101. export const ModalConfigValidator = {
  102. model(x: string) {
  103. return limitModel(x);
  104. },
  105. max_tokens(x: number) {
  106. return limitNumber(x, 0, 32000, 2000);
  107. },
  108. presence_penalty(x: number) {
  109. return limitNumber(x, -2, 2, 0);
  110. },
  111. temperature(x: number) {
  112. return limitNumber(x, 0, 2, 1);
  113. },
  114. };
  115. const DEFAULT_CONFIG: ChatConfig = {
  116. historyMessageCount: 4,
  117. compressMessageLengthThreshold: 1000,
  118. sendBotMessages: true as boolean,
  119. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  120. avatar: "1f603",
  121. fontSize: 14,
  122. theme: Theme.Auto as Theme,
  123. tightBorder: false,
  124. sendPreviewBubble: true,
  125. disablePromptHint: false,
  126. modelConfig: {
  127. model: "gpt-3.5-turbo",
  128. temperature: 1,
  129. max_tokens: 2000,
  130. presence_penalty: 0,
  131. },
  132. };
  133. export interface ChatStat {
  134. tokenCount: number;
  135. wordCount: number;
  136. charCount: number;
  137. }
  138. export interface ChatSession {
  139. id: number;
  140. topic: string;
  141. sendMemory: boolean;
  142. memoryPrompt: string;
  143. context: Message[];
  144. messages: Message[];
  145. stat: ChatStat;
  146. lastUpdate: string;
  147. lastSummarizeIndex: number;
  148. }
  149. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  150. export const BOT_HELLO: Message = createMessage({
  151. role: "assistant",
  152. content: Locale.Store.BotHello,
  153. });
  154. function createEmptySession(): ChatSession {
  155. const createDate = new Date().toLocaleString();
  156. return {
  157. id: Date.now(),
  158. topic: DEFAULT_TOPIC,
  159. sendMemory: true,
  160. memoryPrompt: "",
  161. context: [],
  162. messages: [],
  163. stat: {
  164. tokenCount: 0,
  165. wordCount: 0,
  166. charCount: 0,
  167. },
  168. lastUpdate: createDate,
  169. lastSummarizeIndex: 0,
  170. };
  171. }
  172. interface ChatStore {
  173. config: ChatConfig;
  174. sessions: ChatSession[];
  175. currentSessionIndex: number;
  176. clearSessions: () => void;
  177. removeSession: (index: number) => void;
  178. moveSession: (from: number, to: number) => void;
  179. selectSession: (index: number) => void;
  180. newSession: () => void;
  181. currentSession: () => ChatSession;
  182. onNewMessage: (message: Message) => void;
  183. onUserInput: (content: string) => Promise<void>;
  184. summarizeSession: () => void;
  185. updateStat: (message: Message) => void;
  186. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  187. updateMessage: (
  188. sessionIndex: number,
  189. messageIndex: number,
  190. updater: (message?: Message) => void,
  191. ) => void;
  192. resetSession: () => void;
  193. getMessagesWithMemory: () => Message[];
  194. getMemoryPrompt: () => Message;
  195. getConfig: () => ChatConfig;
  196. resetConfig: () => void;
  197. updateConfig: (updater: (config: ChatConfig) => void) => void;
  198. clearAllData: () => void;
  199. }
  200. function countMessages(msgs: Message[]) {
  201. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  202. }
  203. const LOCAL_KEY = "chat-next-web-store";
  204. export const useChatStore = create<ChatStore>()(
  205. persist(
  206. (set, get) => ({
  207. sessions: [createEmptySession()],
  208. currentSessionIndex: 0,
  209. config: {
  210. ...DEFAULT_CONFIG,
  211. },
  212. clearSessions() {
  213. set(() => ({
  214. sessions: [createEmptySession()],
  215. currentSessionIndex: 0,
  216. }));
  217. },
  218. resetConfig() {
  219. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  220. },
  221. getConfig() {
  222. return get().config;
  223. },
  224. updateConfig(updater) {
  225. const config = get().config;
  226. updater(config);
  227. set(() => ({ config }));
  228. },
  229. selectSession(index: number) {
  230. set({
  231. currentSessionIndex: index,
  232. });
  233. },
  234. removeSession(index: number) {
  235. set((state) => {
  236. let nextIndex = state.currentSessionIndex;
  237. const sessions = state.sessions;
  238. if (sessions.length === 1) {
  239. return {
  240. currentSessionIndex: 0,
  241. sessions: [createEmptySession()],
  242. };
  243. }
  244. sessions.splice(index, 1);
  245. if (nextIndex === index) {
  246. nextIndex -= 1;
  247. }
  248. return {
  249. currentSessionIndex: nextIndex,
  250. sessions,
  251. };
  252. });
  253. },
  254. moveSession(from: number, to: number) {
  255. set((state) => {
  256. const { sessions, currentSessionIndex: oldIndex } = state;
  257. // move the session
  258. const newSessions = [...sessions];
  259. const session = newSessions[from];
  260. newSessions.splice(from, 1);
  261. newSessions.splice(to, 0, session);
  262. // modify current session id
  263. let newIndex = oldIndex === from ? to : oldIndex;
  264. if (oldIndex > from && oldIndex <= to) {
  265. newIndex -= 1;
  266. } else if (oldIndex < from && oldIndex >= to) {
  267. newIndex += 1;
  268. }
  269. return {
  270. currentSessionIndex: newIndex,
  271. sessions: newSessions,
  272. };
  273. });
  274. },
  275. newSession() {
  276. set((state) => ({
  277. currentSessionIndex: 0,
  278. sessions: [createEmptySession()].concat(state.sessions),
  279. }));
  280. },
  281. currentSession() {
  282. let index = get().currentSessionIndex;
  283. const sessions = get().sessions;
  284. if (index < 0 || index >= sessions.length) {
  285. index = Math.min(sessions.length - 1, Math.max(0, index));
  286. set(() => ({ currentSessionIndex: index }));
  287. }
  288. const session = sessions[index];
  289. return session;
  290. },
  291. onNewMessage(message) {
  292. get().updateCurrentSession((session) => {
  293. session.lastUpdate = new Date().toLocaleString();
  294. });
  295. get().updateStat(message);
  296. get().summarizeSession();
  297. },
  298. async onUserInput(content) {
  299. const userMessage: Message = createMessage({
  300. role: "user",
  301. content,
  302. });
  303. const botMessage: Message = createMessage({
  304. role: "assistant",
  305. streaming: true,
  306. });
  307. // get recent messages
  308. const recentMessages = get().getMessagesWithMemory();
  309. const sendMessages = recentMessages.concat(userMessage);
  310. const sessionIndex = get().currentSessionIndex;
  311. const messageIndex = get().currentSession().messages.length + 1;
  312. // save user's and bot's message
  313. get().updateCurrentSession((session) => {
  314. session.messages.push(userMessage);
  315. session.messages.push(botMessage);
  316. });
  317. // make request
  318. console.log("[User Input] ", sendMessages);
  319. requestChatStream(sendMessages, {
  320. onMessage(content, done) {
  321. // stream response
  322. if (done) {
  323. botMessage.streaming = false;
  324. botMessage.content = content;
  325. get().onNewMessage(botMessage);
  326. ControllerPool.remove(
  327. sessionIndex,
  328. botMessage.id ?? messageIndex,
  329. );
  330. } else {
  331. botMessage.content = content;
  332. set(() => ({}));
  333. }
  334. },
  335. onError(error, statusCode) {
  336. if (statusCode === 401) {
  337. botMessage.content = Locale.Error.Unauthorized;
  338. } else {
  339. botMessage.content += "\n\n" + Locale.Store.Error;
  340. }
  341. botMessage.streaming = false;
  342. userMessage.isError = true;
  343. botMessage.isError = true;
  344. set(() => ({}));
  345. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  346. },
  347. onController(controller) {
  348. // collect controller for stop/retry
  349. ControllerPool.addController(
  350. sessionIndex,
  351. botMessage.id ?? messageIndex,
  352. controller,
  353. );
  354. },
  355. filterBot: !get().config.sendBotMessages,
  356. modelConfig: get().config.modelConfig,
  357. });
  358. },
  359. getMemoryPrompt() {
  360. const session = get().currentSession();
  361. return {
  362. role: "system",
  363. content: Locale.Store.Prompt.History(session.memoryPrompt),
  364. date: "",
  365. } as Message;
  366. },
  367. getMessagesWithMemory() {
  368. const session = get().currentSession();
  369. const config = get().config;
  370. const messages = session.messages.filter((msg) => !msg.isError);
  371. const n = messages.length;
  372. const context = session.context.slice();
  373. if (
  374. session.sendMemory &&
  375. session.memoryPrompt &&
  376. session.memoryPrompt.length > 0
  377. ) {
  378. const memoryPrompt = get().getMemoryPrompt();
  379. context.push(memoryPrompt);
  380. }
  381. const recentMessages = context.concat(
  382. messages.slice(Math.max(0, n - config.historyMessageCount)),
  383. );
  384. return recentMessages;
  385. },
  386. updateMessage(
  387. sessionIndex: number,
  388. messageIndex: number,
  389. updater: (message?: Message) => void,
  390. ) {
  391. const sessions = get().sessions;
  392. const session = sessions.at(sessionIndex);
  393. const messages = session?.messages;
  394. updater(messages?.at(messageIndex));
  395. set(() => ({ sessions }));
  396. },
  397. resetSession() {
  398. get().updateCurrentSession((session) => {
  399. session.messages = [];
  400. session.memoryPrompt = "";
  401. });
  402. },
  403. summarizeSession() {
  404. const session = get().currentSession();
  405. // should summarize topic after chating more than 50 words
  406. const SUMMARIZE_MIN_LEN = 50;
  407. if (
  408. session.topic === DEFAULT_TOPIC &&
  409. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  410. ) {
  411. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  412. (res) => {
  413. get().updateCurrentSession(
  414. (session) =>
  415. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  416. );
  417. },
  418. );
  419. }
  420. const config = get().config;
  421. let toBeSummarizedMsgs = session.messages.slice(
  422. session.lastSummarizeIndex,
  423. );
  424. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  425. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  426. const n = toBeSummarizedMsgs.length;
  427. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  428. Math.max(0, n - config.historyMessageCount),
  429. );
  430. }
  431. // add memory prompt
  432. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  433. const lastSummarizeIndex = session.messages.length;
  434. console.log(
  435. "[Chat History] ",
  436. toBeSummarizedMsgs,
  437. historyMsgLength,
  438. config.compressMessageLengthThreshold,
  439. );
  440. if (historyMsgLength > config.compressMessageLengthThreshold) {
  441. requestChatStream(
  442. toBeSummarizedMsgs.concat({
  443. role: "system",
  444. content: Locale.Store.Prompt.Summarize,
  445. date: "",
  446. }),
  447. {
  448. filterBot: false,
  449. onMessage(message, done) {
  450. session.memoryPrompt = message;
  451. if (done) {
  452. console.log("[Memory] ", session.memoryPrompt);
  453. session.lastSummarizeIndex = lastSummarizeIndex;
  454. }
  455. },
  456. onError(error) {
  457. console.error("[Summarize] ", error);
  458. },
  459. },
  460. );
  461. }
  462. },
  463. updateStat(message) {
  464. get().updateCurrentSession((session) => {
  465. session.stat.charCount += message.content.length;
  466. // TODO: should update chat count and word count
  467. });
  468. },
  469. updateCurrentSession(updater) {
  470. const sessions = get().sessions;
  471. const index = get().currentSessionIndex;
  472. updater(sessions[index]);
  473. set(() => ({ sessions }));
  474. },
  475. clearAllData() {
  476. if (confirm(Locale.Store.ConfirmClearAll)) {
  477. localStorage.clear();
  478. location.reload();
  479. }
  480. },
  481. }),
  482. {
  483. name: LOCAL_KEY,
  484. version: 1.2,
  485. migrate(persistedState, version) {
  486. const state = persistedState as ChatStore;
  487. if (version === 1) {
  488. state.sessions.forEach((s) => (s.context = []));
  489. }
  490. if (version < 1.2) {
  491. state.sessions.forEach((s) => (s.sendMemory = true));
  492. }
  493. return state;
  494. },
  495. },
  496. ),
  497. );