/* Copyright 2026 Marimo. All rights reserved. */ import { z } from "zod"; import { FieldOptions, randomNumber } from "@/components/forms/options"; import { DATA_TYPES } from "@/core/kernel/messages"; import { AGGREGATION_FNS, type ColumnId, NUMPY_DTYPES, } from "@/plugins/impl/data-frames/types"; import { ALL_OPERATORS, isConditionValueValid, type OperatorType, } from "./utils/operators"; export const columnToFieldTypesSchema = z.array( z.tuple([z.coerce.string(), z.tuple([z.enum(DATA_TYPES), z.string()])]), ); export const column_id = z .string() .min(1, "Required") .or(z.number()) .transform((v) => v as ColumnId) .describe(FieldOptions.of({ label: "Column", special: "column_id" })); export const column_id_array = z .array(column_id.describe(FieldOptions.of({ special: "column_id" }))) .min(1, "At least one column is required") .default([]) .describe(FieldOptions.of({ label: "Columns", minLength: 1 })); const ColumnConversionTransformSchema = z .object({ type: z.literal("column_conversion"), column_id: column_id, data_type: z .enum(NUMPY_DTYPES) .describe(FieldOptions.of({ label: "Data type (numpy)" })) .default("bool"), errors: z .enum(["ignore", "raise"]) .default("ignore") .describe( FieldOptions.of({ label: "Handle errors", special: "radio_group" }), ), }) .describe(FieldOptions.of({})); const RenameColumnTransformSchema = z.object({ type: z.literal("rename_column"), column_id: column_id, new_column_id: z .string() .min(1, "Required") .transform((v) => v as ColumnId) .describe(FieldOptions.of({ label: "New column name", minLength: 1 })), }); const SortColumnTransformSchema = z.object({ type: z.literal("sort_column"), column_id: column_id, ascending: z .boolean() .describe(FieldOptions.of({ label: "Ascending" })) .default(true), na_position: z .enum(["first", "last"]) .describe( FieldOptions.of({ label: "N/A position", special: "radio_group" }), ) .default("last"), }); export const FilterConditionSchema = z .object({ column_id: column_id, operator: z .enum(Object.keys(ALL_OPERATORS) as [OperatorType, ...OperatorType[]]) .describe(FieldOptions.of({ label: " " })), type: z.literal("condition").default("condition"), value: z.any().describe(FieldOptions.of({ label: "Value" })), negate: z.boolean().default(false), }) .describe(FieldOptions.of({ direction: "row", special: "column_filter" })); export type FilterConditionType = z.infer; export interface FilterGroupType { type: "group"; operator: "and" | "or"; children: (FilterConditionType | FilterGroupType)[]; negate: boolean; } export const FilterGroupSchema: z.ZodType = z.lazy(() => z.object({ type: z.literal("group").default("group"), operator: z.enum(["and", "or"]).default("and"), children: z .array(z.union([FilterConditionSchema, FilterGroupSchema])) .default([]), negate: z.boolean().default(false), }), ); const FilterRowsTransformSchema = z.object({ type: z.literal("filter_rows"), operation: z .enum(["keep_rows", "remove_rows"]) .default("keep_rows") .describe(FieldOptions.of({ special: "radio_group" })), where: z .array(FilterConditionSchema) .min(1) .describe(FieldOptions.of({ label: "Value", minLength: 1 })) .default(() => [ { column_id: "" as ColumnId, operator: "==" as const, value: "", type: "condition" as const, negate: false, }, ]) .transform((value): FilterGroupType => { const validConditions = value.filter((condition) => { return isConditionValueValid(condition.operator, condition.value); }); return { type: "group", operator: "and", children: validConditions, negate: false, }; }), }); const GroupByTransformSchema = z .object({ type: z.literal("group_by"), column_ids: z .array(column_id.describe(FieldOptions.of({ special: "column_id" }))) .default([]) .describe(FieldOptions.of({ label: "Group by columns", minLength: 1 })), aggregation_column_ids: z .array(column_id.describe(FieldOptions.of({ special: "column_id" }))) .default([]) .describe(FieldOptions.of({ label: "Aggregate on columns" })), aggregation: z .enum(AGGREGATION_FNS) .default("count") .describe(FieldOptions.of({ label: "Aggregation" })), drop_na: z .boolean() .default(false) .describe(FieldOptions.of({ label: "Drop N/A" })), }) .describe(FieldOptions.of({})); const AggregateTransformSchema = z .object({ type: z.literal("aggregate"), column_ids: column_id_array, aggregations: z .array(z.enum(AGGREGATION_FNS)) .min(1, "At least one aggregation is required") .default(["count"]) .describe(FieldOptions.of({ label: "Aggregations", minLength: 1 })), }) .describe(FieldOptions.of({ direction: "row" })); const SelectColumnsTransformSchema = z.object({ type: z.literal("select_columns"), column_ids: column_id_array, }); const SampleRowsTransformSchema = z.object({ type: z.literal("sample_rows"), n: z .number() .positive() .describe(FieldOptions.of({ label: "Number of rows" })), seed: z .number() .default(() => randomNumber()) .describe( FieldOptions.of({ label: "Re-sample", special: "random_number_button" }), ), replace: z .boolean() .default(false) .describe( FieldOptions.of({ label: "Sample with replacement", }), ), }); const ShuffleRowsTransformSchema = z.object({ type: z.literal("shuffle_rows"), seed: z .number() .default(() => randomNumber()) .describe( FieldOptions.of({ label: "Re-shuffle", special: "random_number_button" }), ), }); const ExplodeColumnsTransformSchema = z.object({ type: z.literal("explode_columns"), column_ids: column_id_array, }); const ExpandDictTransformSchema = z.object({ type: z.literal("expand_dict"), column_id: column_id, }); const UniqueTransformSchema = z .object({ type: z.literal("unique"), column_ids: column_id_array, keep: z .enum(["first", "last", "none", "any"]) .default("first") .describe(FieldOptions.of({ label: "Keep" })), }) .describe(FieldOptions.of({ direction: "row" })); const PivotTransformSchema = z .object({ type: z.literal("pivot"), column_ids: column_id_array, index_column_ids: z .array(column_id.describe(FieldOptions.of({ special: "column_id" }))) .default([]) .describe(FieldOptions.of({ label: "Rows" })), value_column_ids: z .array(column_id.describe(FieldOptions.of({ special: "column_id" }))) .default([]) .describe(FieldOptions.of({ label: "Values", minLength: 1 })), aggregation: z .enum(AGGREGATION_FNS) .default("sum") .describe(FieldOptions.of({ label: "Aggregation" })), }) .describe(FieldOptions.of({})); export const TransformTypeSchema = z.union([ FilterRowsTransformSchema, SelectColumnsTransformSchema, RenameColumnTransformSchema, ColumnConversionTransformSchema, SortColumnTransformSchema, GroupByTransformSchema, AggregateTransformSchema, SampleRowsTransformSchema, ShuffleRowsTransformSchema, ExplodeColumnsTransformSchema, ExpandDictTransformSchema, UniqueTransformSchema, PivotTransformSchema, ]); export type TransformType = z.infer; export const TransformationsSchema = z.object({ transforms: z.array(TransformTypeSchema), }); export type Transformations = z.infer;