app.ts 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import {
  5. ControllerPool,
  6. requestChatStream,
  7. requestWithPrompt,
  8. } from "../requests";
  9. import { trimTopic } from "../utils";
  10. import Locale from "../locales";
  11. export type Message = ChatCompletionResponseMessage & {
  12. date: string;
  13. streaming?: boolean;
  14. };
  15. export enum SubmitKey {
  16. Enter = "Enter",
  17. CtrlEnter = "Ctrl + Enter",
  18. ShiftEnter = "Shift + Enter",
  19. AltEnter = "Alt + Enter",
  20. }
  21. export enum Theme {
  22. Auto = "auto",
  23. Dark = "dark",
  24. Light = "light",
  25. }
  26. export interface ChatConfig {
  27. maxToken?: number;
  28. historyMessageCount: number; // -1 means all
  29. compressMessageLengthThreshold: number;
  30. sendBotMessages: boolean; // send bot's message or not
  31. submitKey: SubmitKey;
  32. avatar: string;
  33. theme: Theme;
  34. tightBorder: boolean;
  35. modelConfig: {
  36. model: string;
  37. temperature: number;
  38. max_tokens: number;
  39. presence_penalty: number;
  40. };
  41. }
  42. export type ModelConfig = ChatConfig["modelConfig"];
  43. const ENABLE_GPT4 = true;
  44. export const ALL_MODELS = [
  45. {
  46. name: "gpt-4",
  47. available: ENABLE_GPT4,
  48. },
  49. {
  50. name: "gpt-4-0314",
  51. available: ENABLE_GPT4,
  52. },
  53. {
  54. name: "gpt-4-32k",
  55. available: ENABLE_GPT4,
  56. },
  57. {
  58. name: "gpt-4-32k-0314",
  59. available: ENABLE_GPT4,
  60. },
  61. {
  62. name: "gpt-3.5-turbo",
  63. available: true,
  64. },
  65. {
  66. name: "gpt-3.5-turbo-0301",
  67. available: true,
  68. },
  69. ];
  70. export function isValidModel(name: string) {
  71. return ALL_MODELS.some((m) => m.name === name && m.available);
  72. }
  73. export function isValidNumber(x: number, min: number, max: number) {
  74. return typeof x === "number" && x <= max && x >= min;
  75. }
  76. export function filterConfig(config: ModelConfig): Partial<ModelConfig> {
  77. const validator: {
  78. [k in keyof ModelConfig]: (x: ModelConfig[keyof ModelConfig]) => boolean;
  79. } = {
  80. model(x) {
  81. return isValidModel(x as string);
  82. },
  83. max_tokens(x) {
  84. return isValidNumber(x as number, 100, 4000);
  85. },
  86. presence_penalty(x) {
  87. return isValidNumber(x as number, -2, 2);
  88. },
  89. temperature(x) {
  90. return isValidNumber(x as number, 0, 1);
  91. },
  92. };
  93. Object.keys(validator).forEach((k) => {
  94. const key = k as keyof ModelConfig;
  95. if (!validator[key](config[key])) {
  96. delete config[key];
  97. }
  98. });
  99. return config;
  100. }
  101. const DEFAULT_CONFIG: ChatConfig = {
  102. historyMessageCount: 4,
  103. compressMessageLengthThreshold: 1000,
  104. sendBotMessages: true as boolean,
  105. submitKey: SubmitKey.CtrlEnter as SubmitKey,
  106. avatar: "1f603",
  107. theme: Theme.Auto as Theme,
  108. tightBorder: false,
  109. modelConfig: {
  110. model: "gpt-3.5-turbo",
  111. temperature: 1,
  112. max_tokens: 2000,
  113. presence_penalty: 0,
  114. },
  115. };
  116. export interface ChatStat {
  117. tokenCount: number;
  118. wordCount: number;
  119. charCount: number;
  120. }
  121. export interface ChatSession {
  122. id: number;
  123. topic: string;
  124. memoryPrompt: string;
  125. messages: Message[];
  126. stat: ChatStat;
  127. lastUpdate: string;
  128. lastSummarizeIndex: number;
  129. }
  130. const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
  131. function createEmptySession(): ChatSession {
  132. const createDate = new Date().toLocaleString();
  133. return {
  134. id: Date.now(),
  135. topic: DEFAULT_TOPIC,
  136. memoryPrompt: "",
  137. messages: [
  138. {
  139. role: "assistant",
  140. content: Locale.Store.BotHello,
  141. date: createDate,
  142. },
  143. ],
  144. stat: {
  145. tokenCount: 0,
  146. wordCount: 0,
  147. charCount: 0,
  148. },
  149. lastUpdate: createDate,
  150. lastSummarizeIndex: 0,
  151. };
  152. }
  153. interface ChatStore {
  154. config: ChatConfig;
  155. sessions: ChatSession[];
  156. currentSessionIndex: number;
  157. removeSession: (index: number) => void;
  158. selectSession: (index: number) => void;
  159. newSession: () => void;
  160. currentSession: () => ChatSession;
  161. onNewMessage: (message: Message) => void;
  162. onUserInput: (content: string) => Promise<void>;
  163. summarizeSession: () => void;
  164. updateStat: (message: Message) => void;
  165. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  166. updateMessage: (
  167. sessionIndex: number,
  168. messageIndex: number,
  169. updater: (message?: Message) => void
  170. ) => void;
  171. getMessagesWithMemory: () => Message[];
  172. getMemoryPrompt: () => Message;
  173. getConfig: () => ChatConfig;
  174. resetConfig: () => void;
  175. updateConfig: (updater: (config: ChatConfig) => void) => void;
  176. clearAllData: () => void;
  177. }
  178. const LOCAL_KEY = "chat-next-web-store";
  179. export const useChatStore = create<ChatStore>()(
  180. persist(
  181. (set, get) => ({
  182. sessions: [createEmptySession()],
  183. currentSessionIndex: 0,
  184. config: {
  185. ...DEFAULT_CONFIG,
  186. },
  187. resetConfig() {
  188. set(() => ({ config: { ...DEFAULT_CONFIG } }));
  189. },
  190. getConfig() {
  191. return get().config;
  192. },
  193. updateConfig(updater) {
  194. const config = get().config;
  195. updater(config);
  196. set(() => ({ config }));
  197. },
  198. selectSession(index: number) {
  199. set({
  200. currentSessionIndex: index,
  201. });
  202. },
  203. removeSession(index: number) {
  204. set((state) => {
  205. let nextIndex = state.currentSessionIndex;
  206. const sessions = state.sessions;
  207. if (sessions.length === 1) {
  208. return {
  209. currentSessionIndex: 0,
  210. sessions: [createEmptySession()],
  211. };
  212. }
  213. sessions.splice(index, 1);
  214. if (nextIndex === index) {
  215. nextIndex -= 1;
  216. }
  217. return {
  218. currentSessionIndex: nextIndex,
  219. sessions,
  220. };
  221. });
  222. },
  223. newSession() {
  224. set((state) => ({
  225. currentSessionIndex: 0,
  226. sessions: [createEmptySession()].concat(state.sessions),
  227. }));
  228. },
  229. currentSession() {
  230. let index = get().currentSessionIndex;
  231. const sessions = get().sessions;
  232. if (index < 0 || index >= sessions.length) {
  233. index = Math.min(sessions.length - 1, Math.max(0, index));
  234. set(() => ({ currentSessionIndex: index }));
  235. }
  236. const session = sessions[index];
  237. return session;
  238. },
  239. onNewMessage(message) {
  240. get().updateCurrentSession((session) => {
  241. session.lastUpdate = new Date().toLocaleString();
  242. });
  243. get().updateStat(message);
  244. get().summarizeSession();
  245. },
  246. async onUserInput(content) {
  247. const userMessage: Message = {
  248. role: "user",
  249. content,
  250. date: new Date().toLocaleString(),
  251. };
  252. const botMessage: Message = {
  253. content: "",
  254. role: "assistant",
  255. date: new Date().toLocaleString(),
  256. streaming: true,
  257. };
  258. // get recent messages
  259. const recentMessages = get().getMessagesWithMemory();
  260. const sendMessages = recentMessages.concat(userMessage);
  261. const sessionIndex = get().currentSessionIndex;
  262. const messageIndex = get().currentSession().messages.length + 1;
  263. // save user's and bot's message
  264. get().updateCurrentSession((session) => {
  265. session.messages.push(userMessage);
  266. session.messages.push(botMessage);
  267. });
  268. // make request
  269. console.log("[User Input] ", sendMessages);
  270. requestChatStream(sendMessages, {
  271. onMessage(content, done) {
  272. // stream response
  273. if (done) {
  274. botMessage.streaming = false;
  275. botMessage.content = content;
  276. get().onNewMessage(botMessage);
  277. ControllerPool.remove(sessionIndex, messageIndex);
  278. } else {
  279. botMessage.content = content;
  280. set(() => ({}));
  281. }
  282. },
  283. onError(error) {
  284. botMessage.content += "\n\n" + Locale.Store.Error;
  285. botMessage.streaming = false;
  286. set(() => ({}));
  287. ControllerPool.remove(sessionIndex, messageIndex);
  288. },
  289. onController(controller) {
  290. // collect controller for stop/retry
  291. ControllerPool.addController(
  292. sessionIndex,
  293. messageIndex,
  294. controller
  295. );
  296. },
  297. filterBot: !get().config.sendBotMessages,
  298. modelConfig: get().config.modelConfig,
  299. });
  300. },
  301. getMemoryPrompt() {
  302. const session = get().currentSession();
  303. return {
  304. role: "system",
  305. content: Locale.Store.Prompt.History(session.memoryPrompt),
  306. date: "",
  307. } as Message;
  308. },
  309. getMessagesWithMemory() {
  310. const session = get().currentSession();
  311. const config = get().config;
  312. const n = session.messages.length;
  313. const recentMessages = session.messages.slice(
  314. n - config.historyMessageCount
  315. );
  316. const memoryPrompt = get().getMemoryPrompt();
  317. if (session.memoryPrompt) {
  318. recentMessages.unshift(memoryPrompt);
  319. }
  320. return recentMessages;
  321. },
  322. updateMessage(
  323. sessionIndex: number,
  324. messageIndex: number,
  325. updater: (message?: Message) => void
  326. ) {
  327. const sessions = get().sessions;
  328. const session = sessions.at(sessionIndex);
  329. const messages = session?.messages;
  330. updater(messages?.at(messageIndex));
  331. set(() => ({ sessions }));
  332. },
  333. summarizeSession() {
  334. const session = get().currentSession();
  335. if (session.topic === DEFAULT_TOPIC && session.messages.length >= 3) {
  336. // should summarize topic
  337. requestWithPrompt(session.messages, Locale.Store.Prompt.Topic).then(
  338. (res) => {
  339. get().updateCurrentSession(
  340. (session) => (session.topic = trimTopic(res))
  341. );
  342. }
  343. );
  344. }
  345. const config = get().config;
  346. let toBeSummarizedMsgs = session.messages.slice(
  347. session.lastSummarizeIndex
  348. );
  349. const historyMsgLength = toBeSummarizedMsgs.reduce(
  350. (pre, cur) => pre + cur.content.length,
  351. 0
  352. );
  353. if (historyMsgLength > 4000) {
  354. toBeSummarizedMsgs = toBeSummarizedMsgs.slice(
  355. -config.historyMessageCount
  356. );
  357. }
  358. // add memory prompt
  359. toBeSummarizedMsgs.unshift(get().getMemoryPrompt());
  360. const lastSummarizeIndex = session.messages.length;
  361. console.log(
  362. "[Chat History] ",
  363. toBeSummarizedMsgs,
  364. historyMsgLength,
  365. config.compressMessageLengthThreshold
  366. );
  367. if (historyMsgLength > config.compressMessageLengthThreshold) {
  368. requestChatStream(
  369. toBeSummarizedMsgs.concat({
  370. role: "system",
  371. content: Locale.Store.Prompt.Summarize,
  372. date: "",
  373. }),
  374. {
  375. filterBot: false,
  376. onMessage(message, done) {
  377. session.memoryPrompt = message;
  378. if (done) {
  379. console.log("[Memory] ", session.memoryPrompt);
  380. session.lastSummarizeIndex = lastSummarizeIndex;
  381. }
  382. },
  383. onError(error) {
  384. console.error("[Summarize] ", error);
  385. },
  386. }
  387. );
  388. }
  389. },
  390. updateStat(message) {
  391. get().updateCurrentSession((session) => {
  392. session.stat.charCount += message.content.length;
  393. // TODO: should update chat count and word count
  394. });
  395. },
  396. updateCurrentSession(updater) {
  397. const sessions = get().sessions;
  398. const index = get().currentSessionIndex;
  399. updater(sessions[index]);
  400. set(() => ({ sessions }));
  401. },
  402. clearAllData() {
  403. if (confirm(Locale.Store.ConfirmClearAll)) {
  404. localStorage.clear();
  405. location.reload();
  406. }
  407. },
  408. }),
  409. {
  410. name: LOCAL_KEY,
  411. version: 1,
  412. }
  413. )
  414. );