store.ts 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import { requestChat, requestChatStream, requestWithPrompt } from "./requests";
  5. import { trimTopic } from "./utils";
  6. export type Message = ChatCompletionResponseMessage & {
  7. date: string;
  8. streaming?: boolean;
  9. };
  10. export enum SubmitKey {
  11. Enter = "Enter",
  12. CtrlEnter = "Ctrl + Enter",
  13. ShiftEnter = "Shift + Enter",
  14. AltEnter = "Alt + Enter",
  15. }
  16. interface ChatConfig {
  17. maxToken?: number;
  18. historyMessageCount: number; // -1 means all
  19. sendBotMessages: boolean; // send bot's message or not
  20. submitKey: SubmitKey;
  21. avatar: string;
  22. }
  23. interface ChatStat {
  24. tokenCount: number;
  25. wordCount: number;
  26. charCount: number;
  27. }
  28. interface ChatSession {
  29. id: number;
  30. topic: string;
  31. memoryPrompt: string;
  32. messages: Message[];
  33. stat: ChatStat;
  34. lastUpdate: string;
  35. deleted?: boolean;
  36. }
  37. const DEFAULT_TOPIC = "新的聊天";
  38. function createEmptySession(): ChatSession {
  39. const createDate = new Date().toLocaleString();
  40. return {
  41. id: Date.now(),
  42. topic: DEFAULT_TOPIC,
  43. memoryPrompt: "",
  44. messages: [
  45. {
  46. role: "assistant",
  47. content: "有什么可以帮你的吗",
  48. date: createDate,
  49. },
  50. ],
  51. stat: {
  52. tokenCount: 0,
  53. wordCount: 0,
  54. charCount: 0,
  55. },
  56. lastUpdate: createDate,
  57. };
  58. }
  59. interface ChatStore {
  60. config: ChatConfig;
  61. sessions: ChatSession[];
  62. currentSessionIndex: number;
  63. removeSession: (index: number) => void;
  64. selectSession: (index: number) => void;
  65. newSession: () => void;
  66. currentSession: () => ChatSession;
  67. onNewMessage: (message: Message) => void;
  68. onUserInput: (content: string) => Promise<void>;
  69. onBotResponse: (message: Message) => void;
  70. summarizeSession: () => void;
  71. updateStat: (message: Message) => void;
  72. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  73. updateMessage: (
  74. sessionIndex: number,
  75. messageIndex: number,
  76. updater: (message?: Message) => void
  77. ) => void;
  78. getConfig: () => ChatConfig;
  79. updateConfig: (updater: (config: ChatConfig) => void) => void;
  80. }
  81. export const useChatStore = create<ChatStore>()(
  82. persist(
  83. (set, get) => ({
  84. sessions: [createEmptySession()],
  85. currentSessionIndex: 0,
  86. config: {
  87. historyMessageCount: 5,
  88. sendBotMessages: false as boolean,
  89. submitKey: SubmitKey.CtrlEnter,
  90. avatar: "1fae0",
  91. },
  92. getConfig() {
  93. return get().config;
  94. },
  95. updateConfig(updater) {
  96. const config = get().config;
  97. updater(config);
  98. set(() => ({ config }));
  99. },
  100. selectSession(index: number) {
  101. set({
  102. currentSessionIndex: index,
  103. });
  104. },
  105. removeSession(index: number) {
  106. set((state) => {
  107. let nextIndex = state.currentSessionIndex;
  108. const sessions = state.sessions;
  109. if (sessions.length === 1) {
  110. return {
  111. currentSessionIndex: 0,
  112. sessions: [createEmptySession()],
  113. };
  114. }
  115. sessions.splice(index, 1);
  116. if (nextIndex === index) {
  117. nextIndex -= 1;
  118. }
  119. return {
  120. currentSessionIndex: nextIndex,
  121. sessions,
  122. };
  123. });
  124. },
  125. newSession() {
  126. set((state) => ({
  127. currentSessionIndex: 0,
  128. sessions: [createEmptySession()].concat(state.sessions),
  129. }));
  130. },
  131. currentSession() {
  132. let index = get().currentSessionIndex;
  133. const sessions = get().sessions;
  134. if (index < 0 || index >= sessions.length) {
  135. index = Math.min(sessions.length - 1, Math.max(0, index));
  136. set(() => ({ currentSessionIndex: index }));
  137. }
  138. const session = sessions[index];
  139. return session;
  140. },
  141. onNewMessage(message) {
  142. get().updateCurrentSession((session) => {
  143. session.messages.push(message);
  144. });
  145. get().updateStat(message);
  146. get().summarizeSession();
  147. },
  148. async onUserInput(content) {
  149. const message: Message = {
  150. role: "user",
  151. content,
  152. date: new Date().toLocaleString(),
  153. };
  154. // get last five messges
  155. const messages = get().currentSession().messages.concat(message);
  156. get().onNewMessage(message);
  157. const botMessage: Message = {
  158. content: "",
  159. role: "assistant",
  160. date: new Date().toLocaleString(),
  161. streaming: true,
  162. };
  163. get().updateCurrentSession((session) => {
  164. session.messages.push(botMessage);
  165. });
  166. const fiveMessages = messages.slice(-5);
  167. requestChatStream(fiveMessages, {
  168. onMessage(content, done) {
  169. if (done) {
  170. botMessage.streaming = false;
  171. get().updateStat(botMessage);
  172. get().summarizeSession();
  173. } else {
  174. botMessage.content = content;
  175. set(() => ({}));
  176. }
  177. },
  178. onError(error) {
  179. botMessage.content = "出错了,稍后重试吧";
  180. botMessage.streaming = false;
  181. set(() => ({}));
  182. },
  183. filterBot: !get().config.sendBotMessages,
  184. });
  185. },
  186. updateMessage(
  187. sessionIndex: number,
  188. messageIndex: number,
  189. updater: (message?: Message) => void
  190. ) {
  191. const sessions = get().sessions;
  192. const session = sessions.at(sessionIndex);
  193. const messages = session?.messages;
  194. console.log(sessions, messages?.length, messages?.at(messageIndex));
  195. updater(messages?.at(messageIndex));
  196. set(() => ({ sessions }));
  197. },
  198. onBotResponse(message) {
  199. get().onNewMessage(message);
  200. },
  201. summarizeSession() {
  202. const session = get().currentSession();
  203. if (session.topic !== DEFAULT_TOPIC) return;
  204. requestWithPrompt(
  205. session.messages,
  206. "简明扼要地 10 字以内总结主题"
  207. ).then((res) => {
  208. get().updateCurrentSession(
  209. (session) => (session.topic = trimTopic(res))
  210. );
  211. });
  212. },
  213. updateStat(message) {
  214. get().updateCurrentSession((session) => {
  215. session.stat.charCount += message.content.length;
  216. // TODO: should update chat count and word count
  217. });
  218. },
  219. updateCurrentSession(updater) {
  220. const sessions = get().sessions;
  221. const index = get().currentSessionIndex;
  222. updater(sessions[index]);
  223. set(() => ({ sessions }));
  224. },
  225. }),
  226. { name: "chat-next-web-store" }
  227. )
  228. );