/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { SqlExpression, SqlLiteral, SqlTable } from 'druid-query-toolkit'; import { create } from 'zustand'; import type { PersistStorage, StorageValue } from 'zustand/middleware'; import { persist } from 'zustand/middleware'; import type { StateCreator } from 'zustand/vanilla'; import { createStore } from 'zustand/vanilla'; import { inflateValidateAndFillDefaults } from './inflate-value'; import type { ExpressionMeta } from './models'; import type { Parameter, ParameterDefinitions, ParameterTypes, TypedParameter } from './parameter'; import type { ParametersToParams, VisualModule, VisualModuleContext, VisualModuleInstance, } from './visual-module'; const HOST_STORE_VERSION = 1; const STORAGE_BLACK_LIST: (keyof HostState)[] = ['columns', 'visualModules']; type ParameterOverrides = { [K in Exclude, 'type'>]?: Partial[K]>; }; export interface VisualModuleOverrides

{ /** * Override properties on each parameter. */ parameters?: { [Name in keyof P]?: ParameterOverrides; }; /** * Override the order of parameters. */ parameterOrder?: (keyof P)[]; /** * Initial table for this module. * * If left unspecified, the host's table will be used. */ table?: string; /** * Initial WHERE clause for this module. * * If left unspecified, the host's WHERE clause will be used. */ where?: SqlExpression | string; /** * Initial HAVING clause for this module. * * If left unspecified, the host's HAVING clause will be used. */ having?: SqlExpression | string; /** * Optional callback to invoke when the module is updated. * * @param updateEvent - Event that triggered the update. * @param context - Current visual module context */ onUpdate?( updateEvent: UpdateEvent>, context: VisualModuleContext>, ): void; } export interface RegisteredVisualModule { /** * Name of the visual module. */ moduleName: string; /** * Parameters that this module accepts, if any. */ parameters?: ParameterDefinitions; /** * Order of parameters, if any have been defined. */ parameterOrder?: string[]; /** * Creates a new instance of this module. */ (context: VisualModuleContext>): VisualModuleInstance< Record >; } /** * Event passed to a visual module's update() handler. */ export interface UpdateEvent { /** * Parameter values for this visual module. */ params: Params; /** * Current table name. */ table: SqlExpression; /** * Where expression. */ where: SqlExpression; /** * Having expression. */ having?: SqlExpression; } export interface ModuleState { params: Record; table: SqlExpression | undefined; where: SqlExpression | undefined; having?: SqlExpression | undefined; } export type TransferValue = [type: Parameter['type'], value: unknown]; export interface HostState { table: SqlExpression; columns?: ExpressionMeta[]; where: SqlExpression; having?: SqlExpression; transferState: Record; visualModules?: Record; visualModuleState: Record; getModuleParams

(moduleName: string): ParametersToParams

; getUpdateEvent

( moduleName: string, ): UpdateEvent>; registerModule

( name: string, module: VisualModule

, overrides?: VisualModuleOverrides

, ): void; removeModule(name: string): void; setModuleParams

>( moduleName: string, params: Partial

| ((prev: P | undefined) => Partial

), ): void; setModuleTable(moduleName: string, table: SqlExpression | undefined): void; setModuleWhere(moduleName: string, where: SqlExpression | undefined): void; setModuleHaving(moduleName: string, having: SqlExpression | undefined): void; } export interface StateStorage { /** * Gets the persisted state for a store identified by `name`. * * The returned value must be the raw stringified (unserialized) state. * * @param name - Identifier for the store. * @returns - Persisted store state. */ getItem: (name: string) => any | null; /** * Sets the persisted state for a store identified by `name`. * * @param name - Identifier for the store. */ setItem: (name: string, value: any) => void; /** * Clears the persisted state for a store identified by `name`. * * @param name - Identifier for the store. */ removeItem: (name: string) => void; } export interface HostStorePersistOptions { /** * Key to use for persisting the host store. Must be unique for each host instance. * * @default '@druid/host' */ name?: string; /** * Storage implementation to use for persisting the host store. */ storage: 'localstorage' | 'sessionstorage' | 'url' | StateStorage; } export interface HostStoreOptions { /** * Options for persisting the host store. * * If not defined, the host store will not be persisted. */ persist?: HostStorePersistOptions; /** * Initial SQL expression for 'table' in the host store. * */ table: SqlExpression; /** * Initial value for 'where' in the host store. * * @default SqlLiteral.TRUE */ where?: SqlExpression | string; /** * Initial value for 'having' in the host store. * * @default SqlLiteral.TRUE */ having?: SqlExpression | string; } function createJSONStorage(getStorage: () => StateStorage): PersistStorage { return { getItem(name) { const storage = getStorage(); const value: StorageValue = storage.getItem(name) || ({} as HostState); return { state: { ...value.state, where: SqlExpression.parse(value.state.where), having: value.state.having ? SqlExpression.parse(value.state.having) : undefined, table: SqlTable.parse(value.state.table), }, version: value.version, }; }, setItem(name, value) { const storage = getStorage(); storage.setItem(name, value); }, removeItem(name) { const storage = getStorage(); storage.removeItem(name); }, }; } export function createHostStore(options: HostStoreOptions) { const creator: StateCreator = (set, get) => { return { transferState: {}, visualModules: {}, visualModuleState: {}, table: options.table, columns: [], where: typeof options.where === 'string' ? SqlExpression.parse(options.where) : options.where ?? SqlLiteral.TRUE, having: typeof options.having === 'string' ? SqlExpression.parse(options.having) : options.having ?? SqlLiteral.TRUE, getModuleParams

(moduleName: string): ParametersToParams

{ const { columns, transferState, visualModules, visualModuleState } = get(); const visualModule = visualModules ? visualModules[moduleName] : undefined; const moduleState = visualModuleState[moduleName]; return inflateValidateAndFillDefaults( moduleState?.params ?? {}, visualModule?.parameters, transferState, columns || [], ) as ParametersToParams

; }, getUpdateEvent

( moduleName: string, ): UpdateEvent> { const { table, where, visualModuleState, having } = get(); const moduleState = visualModuleState[moduleName]; return { params: this.getModuleParams(moduleName), table: moduleState?.table ?? table, where: moduleState?.where ?? where, having: moduleState?.having ?? having, }; }, registerModule

( name: string, module: VisualModule

, overrides: VisualModuleOverrides

= {}, ) { const registeredModule: RegisteredVisualModule = context => { return module(context as VisualModuleContext>); }; registeredModule.moduleName = name; if ('parameters' in module) { registeredModule.parameters = { ...module.parameters }; if (overrides.parameters) { for (const [name, paramOverrides] of Object.entries(overrides.parameters)) { registeredModule.parameters[name] = { ...registeredModule.parameters[name], ...paramOverrides, control: { ...registeredModule.parameters[name].control, ...paramOverrides.control, }, // Don't allow the type to be overridden type: registeredModule.parameters[name].type, }; } } if (overrides.parameterOrder) { registeredModule.parameterOrder = overrides.parameterOrder as string[]; } } set({ visualModules: { ...get().visualModules, [name]: registeredModule }, }); }, removeModule(name: string) { const nextVisualModules = { ...get().visualModules }; delete nextVisualModules[name]; set({ visualModules: nextVisualModules }); }, setModuleParams

>( moduleName: string, params: Partial

| ((prev: P | undefined) => Partial

), ) { const { visualModules, visualModuleState, transferState } = get(); const visualModule = visualModules ? visualModules[moduleName] : undefined; const currentParams = visualModuleState[moduleName]?.params; const nextParams = typeof params === 'function' ? params(currentParams as P | undefined) : params; if (currentParams === nextParams || !visualModule) return; const moduleState: ModuleState = { table: undefined, where: undefined, ...get().visualModuleState[moduleName], params: { ...currentParams, ...nextParams }, }; // Update transferState if necessary let nextTransferState = transferState; for (const [paramName, value] of Object.entries(nextParams)) { const parameter = visualModule.parameters?.[paramName]; if (!parameter?.control?.transferGroup) continue; nextTransferState = { ...nextTransferState, [parameter.control.transferGroup]: [parameter.type, value], }; } set({ visualModuleState: { ...visualModuleState, [moduleName]: moduleState, }, transferState: nextTransferState, }); }, setModuleTable(moduleName, table) { const visualModuleState = get().visualModuleState; if (visualModuleState[moduleName]?.table === table) { return; } const moduleState: ModuleState = { params: {}, where: undefined, ...get().visualModuleState[moduleName], table, }; set({ visualModuleState: { ...visualModuleState, [moduleName]: moduleState, }, }); }, setModuleWhere(moduleName, where) { const visualModuleState = get().visualModuleState; if (visualModuleState[moduleName]?.where === where) { return; } const moduleState: ModuleState = { params: {}, table: undefined, ...get().visualModuleState[moduleName], where, }; set({ visualModuleState: { ...visualModuleState, [moduleName]: moduleState, }, }); }, setModuleHaving(moduleName, having) { const visualModuleState = get().visualModuleState; if (visualModuleState[moduleName]?.having === having) { return; } const moduleState: ModuleState = { params: {}, table: undefined, where: undefined, ...get().visualModuleState[moduleName], having, }; set({ visualModuleState: { ...visualModuleState, [moduleName]: moduleState, }, }); }, }; }; if (!options.persist) { return createStore(creator); } const { storage } = options.persist; const persistKey = options.persist.name ?? '@druid/host'; const partialize = (state: HostState): HostState => { return Object.fromEntries( Object.entries(state).filter(([key]) => !STORAGE_BLACK_LIST.includes(key as keyof HostState)), ) as HostState; }; const store = create()( persist(creator, { name: persistKey, partialize, storage: createJSONStorage( typeof storage === 'object' ? () => storage : storage === 'sessionstorage' ? () => sessionStorage : storage === 'url' ? () => hashStorage : () => localStorage, ), version: HOST_STORE_VERSION, }) as StateCreator, ); if (storage === 'url') { // Rehydrates the store from the hash window.addEventListener('hashchange', () => { const searchParams = new URLSearchParams(location.hash.slice(1)); const storedValue = searchParams.get(persistKey); const storedJSON = JSON.parse(storedValue!); const currentJson = JSON.stringify({ state: partialize(store.getState()), version: HOST_STORE_VERSION, }); // we check that what's stored is different from what's in the hash // to avoid unncessry renders if (storedJSON === currentJson) return; (store as any).persist.rehydrate(); }); } return store; } const hashStorage: StateStorage = { getItem: (key): string => { const searchParams = new URLSearchParams(location.hash.slice(1)); const storedValue = searchParams.get(key); return storedValue ? JSON.parse(storedValue) : undefined; }, setItem: (key, newValue): void => { const searchParams = new URLSearchParams(location.hash.slice(1)); searchParams.set(key, JSON.stringify(newValue)); location.hash = searchParams.toString(); }, removeItem: (key): void => { const searchParams = new URLSearchParams(location.hash.slice(1)); searchParams.delete(key); location.hash = searchParams.toString(); }, };