import {PromptTemplate} from '@langchain/core/prompts'; import {RunnableSequence} from '@langchain/core/runnables'; import {inject} from '@loopback/core'; import {AiIntegrationBindings} from '../../../keys'; import {LLMProvider} from '../../../types'; import {stripThinkingTokens} from '../../../utils'; import { DatabaseSchema, QueryTemplate, QueryTemplateMetadata, TemplatePlaceholder, } from '../types'; import {RunnableConfig} from '../../../graphs'; const MAX_TEMPLATE_RECURSION_DEPTH = 3; type ResolvedTemplate = { sql: string; description: string; }; export class TemplateHelper { constructor( @inject(AiIntegrationBindings.CheapLLM) private readonly llm: LLMProvider, ) {} extractionPrompt = PromptTemplate.fromTemplate(` You are an expert at extracting parameter values from natural language prompts. Given a user prompt, a SQL template, and a list of placeholders with their descriptions and types, extract the value for each placeholder from the prompt. For sql_expression placeholders, generate a valid SQL fragment that fits the position of the placeholder in the template. {prompt} {template} {placeholders} Return each extracted value as an XML tag where the tag name is the placeholder name. If a placeholder value cannot be determined from the prompt, use the default value if provided, or leave the tag empty. Rules per type: - string: Return the raw value only, without any surrounding quotes. Example: Acme Corp - number: Return the numeric value only. Example: 10 - boolean: Return true or false. Example: true - sql_expression: Return a complete, valid SQL fragment with proper SQL syntax including quotes where needed. Example: created_at > '2024-01-01' Do not return any other text or explanation, just the XML tags. `); async extractPlaceholderValues( placeholders: TemplatePlaceholder[], prompt: string, sqlTemplate: string, config: RunnableConfig, schema?: DatabaseSchema, ): Promise> { const chain = RunnableSequence.from([ this.extractionPrompt, this.llm, stripThinkingTokens, ]); const placeholderDescriptions = placeholders .map(p => { let desc = `- ${p.name} (type: ${p.type}): ${p.description}`; if (p.default) desc += ` [default: ${p.default}]`; const columnContext = this._getColumnContext(p, schema); if (columnContext) desc += `\n ${columnContext}`; return desc; }) .join('\n'); const response = await chain.invoke( { prompt, template: sqlTemplate, placeholders: placeholderDescriptions, }, config, ); return this._parseXmlValues(response, placeholders); } private _getColumnContext( placeholder: TemplatePlaceholder, schema?: DatabaseSchema, ): string | null { if (!schema || !placeholder.table || !placeholder.column) { return null; } const tableSchema = schema.tables[placeholder.table]; if (!tableSchema) { return null; } const columnSchema = tableSchema.columns[placeholder.column]; if (!columnSchema) { return null; } const parts: string[] = [ `Column "${placeholder.column}" in "${placeholder.table}" (${columnSchema.type})`, ]; if (columnSchema.description) { parts.push(columnSchema.description); } if (columnSchema.metadata) { const metaStr = Object.entries(columnSchema.metadata) .map(([k, v]) => `${k}: ${JSON.stringify(v)}`) .join(', '); if (metaStr) parts.push(metaStr); } // Include table-level context entries relevant to this column parts.push( ...this._getRelevantContextEntries( tableSchema.context, placeholder.column, ), ); return parts.join('. '); } private _getRelevantContextEntries( context: unknown[] | undefined, column: string, ): string[] { if (!context?.length) { return []; } const results: string[] = []; for (const ctx of context) { if ( typeof ctx === 'string' && ctx.toLowerCase().includes(column.toLowerCase()) ) { results.push(ctx); } else if ( typeof ctx === 'object' && ctx !== null && (ctx as Record)[column] ) { results.push((ctx as Record)[column]); } else { // do nothing } } return results; } private _parseXmlValues( xml: string, placeholders: TemplatePlaceholder[], ): Record { const result: Record = {}; for (const p of placeholders) { const match = new RegExp( String.raw`<${p.name}>([\s\S]*?)`, ).exec(xml); const value = match?.[1]?.trim(); result[p.name] = value?.length ? value : null; } return result; } async resolveTemplate( template: QueryTemplate, prompt: string, config: RunnableConfig, schema?: DatabaseSchema, templateFetcher?: (id: string) => Promise, depth = 0, ): Promise { if (depth > MAX_TEMPLATE_RECURSION_DEPTH) { throw new Error( `Max template recursion depth exceeded (${MAX_TEMPLATE_RECURSION_DEPTH})`, ); } // 1. Resolve template_ref placeholders first (before the LLM call) let sql = await this._resolveTemplateRefs( template, prompt, config, schema, templateFetcher, depth, ); // 2. Extract values only for non-template_ref placeholders via LLM const extractablePlaceholders = template.placeholders.filter( p => p.type !== 'template_ref', ); let values: Record = {}; if (extractablePlaceholders.length > 0) { values = await this.extractPlaceholderValues( extractablePlaceholders, prompt, sql, config, schema, ); } // 3. Substitute extracted values directly into SQL sql = this._substitutePlaceholders(sql, extractablePlaceholders, values); return { sql, description: template.description, }; } private async _resolveTemplateRefs( template: QueryTemplate, prompt: string, config: RunnableConfig, schema: DatabaseSchema | undefined, templateFetcher: | ((id: string) => Promise) | undefined, depth: number, ): Promise { let sql = template.template; const templateRefPlaceholders = template.placeholders.filter( p => p.type === 'template_ref', ); for (const placeholder of templateRefPlaceholders) { const marker = `{{${placeholder.name}}}`; if (!sql.includes(marker)) { continue; } if (!templateFetcher || !placeholder.templateId) { throw new Error( `Cannot resolve template_ref placeholder "${placeholder.name}" - no template fetcher or templateId`, ); } const refTemplate = await templateFetcher(placeholder.templateId); if (!refTemplate) { throw new Error( `Referenced template "${placeholder.templateId}" not found`, ); } const resolved = await this.resolveTemplate( refTemplate, prompt, config, schema, templateFetcher, depth + 1, ); sql = sql.replace(marker, `(${resolved.sql})`); } return sql; } private _substitutePlaceholders( sql: string, placeholders: TemplatePlaceholder[], values: Record, ): string { for (const placeholder of placeholders) { const value = values[placeholder.name] ?? placeholder.default ?? null; const marker = `{{${placeholder.name}}}`; if (!sql.includes(marker)) { continue; } if (placeholder.optional && !value) { sql = sql.replace( new RegExp(String.raw`\s*${this._escapeRegex(marker)}\s*`), ' ', ); continue; } sql = sql.replace(marker, this._formatValue(placeholder.type, value)); } return sql; } private _formatValue(type: string, value: string | null): string { switch (type) { case 'string': return `'${(value ?? '').replace(/'/g, "''")}'`; case 'number': return `${Number(value) || 0}`; case 'boolean': return this._isTruthy(value) ? 'TRUE' : 'FALSE'; case 'sql_expression': return value ?? '1=1'; default: return value ?? ''; } } private _isTruthy(value: string | null): boolean { const lower = value?.toLowerCase(); return lower === 'true' || lower === 'yes' || value === '1'; } private _escapeRegex(str: string): string { return str.replace(/[.*+?^${}()|[\]\\]/g, String.raw`\$&`); } parseTemplateMetadata(metadata: QueryTemplateMetadata): QueryTemplate { return { id: metadata.templateId, tenantId: '', template: metadata.template, description: metadata.description, placeholders: JSON.parse(metadata.placeholders), tables: JSON.parse(metadata.tables), schemaHash: metadata.schemaHash, votes: metadata.votes, prompt: '', }; } }