app.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. export type Message = ChatCompletionResponseMessage & {
  12. date: string;
  13. streaming?: boolean;
  14. isError?: boolean;
  15. };
  16. export enum SubmitKey {
  17. Enter = "Enter",
  18. CtrlEnter = "Ctrl + Enter",
  19. ShiftEnter = "Shift + Enter",
  20. AltEnter = "Alt + Enter",
  21. MetaEnter = "Meta + Enter",
  22. }
  23. export enum Theme {
  24. Auto = "auto",
  25. Dark = "dark",
  26. Light = "light",
  27. }
  28. export interface ChatConfig {
  29. historyMessageCount: number; // -1 means all
  30. compressMessageLengthThreshold: number;
  31. sendBotMessages: boolean; // send bot's message or not
  32. submitKey: SubmitKey;
  33. avatar: string;
  34. fontSize: number;
  35. theme: Theme;
  36. tightBorder: boolean;
  37. sendPreviewBubble: boolean;
  38. disablePromptHint: boolean;
  39. modelConfig: {
  40. model: string;
  41. temperature: number;
  42. max_tokens: number;
  43. presence_penalty: number;
  44. };
  45. }
  46. export type ModelConfig = ChatConfig["modelConfig"];
  47. export const ROLES: Message["role"][] = ["system", "user", "assistant"];
  48. const ENABLE_GPT4 = true;
  49. export const ALL_MODELS = [
  50. {
  51. name: "gpt-4",
  52. available: ENABLE_GPT4,
  53. },
  54. {
  55. name: "gpt-4-0314",
  56. available: ENABLE_GPT4,
  57. },
  58. {
  59. name: "gpt-4-32k",
  60. available: ENABLE_GPT4,
  61. },
  62. {
  63. name: "gpt-4-32k-0314",
  64. available: ENABLE_GPT4,
  65. },
  66. {
  67. name: "gpt-3.5-turbo",
  68. available: true,
  69. },
  70. {
  71. name: "gpt-3.5-turbo-0301",
  72. available: true,
  73. },
  74. ];
  75. export function isValidModel(name: string) {
  76. return ALL_MODELS.some((m) => m.name === name && m.available);
  77. }
  78. export function isValidNumber(x: number, min: number, max: number) {
  79. return typeof x === "number" && x <= max && x >= min;
  80. }
  81. export function filterConfig(oldConfig: ModelConfig): Partial<ModelConfig> {
  82. const config = Object.assign({}, oldConfig);
  83. const validator: {
  84. [k in keyof ModelConfig]: (x: ModelConfig[keyof ModelConfig]) => boolean;
  85. } = {
  86. model(x) {
  87. return isValidModel(x as string);
  88. },
  89. max_tokens(x) {
  90. return isValidNumber(x as number, 100, 4000);
  91. },
  92. presence_penalty(x) {
  93. return isValidNumber(x as number, -2, 2);
  94. },
  95. temperature(x) {
  96. return isValidNumber(x as number, 0, 2);
  97. },
  98. };
  99. Object.keys(validator).forEach((k) => {
  100. const key = k as keyof ModelConfig;
  101. if (!validator[key](config[key])) {
  102. delete config[key];
  103. }
  104. });
  105. return config;
  106. }
  107. const DEFAULT_CONFIG: ChatConfig = {
  108. historyMessageCount: 4,
  109. compressMessageLengthThreshold: 1000,
  110. sendBotMessages: true as boolean,
  111. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  112. avatar: "1f603",
  113. fontSize: 14,
  114. theme: Theme.Auto as Theme,
  115. tightBorder: false,
  116. sendPreviewBubble: true,
  117. disablePromptHint: false,
  118. modelConfig: {
  119. model: "gpt-3.5-turbo",
  120. temperature: 1,
  121. max_tokens: 2000,
  122. presence_penalty: 0,
  123. },
  124. };
  125. export interface ChatStat {
  126. tokenCount: number;
  127. wordCount: number;
  128. charCount: number;
  129. }
  130. export interface ChatSession {
  131. id: number;
  132. topic: string;
  133. memoryPrompt: string;
  134. context: Message[];
  135. messages: Message[];
  136. stat: ChatStat;
  137. lastUpdate: string;
  138. lastSummarizeIndex: number;
  139. }
  140. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  141. export const BOT_HELLO: Message = {
  142. role: "assistant",
  143. content: Locale.Store.BotHello,
  144. date: "",
  145. };
  146. function createEmptySession(): ChatSession {
  147. const createDate = new Date().toLocaleString();
  148. return {
  149. id: Date.now(),
  150. topic: DEFAULT_TOPIC,
  151. memoryPrompt: "",
  152. context: [],
  153. messages: [],
  154. stat: {
  155. tokenCount: 0,
  156. wordCount: 0,
  157. charCount: 0,
  158. },
  159. lastUpdate: createDate,
  160. lastSummarizeIndex: 0,
  161. };
  162. }
  163. interface ChatStore {
  164. config: ChatConfig;
  165. sessions: ChatSession[];
  166. currentSessionIndex: number;
  167. clearSessions: () => void;
  168. removeSession: (index: number) => void;
  169. selectSession: (index: number) => void;
  170. newSession: () => void;
  171. currentSession: () => ChatSession;
  172. onNewMessage: (message: Message) => void;
  173. onUserInput: (content: string) => Promise<void>;
  174. summarizeSession: () => void;
  175. updateStat: (message: Message) => void;
  176. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  177. updateMessage: (
  178. sessionIndex: number,
  179. messageIndex: number,
  180. updater: (message?: Message) => void,
  181. ) => void;
  182. getMessagesWithMemory: () => Message[];
  183. getMemoryPrompt: () => Message;
  184. getConfig: () => ChatConfig;
  185. resetConfig: () => void;
  186. updateConfig: (updater: (config: ChatConfig) => void) => void;
  187. clearAllData: () => void;
  188. }
  189. function countMessages(msgs: Message[]) {
  190. return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
  191. }
  192. const LOCAL_KEY = "chat-next-web-store";
  193. export const useChatStore = create<ChatStore>()(
  194. persist(
  195. (set, get) => ({
  196. sessions: [createEmptySession()],
  197. currentSessionIndex: 0,
  198. config: {
  199. ...DEFAULT_CONFIG,
  200. },
  201. clearSessions() {
  202. set(() => ({
  203. sessions: [createEmptySession()],
  204. currentSessionIndex: 0,
  205. }));
  206. },
  207. resetConfig() {
  208. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  209. },
  210. getConfig() {
  211. return get().config;
  212. },
  213. updateConfig(updater) {
  214. const config = get().config;
  215. updater(config);
  216. set(() => ({ config }));
  217. },
  218. selectSession(index: number) {
  219. set({
  220. currentSessionIndex: index,
  221. });
  222. },
  223. removeSession(index: number) {
  224. set((state) => {
  225. let nextIndex = state.currentSessionIndex;
  226. const sessions = state.sessions;
  227. if (sessions.length === 1) {
  228. return {
  229. currentSessionIndex: 0,
  230. sessions: [createEmptySession()],
  231. };
  232. }
  233. sessions.splice(index, 1);
  234. if (nextIndex === index) {
  235. nextIndex -= 1;
  236. }
  237. return {
  238. currentSessionIndex: nextIndex,
  239. sessions,
  240. };
  241. });
  242. },
  243. newSession() {
  244. set((state) => ({
  245. currentSessionIndex: 0,
  246. sessions: [createEmptySession()].concat(state.sessions),
  247. }));
  248. },
  249. currentSession() {
  250. let index = get().currentSessionIndex;
  251. const sessions = get().sessions;
  252. if (index < 0 || index >= sessions.length) {
  253. index = Math.min(sessions.length - 1, Math.max(0, index));
  254. set(() => ({ currentSessionIndex: index }));
  255. }
  256. const session = sessions[index];
  257. return session;
  258. },
  259. onNewMessage(message) {
  260. get().updateCurrentSession((session) => {
  261. session.lastUpdate = new Date().toLocaleString();
  262. });
  263. get().updateStat(message);
  264. get().summarizeSession();
  265. },
  266. async onUserInput(content) {
  267. const userMessage: Message = {
  268. role: "user",
  269. content,
  270. date: new Date().toLocaleString(),
  271. };
  272. const botMessage: Message = {
  273. content: "",
  274. role: "assistant",
  275. date: new Date().toLocaleString(),
  276. streaming: true,
  277. };
  278. // get recent messages
  279. const recentMessages = get().getMessagesWithMemory();
  280. const sendMessages = recentMessages.concat(userMessage);
  281. const sessionIndex = get().currentSessionIndex;
  282. const messageIndex = get().currentSession().messages.length + 1;
  283. // save user's and bot's message
  284. get().updateCurrentSession((session) => {
  285. session.messages.push(userMessage);
  286. session.messages.push(botMessage);
  287. });
  288. // make request
  289. console.log("[User Input] ", sendMessages);
  290. requestChatStream(sendMessages, {
  291. onMessage(content, done) {
  292. // stream response
  293. if (done) {
  294. botMessage.streaming = false;
  295. botMessage.content = content;
  296. get().onNewMessage(botMessage);
  297. ControllerPool.remove(sessionIndex, messageIndex);
  298. } else {
  299. botMessage.content = content;
  300. set(() => ({}));
  301. }
  302. },
  303. onError(error, statusCode) {
  304. if (statusCode === 401) {
  305. botMessage.content = Locale.Error.Unauthorized;
  306. } else {
  307. botMessage.content += "\n\n" + Locale.Store.Error;
  308. }
  309. botMessage.streaming = false;
  310. userMessage.isError = true;
  311. botMessage.isError = true;
  312. set(() => ({}));
  313. ControllerPool.remove(sessionIndex, messageIndex);
  314. },
  315. onController(controller) {
  316. // collect controller for stop/retry
  317. ControllerPool.addController(
  318. sessionIndex,
  319. messageIndex,
  320. controller,
  321. );
  322. },
  323. filterBot: !get().config.sendBotMessages,
  324. modelConfig: get().config.modelConfig,
  325. });
  326. },
  327. getMemoryPrompt() {
  328. const session = get().currentSession();
  329. return {
  330. role: "system",
  331. content: Locale.Store.Prompt.History(session.memoryPrompt),
  332. date: "",
  333. } as Message;
  334. },
  335. getMessagesWithMemory() {
  336. const session = get().currentSession();
  337. const config = get().config;
  338. const messages = session.messages.filter((msg) => !msg.isError);
  339. const n = messages.length;
  340. const context = session.context.slice();
  341. if (session.memoryPrompt && session.memoryPrompt.length > 0) {
  342. const memoryPrompt = get().getMemoryPrompt();
  343. context.push(memoryPrompt);
  344. }
  345. const recentMessages = context.concat(
  346. messages.slice(Math.max(0, n - config.historyMessageCount)),
  347. );
  348. return recentMessages;
  349. },
  350. updateMessage(
  351. sessionIndex: number,
  352. messageIndex: number,
  353. updater: (message?: Message) => void,
  354. ) {
  355. const sessions = get().sessions;
  356. const session = sessions.at(sessionIndex);
  357. const messages = session?.messages;
  358. updater(messages?.at(messageIndex));
  359. set(() => ({ sessions }));
  360. },
  361. summarizeSession() {
  362. const session = get().currentSession();
  363. // should summarize topic after chating more than 50 words
  364. const SUMMARIZE_MIN_LEN = 50;
  365. if (
  366. session.topic === DEFAULT_TOPIC &&
  367. countMessages(session.messages) >= SUMMARIZE_MIN_LEN
  368. ) {
  369. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  370. (res) => {
  371. get().updateCurrentSession(
  372. (session) => (session.topic = trimTopic(res)),
  373. );
  374. },
  375. );
  376. }
  377. const config = get().config;
  378. let toBeSummarizedMsgs = session.messages.slice(
  379. session.lastSummarizeIndex,
  380. );
  381. const historyMsgLength = countMessages(toBeSummarizedMsgs);
  382. if (historyMsgLength > get().config?.modelConfig?.max_tokens ?? 4000) {
  383. const n = toBeSummarizedMsgs.length;
  384. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  385. Math.max(0, n - config.historyMessageCount),
  386. );
  387. }
  388. // add memory prompt
  389. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  390. const lastSummarizeIndex = session.messages.length;
  391. console.log(
  392. "[Chat History] ",
  393. toBeSummarizedMsgs,
  394. historyMsgLength,
  395. config.compressMessageLengthThreshold,
  396. );
  397. if (historyMsgLength > config.compressMessageLengthThreshold) {
  398. requestChatStream(
  399. toBeSummarizedMsgs.concat({
  400. role: "system",
  401. content: Locale.Store.Prompt.Summarize,
  402. date: "",
  403. }),
  404. {
  405. filterBot: false,
  406. onMessage(message, done) {
  407. session.memoryPrompt = message;
  408. if (done) {
  409. console.log("[Memory] ", session.memoryPrompt);
  410. session.lastSummarizeIndex = lastSummarizeIndex;
  411. }
  412. },
  413. onError(error) {
  414. console.error("[Summarize] ", error);
  415. },
  416. },
  417. );
  418. }
  419. },
  420. updateStat(message) {
  421. get().updateCurrentSession((session) => {
  422. session.stat.charCount += message.content.length;
  423. // TODO: should update chat count and word count
  424. });
  425. },
  426. updateCurrentSession(updater) {
  427. const sessions = get().sessions;
  428. const index = get().currentSessionIndex;
  429. updater(sessions[index]);
  430. set(() => ({ sessions }));
  431. },
  432. clearAllData() {
  433. if (confirm(Locale.Store.ConfirmClearAll)) {
  434. localStorage.clear();
  435. location.reload();
  436. }
  437. },
  438. }),
  439. {
  440. name: LOCAL_KEY,
  441. version: 1.1,
  442. migrate(persistedState, version) {
  443. const state = persistedState as ChatStore;
  444. if (version === 1) {
  445. state.sessions.forEach((s) => (s.context = []));
  446. }
  447. return state;
  448. },
  449. },
  450. ),
  451. );