store.ts 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { type ChatCompletionResponseMessage } from "openai";
  4. import { requestChat, requestWithPrompt } from "./requests";
  5. import { trimTopic } from "./utils";
  6. export type Message = ChatCompletionResponseMessage & {
  7. date: string;
  8. };
  9. interface ChatConfig {
  10. maxToken: number;
  11. }
  12. interface ChatStat {
  13. tokenCount: number;
  14. wordCount: number;
  15. charCount: number;
  16. }
  17. interface ChatSession {
  18. topic: string;
  19. memoryPrompt: string;
  20. messages: Message[];
  21. stat: ChatStat;
  22. lastUpdate: string;
  23. deleted?: boolean;
  24. }
  25. const DEFAULT_TOPIC = "新的聊天";
  26. function createEmptySession(): ChatSession {
  27. const createDate = new Date().toLocaleString();
  28. return {
  29. topic: DEFAULT_TOPIC,
  30. memoryPrompt: "",
  31. messages: [
  32. {
  33. role: "assistant",
  34. content: "有什么可以帮你的吗",
  35. date: createDate,
  36. },
  37. ],
  38. stat: {
  39. tokenCount: 0,
  40. wordCount: 0,
  41. charCount: 0,
  42. },
  43. lastUpdate: createDate,
  44. };
  45. }
  46. interface ChatStore {
  47. sessions: ChatSession[];
  48. currentSessionIndex: number;
  49. removeSession: (index: number) => void;
  50. selectSession: (index: number) => void;
  51. newSession: () => void;
  52. currentSession: () => ChatSession;
  53. onNewMessage: (message: Message) => void;
  54. onUserInput: (content: string) => Promise<void>;
  55. onBotResponse: (message: Message) => void;
  56. summarizeSession: () => void;
  57. updateStat: (message: Message) => void;
  58. updateCurrentSession: (updater: (session: ChatSession) => void) => void;
  59. }
  60. export const useChatStore = create<ChatStore>()(
  61. persist(
  62. (set, get) => ({
  63. sessions: [createEmptySession()],
  64. currentSessionIndex: 0,
  65. selectSession(index: number) {
  66. set({
  67. currentSessionIndex: index,
  68. });
  69. },
  70. removeSession(index: number) {
  71. set((state) => {
  72. let nextIndex = state.currentSessionIndex;
  73. const sessions = state.sessions;
  74. if (sessions.length === 1) {
  75. return {
  76. currentSessionIndex: 0,
  77. sessions: [createEmptySession()],
  78. };
  79. }
  80. sessions.splice(index, 1);
  81. if (nextIndex === index) {
  82. nextIndex -= 1;
  83. }
  84. return {
  85. currentSessionIndex: nextIndex,
  86. sessions,
  87. };
  88. });
  89. },
  90. newSession() {
  91. set((state) => ({
  92. currentSessionIndex: state.sessions.length,
  93. sessions: state.sessions.concat([createEmptySession()]),
  94. }));
  95. },
  96. currentSession() {
  97. let index = get().currentSessionIndex;
  98. const sessions = get().sessions;
  99. if (index < 0 || index >= sessions.length) {
  100. index = Math.min(sessions.length - 1, Math.max(0, index));
  101. set(() => ({ currentSessionIndex: index }));
  102. }
  103. return sessions[index];
  104. },
  105. onNewMessage(message) {
  106. get().updateCurrentSession((session) => {
  107. session.messages.push(message);
  108. });
  109. get().updateStat(message);
  110. get().summarizeSession();
  111. },
  112. async onUserInput(content) {
  113. const message: Message = {
  114. role: "user",
  115. content,
  116. date: new Date().toLocaleString(),
  117. };
  118. const messages = get().currentSession().messages.concat(message);
  119. get().onNewMessage(message);
  120. const res = await requestChat(messages);
  121. get().onNewMessage({
  122. ...res.choices[0].message!,
  123. date: new Date().toLocaleString(),
  124. });
  125. },
  126. onBotResponse(message) {
  127. get().onNewMessage(message);
  128. },
  129. summarizeSession() {
  130. const session = get().currentSession();
  131. if (session.topic !== DEFAULT_TOPIC) return;
  132. requestWithPrompt(
  133. session.messages,
  134. "简明扼要地 10 字以内总结主题"
  135. ).then((res) => {
  136. get().updateCurrentSession(
  137. (session) => (session.topic = trimTopic(res))
  138. );
  139. });
  140. },
  141. updateStat(message) {
  142. get().updateCurrentSession((session) => {
  143. session.stat.charCount += message.content.length;
  144. // TODO: should update chat count and word count
  145. });
  146. },
  147. updateCurrentSession(updater) {
  148. const sessions = get().sessions;
  149. const index = get().currentSessionIndex;
  150. updater(sessions[index]);
  151. set(() => ({ sessions }));
  152. },
  153. }),
  154. { name: "chat-next-web-store" }
  155. )
  156. );