store.ts 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import { requestChat, requestChatStream, requestWithPrompt } from "./requests";
  5. import { trimTopic } from "./utils";
  6. export type Message = ChatCompletionResponseMessage & {
  7. date: string;
  8. streaming?: boolean;
  9. };
  10. interface ChatConfig {
  11. maxToken: number;
  12. }
  13. interface ChatStat {
  14. tokenCount: number;
  15. wordCount: number;
  16. charCount: number;
  17. }
  18. interface ChatSession {
  19. topic: string;
  20. memoryPrompt: string;
  21. messages: Message[];
  22. stat: ChatStat;
  23. lastUpdate: string;
  24. deleted?: boolean;
  25. }
  26. const DEFAULT_TOPIC = "新的聊天";
  27. function createEmptySession(): ChatSession {
  28. const createDate = new Date().toLocaleString();
  29. return {
  30. topic: DEFAULT_TOPIC,
  31. memoryPrompt: "",
  32. messages: [
  33. {
  34. role: "assistant",
  35. content: "有什么可以帮你的吗",
  36. date: createDate,
  37. },
  38. ],
  39. stat: {
  40. tokenCount: 0,
  41. wordCount: 0,
  42. charCount: 0,
  43. },
  44. lastUpdate: createDate,
  45. };
  46. }
  47. interface ChatStore {
  48. sessions: ChatSession[];
  49. currentSessionIndex: number;
  50. removeSession: (index: number) => void;
  51. selectSession: (index: number) => void;
  52. newSession: () => void;
  53. currentSession: () => ChatSession;
  54. onNewMessage: (message: Message) => void;
  55. onUserInput: (content: string) => Promise<void>;
  56. onBotResponse: (message: Message) => void;
  57. summarizeSession: () => void;
  58. updateStat: (message: Message) => void;
  59. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  60. updateMessage: (
  61. sessionIndex: number,
  62. messageIndex: number,
  63. updater: (message?: Message) => void
  64. ) => void;
  65. }
  66. export const useChatStore = create<ChatStore>()(
  67. persist(
  68. (set, get) => ({
  69. sessions: [createEmptySession()],
  70. currentSessionIndex: 0,
  71. selectSession(index: number) {
  72. set({
  73. currentSessionIndex: index,
  74. });
  75. },
  76. removeSession(index: number) {
  77. set((state) => {
  78. let nextIndex = state.currentSessionIndex;
  79. const sessions = state.sessions;
  80. if (sessions.length === 1) {
  81. return {
  82. currentSessionIndex: 0,
  83. sessions: [createEmptySession()],
  84. };
  85. }
  86. sessions.splice(index, 1);
  87. if (nextIndex === index) {
  88. nextIndex -= 1;
  89. }
  90. return {
  91. currentSessionIndex: nextIndex,
  92. sessions,
  93. };
  94. });
  95. },
  96. newSession() {
  97. set((state) => ({
  98. currentSessionIndex: 0,
  99. sessions: [createEmptySession()].concat(state.sessions),
  100. }));
  101. },
  102. currentSession() {
  103. let index = get().currentSessionIndex;
  104. const sessions = get().sessions;
  105. if (index < 0 || index >= sessions.length) {
  106. index = Math.min(sessions.length - 1, Math.max(0, index));
  107. set(() => ({ currentSessionIndex: index }));
  108. }
  109. return sessions[index];
  110. },
  111. onNewMessage(message) {
  112. get().updateCurrentSession((session) => {
  113. session.messages.push(message);
  114. });
  115. get().updateStat(message);
  116. get().summarizeSession();
  117. },
  118. async onUserInput(content) {
  119. const message: Message = {
  120. role: "user",
  121. content,
  122. date: new Date().toLocaleString(),
  123. };
  124. const messages = get().currentSession().messages.concat(message);
  125. get().onNewMessage(message);
  126. const botMessage: Message = {
  127. content: "",
  128. role: "assistant",
  129. date: new Date().toLocaleString(),
  130. streaming: true,
  131. };
  132. get().updateCurrentSession((session) => {
  133. session.messages.push(botMessage);
  134. });
  135. requestChatStream(messages, {
  136. onMessage(content, done) {
  137. if (done) {
  138. get().updateStat(message);
  139. get().summarizeSession();
  140. } else {
  141. botMessage.content = content;
  142. botMessage.streaming = false;
  143. set(() => ({}));
  144. }
  145. },
  146. });
  147. },
  148. updateMessage(
  149. sessionIndex: number,
  150. messageIndex: number,
  151. updater: (message?: Message) => void
  152. ) {
  153. const sessions = get().sessions;
  154. const session = sessions.at(sessionIndex);
  155. const messages = session?.messages;
  156. console.log(sessions, messages?.length, messages?.at(messageIndex));
  157. updater(messages?.at(messageIndex));
  158. set(() => ({ sessions }));
  159. },
  160. onBotResponse(message) {
  161. get().onNewMessage(message);
  162. },
  163. summarizeSession() {
  164. const session = get().currentSession();
  165. if (session.topic !== DEFAULT_TOPIC) return;
  166. requestWithPrompt(
  167. session.messages,
  168. "简明扼要地 10 字以内总结主题"
  169. ).then((res) => {
  170. get().updateCurrentSession(
  171. (session) => (session.topic = trimTopic(res))
  172. );
  173. });
  174. },
  175. updateStat(message) {
  176. get().updateCurrentSession((session) => {
  177. session.stat.charCount += message.content.length;
  178. // TODO: should update chat count and word count
  179. });
  180. },
  181. updateCurrentSession(updater) {
  182. const sessions = get().sessions;
  183. const index = get().currentSessionIndex;
  184. updater(sessions[index]);
  185. set(() => ({ sessions }));
  186. },
  187. }),
  188. { name: "chat-next-web-store" }
  189. )
  190. );