import {PromptTemplate} from '@langchain/core/prompts'; import {RunnableSequence} from '@langchain/core/runnables'; import {LangGraphRunnableConfig} from '@langchain/langgraph'; import {inject, service} from '@loopback/core'; import {graphNode} from '../../../decorators'; import {IGraphNode, LLMStreamEventType} from '../../../graphs'; import {AiIntegrationBindings} from '../../../keys'; import {LLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import {AIMessage} from '@langchain/core/messages'; import {DbQueryAIExtensionBindings} from '../keys'; import {DbQueryNodes} from '../nodes.enum'; import {DbSchemaHelperService} from '../services'; import {DbQueryState} from '../state'; import {DbQueryConfig} from '../types'; @graphNode(DbQueryNodes.GenerateChecklist) export class GenerateChecklistNode implements IGraphNode { constructor( @inject(AiIntegrationBindings.CheapLLM) private readonly llm: LLMProvider, @inject(DbQueryAIExtensionBindings.Config) private readonly config: DbQueryConfig, @service(DbSchemaHelperService) private readonly schemaHelper: DbSchemaHelperService, @inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true}) private readonly checks?: string[], ) {} prompt = PromptTemplate.fromTemplate(` You are given a user question, the tables selected for SQL generation, the relevant database schema, and a numbered list of rules/checks. Return ONLY the indexes of the rules that are relevant to the user's question, the selected tables, and the given schema. A rule is relevant if: - It directly affects how a correct SQL query should be written for this question. - It is a dependency of another relevant rule (e.g. if rule 3 requires a currency conversion, and rule 5 defines how currency conversion works, both must be included). - It applies to any of the selected tables or their relationships. After selecting relevant rules, review your selection and ensure: - Any rule that is referenced by, or is a prerequisite for, another selected rule is also included. - Do not include rules that are completely unrelated to the question, schema, or selected tables. {prompt} {tables} {schema} {indexedChecks} Return only a comma-separated list of the relevant rule indexes. Do not include any other text, explanation, or formatting. Example: 1,3,5 If no rules are relevant, return: none `); async execute( state: DbQueryState, config: LangGraphRunnableConfig, ): Promise { const empty = {} as DbQueryState; if (this.config.nodes?.generateChecklistNode?.enabled === false) { return empty; } if (state.validationChecklist) { return empty; } const tableCount = Object.keys(state.schema?.tables ?? {}).length; if (tableCount <= 2) { return empty; } const allChecks = [ ...(this.checks ?? []), ...this.schemaHelper.getTablesContext(state.schema), ]; if (allChecks.length === 0) { return empty; } config.writer?.({ type: LLMStreamEventType.Log, data: 'Filtering validation checklist for semantic validation.', }); const mergedIndexes = await this.runParallelChecklist(state, allChecks); if (mergedIndexes.size === 0) { return empty; } const validationChecklist = Array.from(mergedIndexes) .sort((a, b) => a - b) .map(i => allChecks[i - 1]) .join('\n'); return {validationChecklist} as DbQueryState; } private async runParallelChecklist( state: DbQueryState, allChecks: string[], ): Promise> { const indexedChecks = allChecks .map((check, i) => `${i + 1}. ${check}`) .join('\n'); const parallelism = this.config.nodes?.generateChecklistNode?.parallelism ?? 1; const chain = RunnableSequence.from([this.prompt, this.llm]); const invokeArgs = { prompt: state.prompt, tables: Object.keys(state.schema?.tables ?? {}).join(', '), schema: this.schemaHelper.asString(state.schema), indexedChecks, }; const results = await Promise.all( Array.from({length: parallelism}, () => chain.invoke(invokeArgs)), ); const mergedIndexes = new Set(); for (const output of results) { this.parseIndexes(output, allChecks.length).forEach(n => mergedIndexes.add(n), ); } return mergedIndexes; } private parseIndexes(output: AIMessage, maxIndex: number): number[] { const response = stripThinkingTokens(output).trim(); if (!response || response === 'none') return []; return response .split(',') .map(s => Number.parseInt(s.trim(), 10)) .filter(n => !Number.isNaN(n) && n >= 1 && n <= maxIndex); } }