import { type OperatorNodeFn, type OperatorNode, type OperatorNodeMap, isFunctionNode, isSymbolNode, type MathNode, isOperatorNode, isNode, } from 'mathjs'; import { isVmArray, type VmTypeName, type VmAny } from '@mirascript/mirascript'; import { operations } from '@mirascript/mirascript/subtle'; import { symbolName, constantValue, equalText, scalar, globalFnName } from './utils.js'; import type { State } from './state.js'; import type { Options, Result } from './interface.js'; import { migrateAtomic, migrateExpr, migrateParen } from './node.js'; import { toBoolean, toNumber } from './to-type.js'; import { migrateSymbol } from './symbol.js'; import { serialize } from './serialize.js'; const BINARY_MATH_OPERATORS = { add: [ ' + ', (state, l, r) => ({ type: l.type === 'array' || r.type === 'array' ? 'array' : l.type && r.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.add(${l.code}, ${r.code})`, }), ], subtract: [ ' - ', (state, l, r) => ({ type: l.type === 'array' || r.type === 'array' ? 'array' : l.type && r.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.subtract(${l.code}, ${r.code})`, }), ], multiply: [ ' * ', (state, l, r) => ({ type: l.type === 'array' || r.type === 'array' ? 'array' : l.type && r.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.multiply(${l.code}, ${r.code})`, }), ], dotMultiply: [ ' * ', (state, l, r) => ({ type: l.type === 'array' || r.type === 'array' ? 'array' : l.type && r.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.entrywise_multiply(${l.code}, ${r.code})`, }), ], divide: [ ' / ', (state, l, r) => { if (r.type === 'array' || !r.type) { return { type: l.type === 'array' || r.type === 'array' ? 'array' : l.type && r.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.multiply(${l.code}, ${globalFnName(state, 'matrix')}.invert(${r.code}))`, }; } return { type: l.type === 'array' ? 'array' : l.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.entrywise_divide(${l.code}, ${r.code})`, }; }, ], dotDivide: [ ' / ', (state, l, r) => ({ type: l.type === 'array' || r.type === 'array' ? 'array' : l.type && r.type ? 'number' : undefined, code: `${globalFnName(state, 'matrix')}.entrywise_divide(${l.code}, ${r.code})`, }), ], mod: [' % '], pow: ['^', false], dotPow: ['^'], } satisfies Record Result) | false]>; const MATH_FUNCTIONS = { factorial: 'factorial', bitAnd: 'b_and', bitOr: 'b_or', bitNot: 'b_not', bitXor: 'b_xor', leftShift: 'shl', rightArithShift: 'sar', rightLogShift: 'shr', } as const; const BIT_OPS_TO_BOOL_OPS = { bitAnd: ['&&', false], bitOr: ['||', false], bitXor: ['!=', true], } as const; const COMPARE_OPERATORS = { smaller: ['<'], smallerEq: ['<='], larger: ['>'], largerEq: ['>='], equal: ['=~', '=='], unequal: ['!~', '!='], } as const; /** 转换为 boolean */ function b(op: string, state: State, node: MathNode, migrator = migrateExpr): string { const re = migrator(state, node); return toBoolean(state, scalar(op, state, re)).code; } /** 数组元素类型 */ function elementType(lit: VmAny): VmTypeName | undefined { if (!isVmArray(lit) || !lit.length) return undefined; let type: VmTypeName | undefined = undefined; for (const e of lit) { const t = operations.$Type(e); if (!type) { type = t; } else if (type !== t) { return undefined; } } return type; } /** 二元操作 */ function binary( state: State, l: Result, r: Result, op: (state: State, l: Result, r: Result) => Result, alt?: (state: State, l: Result, r: Result) => Result, ): Result { if (l.type && r.type && l.type !== 'array' && r.type !== 'array') { return op(state, l, r); } if (Array.isArray(l.literal) && l.literal.every((e) => !Array.isArray(e))) { const it = { code: 'it', type: elementType(l.literal) }; return { type: 'array', code: `${l.code}::${globalFnName(state, 'map')}(fn { ${op(state, it, r).code} })`, }; } if (Array.isArray(r.literal) && r.literal.every((e) => !Array.isArray(e))) { const it = { code: 'it', type: elementType(r.literal) }; return { type: 'array', code: `${r.code}::${globalFnName(state, 'map')}(fn { ${op(state, l, it).code} })`, }; } if (alt) { return alt(state, l, r); } const a = { code: 'a', type: l.type === 'array' ? elementType(l.literal) : l.type }; const b = { code: 'b', type: r.type === 'array' ? elementType(r.literal) : r.type }; return { type: l.type === 'array' || r.type === 'array' ? 'array' : undefined, code: `${globalFnName(state, 'matrix')}.entrywise(${l.code}, ${r.code}, fn (a, b) { ${op(state, a, b).code} })`, }; } /** 一元操作 */ function unary(state: State, v: Result, op: (state: State, v: Result) => Result): Result { if (v.type && v.type !== 'array') { return op(state, v); } if (Array.isArray(v.literal) && v.literal.every((e) => !Array.isArray(e))) { const it = { code: 'it', type: elementType(v.literal) }; return { type: 'array', code: `${v.code}::${globalFnName(state, 'map')}(fn { ${op(state, it).code} })`, }; } return { type: 'array', code: `${globalFnName(state, 'matrix')}.entrywise(${v.code}, nil, fn { ${op(state, { code: 'it' }).code} })`, }; } /** 转换 AST */ export function migrateOperator( state: State, node: Pick, 'fn' | 'args'>, options: Options, ): Result { const fn = node.fn as string; const { args } = node; const a0 = args[0]!; const a1 = args[1]!; const open = options.format === 'paren' ? '(' : ''; const close = options.format === 'paren' ? ')' : ''; switch (fn) { case 'add': case 'subtract': case 'multiply': case 'divide': case 'mod': case 'pow': case 'dotMultiply': case 'dotDivide': case 'dotPow': { const [op, alt] = BINARY_MATH_OPERATORS[fn]; return binary( state, migrateExpr(state, a0), migrateExpr(state, a1), (state, l, r) => ({ type: 'number', code: `${l.code}${op}${r.code}`, }), alt === false ? (state, l, r) => { state.warn(`无法确定 '${op.trim()}' 的操作数为标量类型,计算结果可能不一致`); return { code: `${l.code}${op}${r.code}` }; } : alt, ); } case 'unaryMinus': case 'unaryPlus': { const op = fn === 'unaryMinus' ? '-' : '+'; const exp = migrateExpr(state, a0); if (typeof exp.literal == 'number') { const f = fn === 'unaryMinus' ? operations.$Neg : operations.$Pos; const v = f(exp.literal); return { type: 'number', literal: v, code: serialize(v), }; } return unary(state, migrateExpr(state, a0), (state, v) => ({ type: 'number', code: `${op}${v.code}`, })); } case 'factorial': case 'bitAnd': case 'bitOr': case 'bitXor': case 'bitNot': case 'leftShift': case 'rightArithShift': case 'rightLogShift': { const f = MATH_FUNCTIONS[fn]; const codes = args.map((a) => { const r = migrateAtomic(state, a); return scalar(f, state, r); }); if (codes.every((c) => c.type === 'boolean' || c.as_boolean) && fn in BIT_OPS_TO_BOOL_OPS) { const [boolOp, needParen] = BIT_OPS_TO_BOOL_OPS[fn as keyof typeof BIT_OPS_TO_BOOL_OPS]; const code = args .map((a) => toBoolean(state, (needParen ? migrateParen : migrateExpr)(state, a)).code) .join(` ${boolOp} `); return { type: 'number', code: `to_number(${code})`, as_boolean: `${open}${code}${close}`, }; } return { type: 'number', code: `${f}(${codes.map((c) => c.code).join(', ')})`, }; } case 'ctranspose': { const p = migrateAtomic(state, a0); state.loose(); return { type: 'array', code: `transpose(${p.code})`, }; } case 'and': return { type: 'boolean', code: `${open}${b('&&', state, a0)} && ${b('&&', state, a1)}${close}`, }; case 'or': return { type: 'boolean', code: `${open}${b('||', state, a0)} || ${b('||', state, a1)}${close}`, }; case 'xor': { return { type: 'boolean', code: `${open}${b('!=', state, a0, migrateParen)} != ${b('!=', state, a1, migrateParen)}${close}`, }; } case 'not': { if (isFunctionNode(a0) && symbolName(a0.fn) === 'equalText' && a0.args.length === 2) { const p = equalText(state, '!=', a0.args[0]!, a0.args[1]!); return { type: p.type, code: `${open}${p.code}${close}`, }; } if (isOperatorNode(a0) && (a0.op === '==' || a0.op === '!=')) { const c = a0.clone(); c.op = a0.op === '==' ? '!=' : '=='; return migrateOperator(state, c, { ...options, format: options.format === 'no-paren' ? 'no-paren' : 'paren', }); } return { type: 'boolean', code: `${open}!${b('!', state, a0)}${close}`, }; } case 'equal': case 'unequal': case 'smaller': case 'smallerEq': case 'larger': case 'largerEq': { const [op, eqOp] = COMPARE_OPERATORS[fn]; const c0 = constantValue(a0); const c1 = constantValue(a1); if (eqOp && (c0 === null || c1 === null)) { // Mathjs 只支持标量与 null 比较 const p0 = c0 === null ? { code: `nil` } : migrateExpr(state, a0); const p1 = c1 === null ? { code: `nil` } : migrateExpr(state, a1); return { type: 'boolean', code: `${open}${p0.code} ${eqOp} ${p1.code}${close}`, }; } // PI, E 与常量比较时,不进行转换 const preserveConst = typeof c0 == 'number' || typeof c1 == 'number'; const a = (a: MathNode | Result) => { let r; if (!isNode(a)) { r = a; } else if (isSymbolNode(a)) { r = migrateSymbol(state, a, !preserveConst); } else { r = migrateExpr(state, a); } return r; }; return binary(state, a(a0), a(a1), (state, l, r) => { if (l.literal !== undefined) { return { type: 'boolean', code: `${toNumber(state, l).code} ${op} ${r.code}`, }; } if (r.literal !== undefined) { return { type: 'boolean', code: `${l.code} ${op} ${toNumber(state, r).code}`, }; } if (l.type === 'number' || r.type === 'number') { return { type: 'boolean', code: `${l.code} ${op} ${r.code}`, }; } return { type: 'boolean', code: `${toNumber(state, l).code} ${op} ${toNumber(state, r).code}`, }; }); } default: { state.err(`不支持的运算符: ${fn}`); if (a1) { return { code: `${open}${migrateExpr(state, a0).code} /* ${fn} */ ${migrateExpr(state, a1).code}${close}`, }; } else { return { code: `${open}${migrateExpr(state, a0).code} /* ${fn} */${close}` }; } } } }