mask.ts 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { getLang, Lang } from "../locales";
  4. import { Message } from "./chat";
  5. import { ModelConfig, useAppConfig } from "./config";
  6. export const MASK_KEY = "mask-store";
  7. export type Mask = {
  8. id: number;
  9. avatar: string;
  10. name: string;
  11. context: Message[];
  12. config: ModelConfig;
  13. lang: Lang;
  14. };
  15. export const DEFAULT_MASK_STATE = {
  16. masks: {} as Record<number, Mask>,
  17. globalMaskId: 0,
  18. };
  19. export type MaskState = typeof DEFAULT_MASK_STATE;
  20. type MaskStore = MaskState & {
  21. create: (mask: Partial<Mask>) => Mask;
  22. update: (id: number, updater: (mask: Mask) => void) => void;
  23. delete: (id: number) => void;
  24. search: (text: string) => Mask[];
  25. getAll: () => Mask[];
  26. };
  27. export const useMaskStore = create<MaskStore>()(
  28. persist(
  29. (set, get) => ({
  30. ...DEFAULT_MASK_STATE,
  31. create(mask) {
  32. set(() => ({ globalMaskId: get().globalMaskId + 1 }));
  33. const id = get().globalMaskId;
  34. const masks = get().masks;
  35. masks[id] = {
  36. id,
  37. avatar: "1f916",
  38. name: "",
  39. config: useAppConfig.getState().modelConfig,
  40. context: [],
  41. lang: getLang(),
  42. ...mask,
  43. };
  44. set(() => ({ masks }));
  45. return masks[id];
  46. },
  47. update(id, updater) {
  48. const masks = get().masks;
  49. const mask = masks[id];
  50. if (!mask) return;
  51. const updateMask = { ...mask };
  52. updater(updateMask);
  53. masks[id] = updateMask;
  54. set(() => ({ masks }));
  55. },
  56. delete(id) {
  57. const masks = get().masks;
  58. delete masks[id];
  59. set(() => ({ masks }));
  60. },
  61. getAll() {
  62. return Object.values(get().masks).sort((a, b) => a.id - b.id);
  63. },
  64. search(text) {
  65. return Object.values(get().masks);
  66. },
  67. }),
  68. {
  69. name: MASK_KEY,
  70. version: 2,
  71. },
  72. ),
  73. );