|
@@ -1,12 +1,6 @@
|
|
|
import { create } from "zustand";
|
|
|
import { persist } from "zustand/middleware";
|
|
|
|
|
|
-import { type ChatCompletionResponseMessage } from "openai";
|
|
|
-import {
|
|
|
- ControllerPool,
|
|
|
- requestChatStream,
|
|
|
- requestWithPrompt,
|
|
|
-} from "../requests";
|
|
|
import { trimTopic } from "../utils";
|
|
|
|
|
|
import Locale from "../locales";
|
|
@@ -14,9 +8,11 @@ import { showToast } from "../components/ui-lib";
|
|
|
import { ModelType } from "./config";
|
|
|
import { createEmptyMask, Mask } from "./mask";
|
|
|
import { StoreKey } from "../constant";
|
|
|
-import { api } from "../client/api";
|
|
|
+import { api, RequestMessage } from "../client/api";
|
|
|
+import { ChatControllerPool } from "../client/controller";
|
|
|
+import { prettyObject } from "../utils/format";
|
|
|
|
|
|
-export type Message = ChatCompletionResponseMessage & {
|
|
|
+export type ChatMessage = RequestMessage & {
|
|
|
date: string;
|
|
|
streaming?: boolean;
|
|
|
isError?: boolean;
|
|
@@ -24,7 +20,7 @@ export type Message = ChatCompletionResponseMessage & {
|
|
|
model?: ModelType;
|
|
|
};
|
|
|
|
|
|
-export function createMessage(override: Partial<Message>): Message {
|
|
|
+export function createMessage(override: Partial<ChatMessage>): ChatMessage {
|
|
|
return {
|
|
|
id: Date.now(),
|
|
|
date: new Date().toLocaleString(),
|
|
@@ -34,8 +30,6 @@ export function createMessage(override: Partial<Message>): Message {
|
|
|
};
|
|
|
}
|
|
|
|
|
|
-export const ROLES: Message["role"][] = ["system", "user", "assistant"];
|
|
|
-
|
|
|
export interface ChatStat {
|
|
|
tokenCount: number;
|
|
|
wordCount: number;
|
|
@@ -48,7 +42,7 @@ export interface ChatSession {
|
|
|
topic: string;
|
|
|
|
|
|
memoryPrompt: string;
|
|
|
- messages: Message[];
|
|
|
+ messages: ChatMessage[];
|
|
|
stat: ChatStat;
|
|
|
lastUpdate: number;
|
|
|
lastSummarizeIndex: number;
|
|
@@ -57,7 +51,7 @@ export interface ChatSession {
|
|
|
}
|
|
|
|
|
|
export const DEFAULT_TOPIC = Locale.Store.DefaultTopic;
|
|
|
-export const BOT_HELLO: Message = createMessage({
|
|
|
+export const BOT_HELLO: ChatMessage = createMessage({
|
|
|
role: "assistant",
|
|
|
content: Locale.Store.BotHello,
|
|
|
});
|
|
@@ -89,24 +83,24 @@ interface ChatStore {
|
|
|
newSession: (mask?: Mask) => void;
|
|
|
deleteSession: (index: number) => void;
|
|
|
currentSession: () => ChatSession;
|
|
|
- onNewMessage: (message: Message) => void;
|
|
|
+ onNewMessage: (message: ChatMessage) => void;
|
|
|
onUserInput: (content: string) => Promise<void>;
|
|
|
summarizeSession: () => void;
|
|
|
- updateStat: (message: Message) => void;
|
|
|
+ updateStat: (message: ChatMessage) => void;
|
|
|
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
|
|
|
updateMessage: (
|
|
|
sessionIndex: number,
|
|
|
messageIndex: number,
|
|
|
- updater: (message?: Message) => void,
|
|
|
+ updater: (message?: ChatMessage) => void,
|
|
|
) => void;
|
|
|
resetSession: () => void;
|
|
|
- getMessagesWithMemory: () => Message[];
|
|
|
- getMemoryPrompt: () => Message;
|
|
|
+ getMessagesWithMemory: () => ChatMessage[];
|
|
|
+ getMemoryPrompt: () => ChatMessage;
|
|
|
|
|
|
clearAllData: () => void;
|
|
|
}
|
|
|
|
|
|
-function countMessages(msgs: Message[]) {
|
|
|
+function countMessages(msgs: ChatMessage[]) {
|
|
|
return msgs.reduce((pre, cur) => pre + cur.content.length, 0);
|
|
|
}
|
|
|
|
|
@@ -241,12 +235,12 @@ export const useChatStore = create<ChatStore>()(
|
|
|
const session = get().currentSession();
|
|
|
const modelConfig = session.mask.modelConfig;
|
|
|
|
|
|
- const userMessage: Message = createMessage({
|
|
|
+ const userMessage: ChatMessage = createMessage({
|
|
|
role: "user",
|
|
|
content,
|
|
|
});
|
|
|
|
|
|
- const botMessage: Message = createMessage({
|
|
|
+ const botMessage: ChatMessage = createMessage({
|
|
|
role: "assistant",
|
|
|
streaming: true,
|
|
|
id: userMessage.id! + 1,
|
|
@@ -278,45 +272,54 @@ export const useChatStore = create<ChatStore>()(
|
|
|
|
|
|
// make request
|
|
|
console.log("[User Input] ", sendMessages);
|
|
|
- requestChatStream(sendMessages, {
|
|
|
- onMessage(content, done) {
|
|
|
- // stream response
|
|
|
- if (done) {
|
|
|
- botMessage.streaming = false;
|
|
|
- botMessage.content = content;
|
|
|
- get().onNewMessage(botMessage);
|
|
|
- ControllerPool.remove(
|
|
|
- sessionIndex,
|
|
|
- botMessage.id ?? messageIndex,
|
|
|
- );
|
|
|
- } else {
|
|
|
- botMessage.content = content;
|
|
|
- set(() => ({}));
|
|
|
- }
|
|
|
+ api.llm.chat({
|
|
|
+ messages: sendMessages,
|
|
|
+ config: { ...modelConfig, stream: true },
|
|
|
+ onUpdate(message) {
|
|
|
+ botMessage.streaming = true;
|
|
|
+ botMessage.content = message;
|
|
|
+ set(() => ({}));
|
|
|
+ },
|
|
|
+ onFinish(message) {
|
|
|
+ botMessage.streaming = false;
|
|
|
+ botMessage.content = message;
|
|
|
+ get().onNewMessage(botMessage);
|
|
|
+ ChatControllerPool.remove(
|
|
|
+ sessionIndex,
|
|
|
+ botMessage.id ?? messageIndex,
|
|
|
+ );
|
|
|
+ set(() => ({}));
|
|
|
},
|
|
|
- onError(error, statusCode) {
|
|
|
+ onError(error) {
|
|
|
const isAborted = error.message.includes("aborted");
|
|
|
- if (statusCode === 401) {
|
|
|
- botMessage.content = Locale.Error.Unauthorized;
|
|
|
- } else if (!isAborted) {
|
|
|
+ if (
|
|
|
+ botMessage.content !== Locale.Error.Unauthorized &&
|
|
|
+ !isAborted
|
|
|
+ ) {
|
|
|
botMessage.content += "\n\n" + Locale.Store.Error;
|
|
|
+ } else if (botMessage.content.length === 0) {
|
|
|
+ botMessage.content = prettyObject(error);
|
|
|
}
|
|
|
botMessage.streaming = false;
|
|
|
userMessage.isError = !isAborted;
|
|
|
botMessage.isError = !isAborted;
|
|
|
|
|
|
set(() => ({}));
|
|
|
- ControllerPool.remove(sessionIndex, botMessage.id ?? messageIndex);
|
|
|
+ ChatControllerPool.remove(
|
|
|
+ sessionIndex,
|
|
|
+ botMessage.id ?? messageIndex,
|
|
|
+ );
|
|
|
+
|
|
|
+ console.error("[Chat] error ", error);
|
|
|
},
|
|
|
onController(controller) {
|
|
|
// collect controller for stop/retry
|
|
|
- ControllerPool.addController(
|
|
|
+ ChatControllerPool.addController(
|
|
|
sessionIndex,
|
|
|
botMessage.id ?? messageIndex,
|
|
|
controller,
|
|
|
);
|
|
|
},
|
|
|
- modelConfig: { ...modelConfig },
|
|
|
});
|
|
|
},
|
|
|
|
|
@@ -330,7 +333,7 @@ export const useChatStore = create<ChatStore>()(
|
|
|
? Locale.Store.Prompt.History(session.memoryPrompt)
|
|
|
: "",
|
|
|
date: "",
|
|
|
- } as Message;
|
|
|
+ } as ChatMessage;
|
|
|
},
|
|
|
|
|
|
getMessagesWithMemory() {
|
|
@@ -385,7 +388,7 @@ export const useChatStore = create<ChatStore>()(
|
|
|
updateMessage(
|
|
|
sessionIndex: number,
|
|
|
messageIndex: number,
|
|
|
- updater: (message?: Message) => void,
|
|
|
+ updater: (message?: ChatMessage) => void,
|
|
|
) {
|
|
|
const sessions = get().sessions;
|
|
|
const session = sessions.at(sessionIndex);
|
|
@@ -410,13 +413,24 @@ export const useChatStore = create<ChatStore>()(
|
|
|
session.topic === DEFAULT_TOPIC &&
|
|
|
countMessages(session.messages) >= SUMMARIZE_MIN_LEN
|
|
|
) {
|
|
|
- requestWithPrompt(session.messages, Locale.Store.Prompt.Topic, {
|
|
|
- model: "gpt-3.5-turbo",
|
|
|
- }).then((res) => {
|
|
|
- get().updateCurrentSession(
|
|
|
- (session) =>
|
|
|
- (session.topic = res ? trimTopic(res) : DEFAULT_TOPIC),
|
|
|
- );
|
|
|
+ const topicMessages = session.messages.concat(
|
|
|
+ createMessage({
|
|
|
+ role: "user",
|
|
|
+ content: Locale.Store.Prompt.Topic,
|
|
|
+ }),
|
|
|
+ );
|
|
|
+ api.llm.chat({
|
|
|
+ messages: topicMessages,
|
|
|
+ config: {
|
|
|
+ model: "gpt-3.5-turbo",
|
|
|
+ },
|
|
|
+ onFinish(message) {
|
|
|
+ get().updateCurrentSession(
|
|
|
+ (session) =>
|
|
|
+ (session.topic =
|
|
|
+ message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC),
|
|
|
+ );
|
|
|
+ },
|
|
|
});
|
|
|
}
|
|
|
|
|
@@ -450,26 +464,24 @@ export const useChatStore = create<ChatStore>()(
|
|
|
historyMsgLength > modelConfig.compressMessageLengthThreshold &&
|
|
|
session.mask.modelConfig.sendMemory
|
|
|
) {
|
|
|
- requestChatStream(
|
|
|
- toBeSummarizedMsgs.concat({
|
|
|
+ api.llm.chat({
|
|
|
+ messages: toBeSummarizedMsgs.concat({
|
|
|
role: "system",
|
|
|
content: Locale.Store.Prompt.Summarize,
|
|
|
date: "",
|
|
|
}),
|
|
|
- {
|
|
|
- overrideModel: "gpt-3.5-turbo",
|
|
|
- onMessage(message, done) {
|
|
|
- session.memoryPrompt = message;
|
|
|
- if (done) {
|
|
|
- console.log("[Memory] ", session.memoryPrompt);
|
|
|
- session.lastSummarizeIndex = lastSummarizeIndex;
|
|
|
- }
|
|
|
- },
|
|
|
- onError(error) {
|
|
|
- console.error("[Summarize] ", error);
|
|
|
- },
|
|
|
+ config: { ...modelConfig, stream: true },
|
|
|
+ onUpdate(message) {
|
|
|
+ session.memoryPrompt = message;
|
|
|
},
|
|
|
- );
|
|
|
+ onFinish(message) {
|
|
|
+ console.log("[Memory] ", message);
|
|
|
+ session.lastSummarizeIndex = lastSummarizeIndex;
|
|
|
+ },
|
|
|
+ onError(err) {
|
|
|
+ console.error("[Summarize] ", err);
|
|
|
+ },
|
|
|
+ });
|
|
|
}
|
|
|
},
|
|
|
|