mask.ts 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { BUILTIN_MASKS } from "../masks";
  4. import { getLang, Lang } from "../locales";
  5. import { DEFAULT_TOPIC, Message } from "./chat";
  6. import { ModelConfig, ModelType, useAppConfig } from "./config";
  7. import { StoreKey } from "../constant";
  8. export type Mask = {
  9. id: number;
  10. avatar: string;
  11. name: string;
  12. context: Message[];
  13. modelConfig: ModelConfig;
  14. lang: Lang;
  15. builtin: boolean;
  16. };
  17. export const DEFAULT_MASK_STATE = {
  18. masks: {} as Record<number, Mask>,
  19. globalMaskId: 0,
  20. };
  21. export type MaskState = typeof DEFAULT_MASK_STATE;
  22. type MaskStore = MaskState & {
  23. create: (mask?: Partial<Mask>) => Mask;
  24. update: (id: number, updater: (mask: Mask) => void) => void;
  25. delete: (id: number) => void;
  26. search: (text: string) => Mask[];
  27. get: (id?: number) => Mask | null;
  28. getAll: () => Mask[];
  29. };
  30. export const DEFAULT_MASK_ID = 1145141919810;
  31. export const DEFAULT_MASK_AVATAR = "gpt-bot";
  32. export const createEmptyMask = () =>
  33. ({
  34. id: DEFAULT_MASK_ID,
  35. avatar: DEFAULT_MASK_AVATAR,
  36. name: DEFAULT_TOPIC,
  37. context: [],
  38. modelConfig: { ...useAppConfig.getState().modelConfig },
  39. lang: getLang(),
  40. builtin: false,
  41. } as Mask);
  42. export const useMaskStore = create<MaskStore>()(
  43. persist(
  44. (set, get) => ({
  45. ...DEFAULT_MASK_STATE,
  46. create(mask) {
  47. set(() => ({ globalMaskId: get().globalMaskId + 1 }));
  48. const id = get().globalMaskId;
  49. const masks = get().masks;
  50. masks[id] = {
  51. ...createEmptyMask(),
  52. ...mask,
  53. id,
  54. };
  55. set(() => ({ masks }));
  56. return masks[id];
  57. },
  58. update(id, updater) {
  59. const masks = get().masks;
  60. const mask = masks[id];
  61. if (!mask) return;
  62. const updateMask = { ...mask };
  63. updater(updateMask);
  64. masks[id] = updateMask;
  65. set(() => ({ masks }));
  66. },
  67. delete(id) {
  68. const masks = get().masks;
  69. delete masks[id];
  70. set(() => ({ masks }));
  71. },
  72. get(id) {
  73. return get().masks[id ?? 1145141919810];
  74. },
  75. getAll() {
  76. const userMasks = Object.values(get().masks).sort(
  77. (a, b) => b.id - a.id,
  78. );
  79. return userMasks.concat(BUILTIN_MASKS);
  80. },
  81. search(text) {
  82. return Object.values(get().masks);
  83. },
  84. }),
  85. {
  86. name: StoreKey.Mask,
  87. version: 2,
  88. },
  89. ),
  90. );