mask.ts 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import { create } from "zustand";
  2. import { persist } from "zustand/middleware";
  3. import { BUILTIN_MASKS } from "../masks";
  4. import { getLang, Lang } from "../locales";
  5. import { DEFAULT_TOPIC, ChatMessage } from "./chat";
  6. import { ModelConfig, useAppConfig } from "./config";
  7. import { StoreKey } from "../constant";
  8. import { nanoid } from "nanoid";
  9. export type Mask = {
  10. id: string;
  11. createdAt: number;
  12. avatar: string;
  13. name: string;
  14. hideContext?: boolean;
  15. context: ChatMessage[];
  16. syncGlobalConfig?: boolean;
  17. modelConfig: ModelConfig;
  18. lang: Lang;
  19. builtin: boolean;
  20. };
  21. export const DEFAULT_MASK_STATE = {
  22. masks: {} as Record<string, Mask>,
  23. };
  24. export type MaskState = typeof DEFAULT_MASK_STATE;
  25. type MaskStore = MaskState & {
  26. create: (mask?: Partial<Mask>) => Mask;
  27. update: (id: string, updater: (mask: Mask) => void) => void;
  28. delete: (id: string) => void;
  29. search: (text: string) => Mask[];
  30. get: (id?: string) => Mask | null;
  31. getAll: () => Mask[];
  32. };
  33. export const DEFAULT_MASK_AVATAR = "gpt-bot";
  34. export const createEmptyMask = () =>
  35. ({
  36. id: nanoid(),
  37. avatar: DEFAULT_MASK_AVATAR,
  38. name: DEFAULT_TOPIC,
  39. context: [],
  40. syncGlobalConfig: true, // use global config as default
  41. modelConfig: { ...useAppConfig.getState().modelConfig },
  42. lang: getLang(),
  43. builtin: false,
  44. createdAt: Date.now(),
  45. } as Mask);
  46. export const useMaskStore = create<MaskStore>()(
  47. persist(
  48. (set, get) => ({
  49. ...DEFAULT_MASK_STATE,
  50. create(mask) {
  51. const masks = get().masks;
  52. const id = nanoid();
  53. masks[id] = {
  54. ...createEmptyMask(),
  55. ...mask,
  56. id,
  57. builtin: false,
  58. };
  59. set(() => ({ masks }));
  60. return masks[id];
  61. },
  62. update(id, updater) {
  63. const masks = get().masks;
  64. const mask = masks[id];
  65. if (!mask) return;
  66. const updateMask = { ...mask };
  67. updater(updateMask);
  68. masks[id] = updateMask;
  69. set(() => ({ masks }));
  70. },
  71. delete(id) {
  72. const masks = get().masks;
  73. delete masks[id];
  74. set(() => ({ masks }));
  75. },
  76. get(id) {
  77. return get().masks[id ?? 1145141919810];
  78. },
  79. getAll() {
  80. const userMasks = Object.values(get().masks).sort(
  81. (a, b) => b.createdAt - a.createdAt,
  82. );
  83. const config = useAppConfig.getState();
  84. if (config.hideBuiltinMasks) return userMasks;
  85. const buildinMasks = BUILTIN_MASKS.map(
  86. (m) =>
  87. ({
  88. ...m,
  89. modelConfig: {
  90. ...config.modelConfig,
  91. ...m.modelConfig,
  92. },
  93. } as Mask),
  94. );
  95. return userMasks.concat(buildinMasks);
  96. },
  97. search(text) {
  98. return Object.values(get().masks);
  99. },
  100. }),
  101. {
  102. name: StoreKey.Mask,
  103. version: 3,
  104. migrate(state, version) {
  105. const newState = JSON.parse(JSON.stringify(state)) as MaskState;
  106. // migrate mask id to nanoid
  107. if (version < 3) {
  108. Object.values(newState.masks).forEach((m) => (m.id = nanoid()));
  109. }
  110. return newState as any;
  111. },
  112. },
  113. ),
  114. );