mask.ts 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { BUILTIN_MASKS } from "../masks";
  4. import { getLang, Lang } from "../locales";
  5. import { DEFAULT_TOPIC, ChatMessage } from "./chat";
  6. import { ModelConfig, ModelType, useAppConfig } from "./config";
  7. import { StoreKey } from "../constant";
  8. export type Mask = {
  9. id: number;
  10. avatar: string;
  11. name: string;
  12. context: ChatMessage[];
  13. syncGlobalConfig?: boolean;
  14. modelConfig: ModelConfig;
  15. lang: Lang;
  16. builtin: boolean;
  17. };
  18. export const DEFAULT_MASK_STATE = {
  19. masks: {} as Record<number, Mask>,
  20. globalMaskId: 0,
  21. };
  22. export type MaskState = typeof DEFAULT_MASK_STATE;
  23. type MaskStore = MaskState & {
  24. create: (mask?: Partial<Mask>) => Mask;
  25. update: (id: number, updater: (mask: Mask) => void) => void;
  26. delete: (id: number) => void;
  27. search: (text: string) => Mask[];
  28. get: (id?: number) => Mask | null;
  29. getAll: () => Mask[];
  30. };
  31. export const DEFAULT_MASK_ID = 1145141919810;
  32. export const DEFAULT_MASK_AVATAR = "gpt-bot";
  33. export const createEmptyMask = () =>
  34. ({
  35. id: DEFAULT_MASK_ID,
  36. avatar: DEFAULT_MASK_AVATAR,
  37. name: DEFAULT_TOPIC,
  38. context: [],
  39. syncGlobalConfig: true, // use global config as default
  40. modelConfig: { ...useAppConfig.getState().modelConfig },
  41. lang: getLang(),
  42. builtin: false,
  43. } as Mask);
  44. export const useMaskStore = create<MaskStore>()(
  45. persist(
  46. (set, get) => ({
  47. ...DEFAULT_MASK_STATE,
  48. create(mask) {
  49. set(() => ({ globalMaskId: get().globalMaskId + 1 }));
  50. const id = get().globalMaskId;
  51. const masks = get().masks;
  52. masks[id] = {
  53. ...createEmptyMask(),
  54. ...mask,
  55. id,
  56. builtin: false,
  57. };
  58. set(() => ({ masks }));
  59. return masks[id];
  60. },
  61. update(id, updater) {
  62. const masks = get().masks;
  63. const mask = masks[id];
  64. if (!mask) return;
  65. const updateMask = { ...mask };
  66. updater(updateMask);
  67. masks[id] = updateMask;
  68. set(() => ({ masks }));
  69. },
  70. delete(id) {
  71. const masks = get().masks;
  72. delete masks[id];
  73. set(() => ({ masks }));
  74. },
  75. get(id) {
  76. return get().masks[id ?? 1145141919810];
  77. },
  78. getAll() {
  79. const userMasks = Object.values(get().masks).sort(
  80. (a, b) => b.id - a.id,
  81. );
  82. return userMasks.concat(BUILTIN_MASKS);
  83. },
  84. search(text) {
  85. return Object.values(get().masks);
  86. },
  87. }),
  88. {
  89. name: StoreKey.Mask,
  90. version: 2,
  91. },
  92. ),
  93. );