Browse Source

refactor: llm client api

Yidadaa 1 year ago
parent
commit
bd90caa99d
8 changed files with 279 additions and 22 deletions
  1. 109 0
      app/client/api.ts
  2. 37 0
      app/client/controller.ts
  3. 124 0
      app/client/platforms/openai.ts
  4. 2 0
      app/constant.ts
  5. 0 22
      app/requests.ts
  6. 1 0
      app/store/chat.ts
  7. 1 0
      package.json
  8. 5 0
      yarn.lock

+ 109 - 0
app/client/api.ts

@@ -0,0 +1,109 @@
+import { fetchEventSource } from "@microsoft/fetch-event-source";
+import { ACCESS_CODE_PREFIX } from "../constant";
+import { ModelType, useAccessStore } from "../store";
+import { ChatGPTApi } from "./platforms/openai";
+
+export enum MessageRole {
+  System = "system",
+  User = "user",
+  Assistant = "assistant",
+}
+
+export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
+export type ChatModel = ModelType;
+
+export interface Message {
+  role: MessageRole;
+  content: string;
+}
+
+export interface LLMConfig {
+  temperature?: number;
+  topP?: number;
+  stream?: boolean;
+  presencePenalty?: number;
+  frequencyPenalty?: number;
+}
+
+export interface ChatOptions {
+  messages: Message[];
+  model: ChatModel;
+  config: LLMConfig;
+
+  onUpdate: (message: string, chunk: string) => void;
+  onFinish: (message: string) => void;
+  onError: (err: Error) => void;
+  onUnAuth: () => void;
+}
+
+export interface LLMUsage {
+  used: number;
+  total: number;
+}
+
+export abstract class LLMApi {
+  abstract chat(options: ChatOptions): Promise<void>;
+  abstract usage(): Promise<LLMUsage>;
+}
+
+export class ClientApi {
+  public llm: LLMApi;
+
+  constructor() {
+    this.llm = new ChatGPTApi();
+  }
+
+  headers() {
+    const accessStore = useAccessStore.getState();
+    let headers: Record<string, string> = {};
+
+    const makeBearer = (token: string) => `Bearer ${token.trim()}`;
+    const validString = (x: string) => x && x.length > 0;
+
+    // use user's api key first
+    if (validString(accessStore.token)) {
+      headers.Authorization = makeBearer(accessStore.token);
+    } else if (
+      accessStore.enabledAccessControl() &&
+      validString(accessStore.accessCode)
+    ) {
+      headers.Authorization = makeBearer(
+        ACCESS_CODE_PREFIX + accessStore.accessCode,
+      );
+    }
+
+    return headers;
+  }
+
+  config() {}
+
+  prompts() {}
+
+  masks() {}
+}
+
+export const api = new ClientApi();
+
+export function getHeaders() {
+  const accessStore = useAccessStore.getState();
+  let headers: Record<string, string> = {
+    "Content-Type": "application/json",
+  };
+
+  const makeBearer = (token: string) => `Bearer ${token.trim()}`;
+  const validString = (x: string) => x && x.length > 0;
+
+  // use user's api key first
+  if (validString(accessStore.token)) {
+    headers.Authorization = makeBearer(accessStore.token);
+  } else if (
+    accessStore.enabledAccessControl() &&
+    validString(accessStore.accessCode)
+  ) {
+    headers.Authorization = makeBearer(
+      ACCESS_CODE_PREFIX + accessStore.accessCode,
+    );
+  }
+
+  return headers;
+}

+ 37 - 0
app/client/controller.ts

@@ -0,0 +1,37 @@
+// To store message streaming controller
+export const ChatControllerPool = {
+  controllers: {} as Record<string, AbortController>,
+
+  addController(
+    sessionIndex: number,
+    messageId: number,
+    controller: AbortController,
+  ) {
+    const key = this.key(sessionIndex, messageId);
+    this.controllers[key] = controller;
+    return key;
+  },
+
+  stop(sessionIndex: number, messageId: number) {
+    const key = this.key(sessionIndex, messageId);
+    const controller = this.controllers[key];
+    controller?.abort();
+  },
+
+  stopAll() {
+    Object.values(this.controllers).forEach((v) => v.abort());
+  },
+
+  hasPending() {
+    return Object.values(this.controllers).length > 0;
+  },
+
+  remove(sessionIndex: number, messageId: number) {
+    const key = this.key(sessionIndex, messageId);
+    delete this.controllers[key];
+  },
+
+  key(sessionIndex: number, messageIndex: number) {
+    return `${sessionIndex},${messageIndex}`;
+  },
+};

+ 124 - 0
app/client/platforms/openai.ts

@@ -0,0 +1,124 @@
+import { REQUEST_TIMEOUT_MS } from "@/app/constant";
+import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
+import {
+  EventStreamContentType,
+  fetchEventSource,
+} from "@microsoft/fetch-event-source";
+import { ChatOptions, LLMApi, LLMUsage } from "../api";
+
+export class ChatGPTApi implements LLMApi {
+  public ChatPath = "v1/chat/completions";
+
+  path(path: string): string {
+    const openaiUrl = useAccessStore.getState().openaiUrl;
+    if (openaiUrl.endsWith("/")) openaiUrl.slice(0, openaiUrl.length - 1);
+    return [openaiUrl, path].join("/");
+  }
+
+  extractMessage(res: any) {
+    return res.choices?.at(0)?.message?.content ?? "";
+  }
+
+  async chat(options: ChatOptions) {
+    const messages = options.messages.map((v) => ({
+      role: v.role,
+      content: v.content,
+    }));
+
+    const modelConfig = {
+      ...useAppConfig.getState().modelConfig,
+      ...useChatStore.getState().currentSession().mask.modelConfig,
+      ...{
+        model: options.model,
+      },
+    };
+
+    const requestPayload = {
+      messages,
+      stream: options.config.stream,
+      model: modelConfig.model,
+      temperature: modelConfig.temperature,
+      presence_penalty: modelConfig.presence_penalty,
+    };
+
+    console.log("[Request] openai payload: ", requestPayload);
+
+    const shouldStream = !!options.config.stream;
+    const controller = new AbortController();
+
+    try {
+      const chatPath = this.path(this.ChatPath);
+      const chatPayload = {
+        method: "POST",
+        body: JSON.stringify(requestPayload),
+        signal: controller.signal,
+      };
+
+      // make a fetch request
+      const reqestTimeoutId = setTimeout(
+        () => controller.abort(),
+        REQUEST_TIMEOUT_MS,
+      );
+      if (shouldStream) {
+        let responseText = "";
+
+        fetchEventSource(chatPath, {
+          ...chatPayload,
+          async onopen(res) {
+            if (
+              res.ok &&
+              res.headers.get("Content-Type") === EventStreamContentType
+            ) {
+              return;
+            }
+
+            if (res.status === 401) {
+              // TODO: Unauthorized 401
+              responseText += "\n\n";
+            } else if (res.status !== 200) {
+              console.error("[Request] response", res);
+              throw new Error("[Request] server error");
+            }
+          },
+          onmessage: (ev) => {
+            if (ev.data === "[DONE]") {
+              return options.onFinish(responseText);
+            }
+            try {
+              const resJson = JSON.parse(ev.data);
+              const message = this.extractMessage(resJson);
+              responseText += message;
+              options.onUpdate(responseText, message);
+            } catch (e) {
+              console.error("[Request] stream error", e);
+              options.onError(e as Error);
+            }
+          },
+          onclose() {
+            options.onError(new Error("stream closed unexpected"));
+          },
+          onerror(err) {
+            options.onError(err);
+          },
+        });
+      } else {
+        const res = await fetch(chatPath, chatPayload);
+
+        const resJson = await res.json();
+        const message = this.extractMessage(resJson);
+        options.onFinish(message);
+      }
+
+      clearTimeout(reqestTimeoutId);
+    } catch (e) {
+      console.log("[Request] failed to make a chat reqeust", e);
+      options.onError(e as Error);
+    }
+  }
+  async usage() {
+    return {
+      used: 0,
+      total: 0,
+    } as LLMUsage;
+  }
+}

+ 2 - 0
app/constant.ts

@@ -40,3 +40,5 @@ export const NARROW_SIDEBAR_WIDTH = 100;
 export const ACCESS_CODE_PREFIX = "ak-";
 
 export const LAST_INPUT_KEY = "last-input";
+
+export const REQUEST_TIMEOUT_MS = 60000;

+ 0 - 22
app/requests.ts

@@ -43,28 +43,6 @@ const makeRequestParam = (
   };
 };
 
-export function getHeaders() {
-  const accessStore = useAccessStore.getState();
-  let headers: Record<string, string> = {};
-
-  const makeBearer = (token: string) => `Bearer ${token.trim()}`;
-  const validString = (x: string) => x && x.length > 0;
-
-  // use user's api key first
-  if (validString(accessStore.token)) {
-    headers.Authorization = makeBearer(accessStore.token);
-  } else if (
-    accessStore.enabledAccessControl() &&
-    validString(accessStore.accessCode)
-  ) {
-    headers.Authorization = makeBearer(
-      ACCESS_CODE_PREFIX + accessStore.accessCode,
-    );
-  }
-
-  return headers;
-}
-
 export function requestOpenaiClient(path: string) {
   const openaiUrl = useAccessStore.getState().openaiUrl;
   return (body: any, method = "POST") =>

+ 1 - 0
app/store/chat.ts

@@ -14,6 +14,7 @@ import { showToast } from "../components/ui-lib";
 import { ModelType } from "./config";
 import { createEmptyMask, Mask } from "./mask";
 import { StoreKey } from "../constant";
+import { api } from "../client/api";
 
 export type Message = ChatCompletionResponseMessage & {
   date: string;

+ 1 - 0
package.json

@@ -14,6 +14,7 @@
   },
   "dependencies": {
     "@hello-pangea/dnd": "^16.2.0",
+    "@microsoft/fetch-event-source": "^2.0.1",
     "@svgr/webpack": "^6.5.1",
     "@vercel/analytics": "^0.1.11",
     "emoji-picker-react": "^4.4.7",

+ 5 - 0
yarn.lock

@@ -1111,6 +1111,11 @@
   dependencies:
     "@types/react" ">=16.0.0"
 
+"@microsoft/fetch-event-source@^2.0.1":
+  version "2.0.1"
+  resolved "https://registry.npmmirror.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d"
+  integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==
+
 "@next/env@13.3.1-canary.8":
   version "13.3.1-canary.8"
   resolved "https://registry.yarnpkg.com/@next/env/-/env-13.3.1-canary.8.tgz#9f5cf57999e4f4b59ef6407924803a247cc4e451"