import { Aggregate, Func } from '../ir' import { isRefProxy, toExpression } from './ref-proxy.js' import type { BasicExpression } from '../ir' import type { RefProxy } from './ref-proxy.js' import type { Context, GetRawResult, RefLeaf, StringifiableScalar, } from './types.js' import type { QueryBuilder } from './index.js' type StringRef = | RefLeaf | RefLeaf | RefLeaf type StringRefProxy = | RefProxy | RefProxy | RefProxy type StringBasicExpression = | BasicExpression | BasicExpression | BasicExpression type StringLike = | StringRef | StringRefProxy | StringBasicExpression | string | null | undefined type ComparisonOperand = | RefProxy | RefLeaf | T | BasicExpression | undefined | null type ComparisonOperandPrimitive = | T | BasicExpression | undefined | null // Helper type for values that can be lowered to expressions. type ExpressionLike = | Aggregate | BasicExpression | RefProxy | RefLeaf | string | number | boolean | bigint | Date | null | undefined | Array type CaseWhenValue = | ExpressionLike | QueryBuilder | ToArrayWrapper | ConcatToArrayWrapper | Record type ExtractCaseWhenValue = T extends CaseWhenWrapper ? TResult : T type CaseWhenResult< TValues extends Array, THasDefault extends boolean, > = TValues[number] extends ExpressionLike ? BasicExpression< ExtractType | (THasDefault extends true ? never : null) > : CaseWhenWrapper< | ExtractCaseWhenValue | (THasDefault extends true ? never : undefined) > // Helper type to extract the underlying type from various expression types type ExtractType = T extends RefProxy ? U : T extends RefLeaf ? U : T extends BasicExpression ? U : T // Helper type to determine aggregate return type based on input nullability type AggregateReturnType = ExtractType extends infer U ? U extends number | undefined | null | Date | bigint | string ? Aggregate : Aggregate : Aggregate // Helper type to determine string function return type based on input nullability type StringFunctionReturnType = ExtractType extends infer U ? U extends string | undefined | null ? BasicExpression : BasicExpression : BasicExpression // Helper type to determine numeric function return type based on input nullability // This handles string, array, and number inputs for functions like length() type NumericFunctionReturnType = ExtractType extends infer U ? U extends string | Array | undefined | null | number ? BasicExpression> : BasicExpression : BasicExpression // Transform string/array types to number while preserving nullability type MapToNumber = T extends string | Array ? number : T extends undefined ? undefined : T extends null ? null : T // Helper type for binary numeric operations (combines nullability of both operands) type BinaryNumericReturnType = ExtractType extends infer U1 ? ExtractType extends infer U2 ? U1 extends number ? U2 extends number ? BasicExpression : U2 extends number | undefined ? BasicExpression : U2 extends number | null ? BasicExpression : BasicExpression : U1 extends number | undefined ? U2 extends number ? BasicExpression : U2 extends number | undefined ? BasicExpression : BasicExpression : U1 extends number | null ? U2 extends number ? BasicExpression : BasicExpression : BasicExpression : BasicExpression : BasicExpression // Operators export function eq( left: ComparisonOperand, right: ComparisonOperand, ): BasicExpression export function eq( left: ComparisonOperandPrimitive, right: ComparisonOperandPrimitive, ): BasicExpression export function eq(left: Aggregate, right: any): BasicExpression export function eq(left: any, right: any): BasicExpression { return new Func(`eq`, [toExpression(left), toExpression(right)]) } export function gt( left: ComparisonOperand, right: ComparisonOperand, ): BasicExpression export function gt( left: ComparisonOperandPrimitive, right: ComparisonOperandPrimitive, ): BasicExpression export function gt(left: Aggregate, right: any): BasicExpression export function gt(left: any, right: any): BasicExpression { return new Func(`gt`, [toExpression(left), toExpression(right)]) } export function gte( left: ComparisonOperand, right: ComparisonOperand, ): BasicExpression export function gte( left: ComparisonOperandPrimitive, right: ComparisonOperandPrimitive, ): BasicExpression export function gte(left: Aggregate, right: any): BasicExpression export function gte(left: any, right: any): BasicExpression { return new Func(`gte`, [toExpression(left), toExpression(right)]) } export function lt( left: ComparisonOperand, right: ComparisonOperand, ): BasicExpression export function lt( left: ComparisonOperandPrimitive, right: ComparisonOperandPrimitive, ): BasicExpression export function lt(left: Aggregate, right: any): BasicExpression export function lt(left: any, right: any): BasicExpression { return new Func(`lt`, [toExpression(left), toExpression(right)]) } export function lte( left: ComparisonOperand, right: ComparisonOperand, ): BasicExpression export function lte( left: ComparisonOperandPrimitive, right: ComparisonOperandPrimitive, ): BasicExpression export function lte(left: Aggregate, right: any): BasicExpression export function lte(left: any, right: any): BasicExpression { return new Func(`lte`, [toExpression(left), toExpression(right)]) } // Overloads for and() - support 2 or more arguments export function and( left: ExpressionLike, right: ExpressionLike, ): BasicExpression export function and( left: ExpressionLike, right: ExpressionLike, ...rest: Array ): BasicExpression export function and( left: ExpressionLike, right: ExpressionLike, ...rest: Array ): BasicExpression { const allArgs = [left, right, ...rest] return new Func( `and`, allArgs.map((arg) => toExpression(arg)), ) } // Overloads for or() - support 2 or more arguments export function or( left: ExpressionLike, right: ExpressionLike, ): BasicExpression export function or( left: ExpressionLike, right: ExpressionLike, ...rest: Array ): BasicExpression export function or( left: ExpressionLike, right: ExpressionLike, ...rest: Array ): BasicExpression { const allArgs = [left, right, ...rest] return new Func( `or`, allArgs.map((arg) => toExpression(arg)), ) } export function not(value: ExpressionLike): BasicExpression { return new Func(`not`, [toExpression(value)]) } // Null/undefined checking functions export function isUndefined(value: ExpressionLike): BasicExpression { return new Func(`isUndefined`, [toExpression(value)]) } export function isNull(value: ExpressionLike): BasicExpression { return new Func(`isNull`, [toExpression(value)]) } export function inArray( value: ExpressionLike, array: ExpressionLike, ): BasicExpression { return new Func(`in`, [toExpression(value), toExpression(array)]) } export function like( left: StringLike, right: StringLike, ): BasicExpression export function like(left: any, right: any): BasicExpression { return new Func(`like`, [toExpression(left), toExpression(right)]) } export function ilike( left: StringLike, right: StringLike, ): BasicExpression { return new Func(`ilike`, [toExpression(left), toExpression(right)]) } // Functions export function upper( arg: T, ): StringFunctionReturnType { return new Func(`upper`, [toExpression(arg)]) as StringFunctionReturnType } export function lower( arg: T, ): StringFunctionReturnType { return new Func(`lower`, [toExpression(arg)]) as StringFunctionReturnType } export function length( arg: T, ): NumericFunctionReturnType { return new Func(`length`, [toExpression(arg)]) as NumericFunctionReturnType } export function concat( arg: ToArrayWrapper, ): ConcatToArrayWrapper export function concat(...args: Array): BasicExpression export function concat( ...args: Array> ): BasicExpression | ConcatToArrayWrapper { const toArrayArg = args.find( (arg): arg is ToArrayWrapper => arg instanceof ToArrayWrapper, ) if (toArrayArg) { if (args.length !== 1) { throw new Error( `concat(toArray(...)) currently supports only a single toArray(...) argument`, ) } return new ConcatToArrayWrapper(toArrayArg.query) } return new Func( `concat`, args.map((arg) => toExpression(arg)), ) } // Helper type for coalesce: extracts non-nullish value types from all args type CoalesceArgTypes> = { [K in keyof T]: NonNullable> }[number] // Whether any arg in the tuple is statically guaranteed non-null (i.e., does not include null | undefined) type HasGuaranteedNonNull> = { [K in keyof T]: null extends ExtractType ? false : undefined extends ExtractType ? false : true }[number] extends false ? false : true // coalesce() return type: union of all non-null arg types; null included unless a guaranteed non-null arg exists type CoalesceReturnType> = HasGuaranteedNonNull extends true ? BasicExpression> : BasicExpression | null> export function coalesce]>( ...args: T ): CoalesceReturnType { return new Func( `coalesce`, args.map((arg) => toExpression(arg)), ) as CoalesceReturnType } /** * Returns the value for the first matching condition, similar to SQL * `CASE WHEN`. * * Arguments are evaluated as condition/value pairs followed by an optional * default value. Scalar branch values return a query expression and can be used * in expression contexts like `select`, `where`, `orderBy`, `groupBy`, * `having`, and equality join operands. If no scalar branch matches and no * default is provided, the result is `null`. * * When a branch value is a projection object, `caseWhen` becomes a select-only * projection value. Projection branches can include nested fields, ref spreads, * and includes. If no projection branch matches and no default is provided, the * result is `undefined`. * * @example * ```ts * caseWhen(gt(user.age, 18), `adult`, `minor`) * ``` * * @example * ```ts * caseWhen( * gt(user.age, 65), * `senior`, * gt(user.age, 18), * `adult`, * `minor`, * ) * ``` * * @example * ```ts * caseWhen(gt(user.age, 18), { * ...user, * posts: q * .from({ post: postsCollection }) * .where(({ post }) => eq(post.userId, user.id)), * }) * ``` */ export function caseWhen( condition1: C1, value1: V1, ): CaseWhenResult<[V1], false> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, D extends CaseWhenValue, >(condition1: C1, value1: V1, defaultValue: D): CaseWhenResult<[V1, D], true> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, ): CaseWhenResult<[V1, V2], false> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, D extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, defaultValue: D, ): CaseWhenResult<[V1, V2, D], true> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, ): CaseWhenResult<[V1, V2, V3], false> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, D extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, defaultValue: D, ): CaseWhenResult<[V1, V2, V3, D], true> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, C4 extends ExpressionLike, V4 extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, condition4: C4, value4: V4, ): CaseWhenResult<[V1, V2, V3, V4], false> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, C4 extends ExpressionLike, V4 extends CaseWhenValue, D extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, condition4: C4, value4: V4, defaultValue: D, ): CaseWhenResult<[V1, V2, V3, V4, D], true> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, C4 extends ExpressionLike, V4 extends CaseWhenValue, C5 extends ExpressionLike, V5 extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, condition4: C4, value4: V4, condition5: C5, value5: V5, ): CaseWhenResult<[V1, V2, V3, V4, V5], false> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, C4 extends ExpressionLike, V4 extends CaseWhenValue, C5 extends ExpressionLike, V5 extends CaseWhenValue, D extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, condition4: C4, value4: V4, condition5: C5, value5: V5, defaultValue: D, ): CaseWhenResult<[V1, V2, V3, V4, V5, D], true> export function caseWhen< C1 extends ExpressionLike, V1 extends CaseWhenValue, C2 extends ExpressionLike, V2 extends CaseWhenValue, C3 extends ExpressionLike, V3 extends CaseWhenValue, C4 extends ExpressionLike, V4 extends CaseWhenValue, C5 extends ExpressionLike, V5 extends CaseWhenValue, >( condition1: C1, value1: V1, condition2: C2, value2: V2, condition3: C3, value3: V3, condition4: C4, value4: V4, condition5: C5, value5: V5, condition6: ExpressionLike, value6: CaseWhenValue, ...rest: Array ): any export function caseWhen(...args: Array): any { if (args.length < 2) { throw new Error(`caseWhen() requires at least two arguments`) } const pairCount = Math.floor(args.length / 2) for (let i = 0; i < pairCount; i++) { const condition = args[i * 2] if (!isConditionValue(condition)) { throw new Error(`caseWhen() conditions must be expression-like values`) } } if (caseWhenHasOnlyExpressionValues(args)) { return new Func( `caseWhen`, args.map((arg) => toExpression(arg)), ) } return new CaseWhenWrapper(args) } export function add( left: T1, right: T2, ): BinaryNumericReturnType { return new Func(`add`, [ toExpression(left), toExpression(right), ]) as BinaryNumericReturnType } // Aggregates export function count(arg: ExpressionLike): Aggregate { return new Aggregate(`count`, [toExpression(arg)]) } export function avg(arg: T): AggregateReturnType { return new Aggregate(`avg`, [toExpression(arg)]) as AggregateReturnType } export function sum(arg: T): AggregateReturnType { return new Aggregate(`sum`, [toExpression(arg)]) as AggregateReturnType } export function min(arg: T): AggregateReturnType { return new Aggregate(`min`, [toExpression(arg)]) as AggregateReturnType } export function max(arg: T): AggregateReturnType { return new Aggregate(`max`, [toExpression(arg)]) as AggregateReturnType } /** * List of comparison function names that can be used with indexes */ export const comparisonFunctions = [ `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `like`, `ilike`, ] as const /** * All supported operator names in TanStack DB expressions */ export const operators = [ // Comparison operators `eq`, `gt`, `gte`, `lt`, `lte`, `in`, `like`, `ilike`, // Logical operators `and`, `or`, `not`, // Null checking `isNull`, `isUndefined`, // String functions `upper`, `lower`, `length`, `concat`, // Numeric functions `add`, // Utility functions `coalesce`, `caseWhen`, // Aggregate functions `count`, `avg`, `sum`, `min`, `max`, ] as const export type OperatorName = (typeof operators)[number] export class ToArrayWrapper<_T = unknown> { readonly __brand = `ToArrayWrapper` as const declare readonly _type: `toArray` declare readonly _result: _T constructor(public readonly query: QueryBuilder) {} } export class ConcatToArrayWrapper<_T = unknown> { readonly __brand = `ConcatToArrayWrapper` as const declare readonly _type: `concatToArray` declare readonly _result: _T constructor(public readonly query: QueryBuilder) {} } export class CaseWhenWrapper<_T = any> { readonly __brand = `CaseWhenWrapper` as const declare readonly _type: `caseWhen` readonly _result?: _T constructor(public readonly args: Array) {} } export function toArray( query: QueryBuilder, ): ToArrayWrapper> { return new ToArrayWrapper(query) } function caseWhenHasOnlyExpressionValues(args: Array): boolean { const valueIndexes = getCaseWhenValueIndexes(args.length) return valueIndexes.every((index) => isExpressionValue(args[index])) } function getCaseWhenValueIndexes(argCount: number): Array { const valueIndexes: Array = [] const hasDefaultValue = argCount % 2 === 1 const pairCount = Math.floor(argCount / 2) for (let i = 0; i < pairCount; i++) { valueIndexes.push(i * 2 + 1) } if (hasDefaultValue) { valueIndexes.push(argCount - 1) } return valueIndexes } function isExpressionValue(value: CaseWhenValue | undefined): boolean { if (isRefProxy(value)) return true if (value instanceof Aggregate || value instanceof Func) return true if (value == null) return true if ( typeof value === `string` || typeof value === `number` || typeof value === `boolean` || typeof value === `bigint` ) { return true } if (value instanceof Date || Array.isArray(value)) return true if (typeof value === `object`) { const candidate = value as { type?: unknown args?: unknown name?: unknown path?: unknown value?: unknown } if ( (candidate.type === `agg` || candidate.type === `func`) && typeof candidate.name === `string` && Array.isArray(candidate.args) ) { return true } if (candidate.type === `ref` && Array.isArray(candidate.path)) return true if (candidate.type === `val` && `value` in candidate) return true } return false } function isConditionValue(value: CaseWhenValue | undefined): boolean { return isExpressionValue(value) && !Array.isArray(value) }