mask.ts 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { getLang, Lang } from "../locales";
  4. import { DEFAULT_TOPIC, Message } from "./chat";
  5. import { ModelConfig, ModelType, useAppConfig } from "./config";
  6. export const MASK_KEY = "mask-store";
  7. export type Mask = {
  8. id: number;
  9. avatar: string;
  10. name: string;
  11. context: Message[];
  12. modelConfig: ModelConfig;
  13. lang: Lang;
  14. };
  15. export const DEFAULT_MASK_STATE = {
  16. masks: {} as Record<number, Mask>,
  17. globalMaskId: 0,
  18. };
  19. export type MaskState = typeof DEFAULT_MASK_STATE;
  20. type MaskStore = MaskState & {
  21. create: (mask?: Partial<Mask>) => Mask;
  22. update: (id: number, updater: (mask: Mask) => void) => void;
  23. delete: (id: number) => void;
  24. search: (text: string) => Mask[];
  25. get: (id?: number) => Mask | null;
  26. getAll: () => Mask[];
  27. };
  28. export const DEFAULT_MASK_ID = 1145141919810;
  29. export const DEFAULT_MASK_AVATAR = "gpt-bot";
  30. export const createEmptyMask = () =>
  31. ({
  32. id: DEFAULT_MASK_ID,
  33. avatar: DEFAULT_MASK_AVATAR,
  34. name: DEFAULT_TOPIC,
  35. context: [],
  36. modelConfig: { ...useAppConfig.getState().modelConfig },
  37. lang: getLang(),
  38. } as Mask);
  39. export const useMaskStore = create<MaskStore>()(
  40. persist(
  41. (set, get) => ({
  42. ...DEFAULT_MASK_STATE,
  43. create(mask) {
  44. set(() => ({ globalMaskId: get().globalMaskId + 1 }));
  45. const id = get().globalMaskId;
  46. const masks = get().masks;
  47. masks[id] = {
  48. ...createEmptyMask(),
  49. id,
  50. ...mask,
  51. };
  52. set(() => ({ masks }));
  53. return masks[id];
  54. },
  55. update(id, updater) {
  56. const masks = get().masks;
  57. const mask = masks[id];
  58. if (!mask) return;
  59. const updateMask = { ...mask };
  60. updater(updateMask);
  61. masks[id] = updateMask;
  62. set(() => ({ masks }));
  63. },
  64. delete(id) {
  65. const masks = get().masks;
  66. delete masks[id];
  67. set(() => ({ masks }));
  68. },
  69. get(id) {
  70. return get().masks[id ?? 1145141919810];
  71. },
  72. getAll() {
  73. return Object.values(get().masks).sort((a, b) => a.id - b.id);
  74. },
  75. search(text) {
  76. return Object.values(get().masks);
  77. },
  78. }),
  79. {
  80. name: MASK_KEY,
  81. version: 2,
  82. },
  83. ),
  84. );