app.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. export type Message = ChatCompletionResponseMessage & {
  12. date: string;
  13. streaming?: boolean;
  14. isError?: boolean;
  15. id?: number;
  16. };
  17. export function createMessage(override: Partial<Message>): Message {
  18. return {
  19. id: Date.now(),
  20. date: new Date().toLocaleString(),
  21. role: "user",
  22. content: "",
  23. ...override,
  24. };
  25. }
  26. export enum SubmitKey {
  27. Enter = "Enter",
  28. CtrlEnter = "Ctrl + Enter",
  29. ShiftEnter = "Shift + Enter",
  30. AltEnter = "Alt + Enter",
  31. MetaEnter = "Meta + Enter",
  32. }
  33. export enum Theme {
  34. Auto = "auto",
  35. Dark = "dark",
  36. Light = "light",
  37. }
  38. export interface ChatConfig {
  39. historyMessageCount: number; // -1 means all
  40. compressMessageLengthThreshold: number;
  41. sendBotMessages: boolean; // send bot's message or not
  42. submitKey: SubmitKey;
  43. avatar: string;
  44. fontSize: number;
  45. theme: Theme;
  46. tightBorder: boolean;
  47. sendPreviewBubble: boolean;
  48. disablePromptHint: boolean;
  49. modelConfig: {
  50. model: string;
  51. temperature: number;
  52. max_tokens: number;
  53. presence_penalty: number;
  54. };
  55. }
  56. export type ModelConfig = ChatConfig["modelConfig"];
  57. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  58. const ENABLE_GPT4 = true;
  59. export const ALL_MODELS = [
  60. {
  61. name: "gpt-4",
  62. available: ENABLE_GPT4,
  63. },
  64. {
  65. name: "gpt-4-0314",
  66. available: ENABLE_GPT4,
  67. },
  68. {
  69. name: "gpt-4-32k",
  70. available: ENABLE_GPT4,
  71. },
  72. {
  73. name: "gpt-4-32k-0314",
  74. available: ENABLE_GPT4,
  75. },
  76. {
  77. name: "gpt-3.5-turbo",
  78. available: true,
  79. },
  80. {
  81. name: "gpt-3.5-turbo-0301",
  82. available: true,
  83. },
  84. ];
  85. export function limitNumber(
  86. x: number,
  87. min: number,
  88. max: number,
  89. defaultValue: number,
  90. ) {
  91. if (typeof x !== "number" || isNaN(x)) {
  92. return defaultValue;
  93. }
  94. return Math.min(max, Math.max(min, x));
  95. }
  96. export function limitModel(name: string) {
  97. return ALL_MODELS.some((m) => m.name === name && m.available)
  98. ? name
  99. : ALL_MODELS[4].name;
  100. }
  101. export const ModalConfigValidator = {
  102. model(x: string) {
  103. return limitModel(x);
  104. },
  105. max_tokens(x: number) {
  106. return limitNumber(x, 0, 32000, 2000);
  107. },
  108. presence_penalty(x: number) {
  109. return limitNumber(x, -2, 2, 0);
  110. },
  111. temperature(x: number) {
  112. return limitNumber(x, 0, 2, 1);
  113. },
  114. };
  115. const DEFAULT_CONFIG: ChatConfig = {
  116. historyMessageCount: 4,
  117. compressMessageLengthThreshold: 1000,
  118. sendBotMessages: true as boolean,
  119. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  120. avatar: "1f603",
  121. fontSize: 14,
  122. theme: Theme.Auto as Theme,
  123. tightBorder: false,
  124. sendPreviewBubble: true,
  125. disablePromptHint: false,
  126. modelConfig: {
  127. model: "gpt-3.5-turbo",
  128. temperature: 1,
  129. max_tokens: 2000,
  130. presence_penalty: 0,
  131. },
  132. };
  133. export interface ChatStat {
  134. tokenCount: number;
  135. wordCount: number;
  136. charCount: number;
  137. }
  138. export interface ChatSession {
  139. id: number;
  140. topic: string;
  141. sendMemory: boolean;
  142. memoryPrompt: string;
  143. context: Message[];
  144. messages: Message[];
  145. stat: ChatStat;
  146. lastUpdate: string;
  147. lastSummarizeIndex: number;
  148. }
  149. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  150. export const BOT_HELLO: Message = createMessage({
  151. role: "assistant",
  152. content: Locale.Store.BotHello,
  153. });
  154. function createEmptySession(): ChatSession {
  155. const createDate = new Date().toLocaleString();
  156. return {
  157. id: Date.now(),
  158. topic: DEFAULT_TOPIC,
  159. sendMemory: true,
  160. memoryPrompt: "",
  161. context: [],
  162. messages: [],
  163. stat: {
  164. tokenCount: 0,
  165. wordCount: 0,
  166. charCount: 0,
  167. },
  168. lastUpdate: createDate,
  169. lastSummarizeIndex: 0,
  170. };
  171. }
  172. interface ChatStore {
  173. config: ChatConfig;
  174. sessions: ChatSession[];
  175. currentSessionIndex: number;
  176. clearSessions: () => void;
  177. removeSession: (index: number) => void;
  178. selectSession: (index: number) => void;
  179. newSession: () => void;
  180. currentSession: () => ChatSession;
  181. onNewMessage: (message: Message) => void;
  182. onUserInput: (content: string) => Promise<void>;
  183. summarizeSession: () => void;
  184. updateStat: (message: Message) => void;
  185. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  186. updateMessage: (
  187. sessionIndex: number,
  188. messageIndex: number,
  189. updater: (message?: Message) => void,
  190. ) => void;
  191. resetSession: () => void;
  192. getMessagesWithMemory: () => Message[];
  193. getMemoryPrompt: () => Message;
  194. getConfig: () => ChatConfig;
  195. resetConfig: () => void;
  196. updateConfig: (updater: (config: ChatConfig) => void) => void;
  197. clearAllData: () => void;
  198. }
  199. function countMessages(msgs: Message[]) {
  200. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  201. }
  202. const LOCAL_KEY = "chat-next-web-store";
  203. export const useChatStore = create<ChatStore>()(
  204. persist(
  205. (set, get) => ({
  206. sessions: [createEmptySession()],
  207. currentSessionIndex: 0,
  208. config: {
  209. ...DEFAULT_CONFIG,
  210. },
  211. clearSessions() {
  212. set(() => ({
  213. sessions: [createEmptySession()],
  214. currentSessionIndex: 0,
  215. }));
  216. },
  217. resetConfig() {
  218. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  219. },
  220. getConfig() {
  221. return get().config;
  222. },
  223. updateConfig(updater) {
  224. const config = get().config;
  225. updater(config);
  226. set(() => ({ config }));
  227. },
  228. selectSession(index: number) {
  229. set({
  230. currentSessionIndex: index,
  231. });
  232. },
  233. removeSession(index: number) {
  234. set((state) => {
  235. let nextIndex = state.currentSessionIndex;
  236. const sessions = state.sessions;
  237. if (sessions.length === 1) {
  238. return {
  239. currentSessionIndex: 0,
  240. sessions: [createEmptySession()],
  241. };
  242. }
  243. sessions.splice(index, 1);
  244. if (nextIndex === index) {
  245. nextIndex -= 1;
  246. }
  247. return {
  248. currentSessionIndex: nextIndex,
  249. sessions,
  250. };
  251. });
  252. },
  253. newSession() {
  254. set((state) => ({
  255. currentSessionIndex: 0,
  256. sessions: [createEmptySession()].concat(state.sessions),
  257. }));
  258. },
  259. currentSession() {
  260. let index = get().currentSessionIndex;
  261. const sessions = get().sessions;
  262. if (index < 0 || index >= sessions.length) {
  263. index = Math.min(sessions.length - 1, Math.max(0, index));
  264. set(() => ({ currentSessionIndex: index }));
  265. }
  266. const session = sessions[index];
  267. return session;
  268. },
  269. onNewMessage(message) {
  270. get().updateCurrentSession((session) => {
  271. session.lastUpdate = new Date().toLocaleString();
  272. });
  273. get().updateStat(message);
  274. get().summarizeSession();
  275. },
  276. async onUserInput(content) {
  277. const userMessage: Message = createMessage({
  278. role: "user",
  279. content,
  280. });
  281. const botMessage: Message = createMessage({
  282. role: "assistant",
  283. streaming: true,
  284. });
  285. // get recent messages
  286. const recentMessages = get().getMessagesWithMemory();
  287. const sendMessages = recentMessages.concat(userMessage);
  288. const sessionIndex = get().currentSessionIndex;
  289. const messageIndex = get().currentSession().messages.length + 1;
  290. // save user's and bot's message
  291. get().updateCurrentSession((session) => {
  292. session.messages.push(userMessage);
  293. session.messages.push(botMessage);
  294. });
  295. // make request
  296. console.log("[User Input] ", sendMessages);
  297. requestChatStream(sendMessages, {
  298. onMessage(content, done) {
  299. // stream response
  300. if (done) {
  301. botMessage.streaming = false;
  302. botMessage.content = content;
  303. get().onNewMessage(botMessage);
  304. ControllerPool.remove(
  305. sessionIndex,
  306. botMessage.id ?? messageIndex,
  307. );
  308. } else {
  309. botMessage.content = content;
  310. set(() => ({}));
  311. }
  312. },
  313. onError(error, statusCode) {
  314. if (statusCode === 401) {
  315. botMessage.content = Locale.Error.Unauthorized;
  316. } else {
  317. botMessage.content += "\n\n" + Locale.Store.Error;
  318. }
  319. botMessage.streaming = false;
  320. userMessage.isError = true;
  321. botMessage.isError = true;
  322. set(() => ({}));
  323. ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
  324. },
  325. onController(controller) {
  326. // collect controller for stop/retry
  327. ControllerPool.addController(
  328. sessionIndex,
  329. botMessage.id ?? messageIndex,
  330. controller,
  331. );
  332. },
  333. filterBot: !get().config.sendBotMessages,
  334. modelConfig: get().config.modelConfig,
  335. });
  336. },
  337. getMemoryPrompt() {
  338. const session = get().currentSession();
  339. return {
  340. role: "system",
  341. content: Locale.Store.Prompt.History(session.memoryPrompt),
  342. date: "",
  343. } as Message;
  344. },
  345. getMessagesWithMemory() {
  346. const session = get().currentSession();
  347. const config = get().config;
  348. const messages = session.messages.filter((msg) => !msg.isError);
  349. const n = messages.length;
  350. const context = session.context.slice();
  351. if (
  352. session.sendMemory &&
  353. session.memoryPrompt &&
  354. session.memoryPrompt.length > 0
  355. ) {
  356. const memoryPrompt = get().getMemoryPrompt();
  357. context.push(memoryPrompt);
  358. }
  359. const recentMessages = context.concat(
  360. messages.slice(Math.max(0, n - config.historyMessageCount)),
  361. );
  362. return recentMessages;
  363. },
  364. updateMessage(
  365. sessionIndex: number,
  366. messageIndex: number,
  367. updater: (message?: Message) => void,
  368. ) {
  369. const sessions = get().sessions;
  370. const session = sessions.at(sessionIndex);
  371. const messages = session?.messages;
  372. updater(messages?.at(messageIndex));
  373. set(() => ({ sessions }));
  374. },
  375. resetSession() {
  376. get().updateCurrentSession((session) => {
  377. session.messages = [];
  378. session.memoryPrompt = "";
  379. });
  380. },
  381. summarizeSession() {
  382. const session = get().currentSession();
  383. // should summarize topic after chating more than 50 words
  384. const SUMMARIZE_MIN_LEN = 50;
  385. if (
  386. session.topic === DEFAULT_TOPIC &&
  387. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  388. ) {
  389. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  390. (res) => {
  391. get().updateCurrentSession(
  392. (session) =>
  393. (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
  394. );
  395. },
  396. );
  397. }
  398. const config = get().config;
  399. let toBeSummarizedMsgs = session.messages.slice(
  400. session.lastSummarizeIndex,
  401. );
  402. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  403. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  404. const n = toBeSummarizedMsgs.length;
  405. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  406. Math.max(0, n - config.historyMessageCount),
  407. );
  408. }
  409. // add memory prompt
  410. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  411. const lastSummarizeIndex = session.messages.length;
  412. console.log(
  413. "[Chat History] ",
  414. toBeSummarizedMsgs,
  415. historyMsgLength,
  416. config.compressMessageLengthThreshold,
  417. );
  418. if (historyMsgLength > config.compressMessageLengthThreshold) {
  419. requestChatStream(
  420. toBeSummarizedMsgs.concat({
  421. role: "system",
  422. content: Locale.Store.Prompt.Summarize,
  423. date: "",
  424. }),
  425. {
  426. filterBot: false,
  427. onMessage(message, done) {
  428. session.memoryPrompt = message;
  429. if (done) {
  430. console.log("[Memory] ", session.memoryPrompt);
  431. session.lastSummarizeIndex = lastSummarizeIndex;
  432. }
  433. },
  434. onError(error) {
  435. console.error("[Summarize] ", error);
  436. },
  437. },
  438. );
  439. }
  440. },
  441. updateStat(message) {
  442. get().updateCurrentSession((session) => {
  443. session.stat.charCount += message.content.length;
  444. // TODO: should update chat count and word count
  445. });
  446. },
  447. updateCurrentSession(updater) {
  448. const sessions = get().sessions;
  449. const index = get().currentSessionIndex;
  450. updater(sessions[index]);
  451. set(() => ({ sessions }));
  452. },
  453. clearAllData() {
  454. if (confirm(Locale.Store.ConfirmClearAll)) {
  455. localStorage.clear();
  456. location.reload();
  457. }
  458. },
  459. }),
  460. {
  461. name: LOCAL_KEY,
  462. version: 1.2,
  463. migrate(persistedState, version) {
  464. const state = persistedState as ChatStore;
  465. if (version === 1) {
  466. state.sessions.forEach((s) => (s.context = []));
  467. }
  468. if (version < 1.2) {
  469. state.sessions.forEach((s) => (s.sendMemory = true));
  470. }
  471. return state;
  472. },
  473. },
  474. ),
  475. );