{"version":3,"file":"sql_db_chain.cjs","names":["BaseChain","DEFAULT_SQL_DATABASE_PROMPT","getPromptTemplateFromDataSource","LLMChain","SQL_PROMPTS_MAP","RunnableSequence","RunnablePassthrough","StringOutputParser"],"sources":["../../../src/chains/sql_db/sql_db_chain.ts"],"sourcesContent":["import type {\n  BaseLanguageModel,\n  BaseLanguageModelInterface,\n} from \"@langchain/core/language_models/base\";\nimport type { OpenAI, TiktokenModel } from \"@langchain/openai\";\nimport { ChainValues } from \"@langchain/core/utils/types\";\nimport { BasePromptTemplate, PromptTemplate } from \"@langchain/core/prompts\";\nimport {\n  calculateMaxTokens,\n  getModelContextSize,\n} from \"@langchain/core/language_models/base\";\nimport { CallbackManagerForChainRun } from \"@langchain/core/callbacks/manager\";\nimport {\n  RunnablePassthrough,\n  RunnableSequence,\n} from \"@langchain/core/runnables\";\nimport { StringOutputParser } from \"@langchain/core/output_parsers\";\nimport {\n  DEFAULT_SQL_DATABASE_PROMPT,\n  SQL_PROMPTS_MAP,\n  SqlDialect,\n} from \"./sql_db_prompt.js\";\nimport { BaseChain, ChainInputs } from \"../base.js\";\nimport { LLMChain } from \"../llm_chain.js\";\nimport type { SqlDatabase } from \"../../sql_db.js\";\nimport { getPromptTemplateFromDataSource } from \"../../util/sql_utils.js\";\n\n/**\n * Interface that extends the ChainInputs interface and defines additional\n * fields specific to a SQL database chain. It represents the input fields\n * for a SQL database chain.\n */\nexport interface SqlDatabaseChainInput extends ChainInputs {\n  llm: BaseLanguageModelInterface;\n  database: SqlDatabase;\n  topK?: number;\n  inputKey?: string;\n  outputKey?: string;\n  sqlOutputKey?: string;\n  prompt?: PromptTemplate;\n}\n\n/**\n * Class that represents a SQL database chain in the LangChain framework.\n * It extends the BaseChain class and implements the functionality\n * specific to a SQL database chain.\n *\n * @security **Security Notice**\n * This chain generates SQL queries for the given database.\n * The SQLDatabase class provides a getTableInfo method that can be used\n * to get column information as well as sample data from the table.\n * To mitigate risk of leaking sensitive data, limit permissions\n * to read and scope to the tables that are needed.\n * Optionally, use the includesTables or ignoreTables class parameters\n * to limit which tables can/cannot be accessed.\n *\n * @link See https://js.langchain.com/docs/security for more information.\n * @example\n * ```typescript\n * const chain = new SqlDatabaseChain({\n *   llm: new OpenAI({ temperature: 0 }),\n *   database: new SqlDatabase({ ...config }),\n * });\n *\n * const result = await chain.run(\"How many tracks are there?\");\n * ```\n */\nexport class SqlDatabaseChain extends BaseChain {\n  static lc_name() {\n    return \"SqlDatabaseChain\";\n  }\n\n  // LLM wrapper to use\n  llm: BaseLanguageModelInterface;\n\n  // SQL Database to connect to.\n  database: SqlDatabase;\n\n  // Prompt to use to translate natural language to SQL.\n  prompt = DEFAULT_SQL_DATABASE_PROMPT;\n\n  // Number of results to return from the query\n  topK = 5;\n\n  inputKey = \"query\";\n\n  outputKey = \"result\";\n\n  sqlOutputKey: string | undefined = undefined;\n\n  // Whether to return the result of querying the SQL table directly.\n  returnDirect = false;\n\n  constructor(fields: SqlDatabaseChainInput) {\n    super(fields);\n    this.llm = fields.llm;\n    this.database = fields.database;\n    this.topK = fields.topK ?? this.topK;\n    this.inputKey = fields.inputKey ?? this.inputKey;\n    this.outputKey = fields.outputKey ?? this.outputKey;\n    this.sqlOutputKey = fields.sqlOutputKey ?? this.sqlOutputKey;\n    this.prompt =\n      fields.prompt ??\n      getPromptTemplateFromDataSource(this.database.appDataSource);\n  }\n\n  /** @ignore */\n  async _call(\n    values: ChainValues,\n    runManager?: CallbackManagerForChainRun\n  ): Promise<ChainValues> {\n    const llmChain = new LLMChain({\n      prompt: this.prompt,\n      llm: this.llm,\n      outputKey: this.outputKey,\n      memory: this.memory,\n    });\n    if (!(this.inputKey in values)) {\n      throw new Error(`Question key ${this.inputKey} not found.`);\n    }\n    const question: string = values[this.inputKey];\n    let inputText = `${question}\\nSQLQuery:`;\n    const tablesToUse = values.table_names_to_use;\n    const tableInfo = await this.database.getTableInfo(tablesToUse);\n\n    const llmInputs = {\n      input: inputText,\n      top_k: this.topK,\n      dialect: this.database.appDataSourceOptions.type,\n      table_info: tableInfo,\n      stop: [\"\\nSQLResult:\"],\n    };\n    await this.verifyNumberOfTokens(inputText, tableInfo);\n\n    const sqlCommand = await llmChain.predict(\n      llmInputs,\n      runManager?.getChild(\"sql_generation\")\n    );\n    let queryResult = \"\";\n    try {\n      queryResult = await this.database.appDataSource.query(sqlCommand);\n    } catch (error) {\n      console.error(error);\n    }\n\n    let finalResult;\n    if (this.returnDirect) {\n      finalResult = { [this.outputKey]: queryResult };\n    } else {\n      inputText += `${sqlCommand}\\nSQLResult: ${JSON.stringify(\n        queryResult\n      )}\\nAnswer:`;\n      llmInputs.input = inputText;\n      finalResult = {\n        [this.outputKey]: await llmChain.predict(\n          llmInputs,\n          runManager?.getChild(\"result_generation\")\n        ),\n      };\n    }\n\n    if (this.sqlOutputKey != null) {\n      finalResult[this.sqlOutputKey] = sqlCommand;\n    }\n\n    return finalResult;\n  }\n\n  _chainType() {\n    return \"sql_database_chain\" as const;\n  }\n\n  get inputKeys(): string[] {\n    return [this.inputKey];\n  }\n\n  get outputKeys(): string[] {\n    if (this.sqlOutputKey != null) {\n      return [this.outputKey, this.sqlOutputKey];\n    }\n    return [this.outputKey];\n  }\n\n  /**\n   * Private method that verifies the number of tokens in the input text and\n   * table information. It throws an error if the number of tokens exceeds\n   * the maximum allowed by the language model.\n   * @param inputText The input text.\n   * @param tableinfo The table information.\n   * @returns A promise that resolves when the verification is complete.\n   */\n  private async verifyNumberOfTokens(\n    inputText: string,\n    tableinfo: string\n  ): Promise<void> {\n    // We verify it only for OpenAI for the moment\n    if (this.llm._llmType() !== \"openai\") {\n      return;\n    }\n    const llm = this.llm as OpenAI;\n    const promptTemplate = this.prompt.template;\n    const stringWeSend = `${inputText}${promptTemplate}${tableinfo}`;\n\n    const maxToken = await calculateMaxTokens({\n      prompt: stringWeSend,\n      // Cast here to allow for other models that may not fit the union\n      modelName: llm.model as TiktokenModel,\n    });\n\n    if (maxToken < (llm.maxTokens ?? -1)) {\n      throw new Error(`The combination of the database structure and your question is too big for the model ${\n        llm.model\n      } which can compute only a max tokens of ${getModelContextSize(\n        llm.model\n      )}.\n      We suggest you to use the includeTables parameters when creating the SqlDatabase object to select only a subset of the tables. You can also use a model which can handle more tokens.`);\n    }\n  }\n}\n\nexport interface CreateSqlQueryChainFields {\n  llm: BaseLanguageModel;\n  db: SqlDatabase;\n  prompt?: BasePromptTemplate;\n  /**\n   * @default 5\n   */\n  k?: number;\n  dialect: SqlDialect;\n}\n\ntype SqlInput = {\n  question: string;\n};\n\ntype SqlInoutWithTables = SqlInput & {\n  tableNamesToUse: string[];\n};\n\nconst strip = (text: string) => {\n  // Replace escaped quotes with actual quotes\n  let newText = text.replace(/\\\\\"/g, '\"').trim();\n  // Remove wrapping quotes if the entire string is wrapped in quotes\n  if (newText.startsWith('\"') && newText.endsWith('\"')) {\n    newText = newText.substring(1, newText.length - 1);\n  }\n  return newText;\n};\n\nconst difference = (setA: Set<string>, setB: Set<string>) =>\n  new Set([...setA].filter((x) => !setB.has(x)));\n\n/**\n * Create a SQL query chain that can create SQL queries for the given database.\n * Returns a Runnable.\n *\n * @param {BaseLanguageModel} llm The language model to use in the chain.\n * @param {SqlDatabase} db The database to use in the chain.\n * @param {BasePromptTemplate | undefined} prompt The prompt to use in the chain.\n * @param {BaseLanguageModel | undefined} k The amount of docs/results to return. Passed through the prompt input value `top_k`.\n * @param {SqlDialect} dialect The SQL dialect to use in the chain.\n * @returns {Promise<RunnableSequence<Record<string, unknown>, string>>} A runnable sequence representing the chain.\n * @example ```typescript\n * const datasource = new DataSource({\n *   type: \"sqlite\",\n *   database: \"../../../../Chinook.db\",\n * });\n * const db = await SqlDatabase.fromDataSourceParams({\n *   appDataSource: datasource,\n * });\n * const llm = new ChatOpenAI({ model: \"gpt-4o-mini\", temperature: 0 });\n * const chain = await createSqlQueryChain({\n *   llm,\n *   db,\n *   dialect: \"sqlite\",\n * });\n * ```\n */\nexport async function createSqlQueryChain({\n  llm,\n  db,\n  prompt,\n  k = 5,\n  dialect,\n}: CreateSqlQueryChainFields) {\n  let promptToUse: BasePromptTemplate;\n  if (prompt) {\n    promptToUse = prompt;\n  } else if (SQL_PROMPTS_MAP[dialect]) {\n    promptToUse = SQL_PROMPTS_MAP[dialect];\n  } else {\n    promptToUse = DEFAULT_SQL_DATABASE_PROMPT;\n  }\n\n  if (\n    difference(\n      new Set([\"input\", \"top_k\", \"table_info\"]),\n      new Set(promptToUse.inputVariables)\n    ).size > 0\n  ) {\n    throw new Error(\n      `Prompt must have input variables: 'input', 'top_k', 'table_info'. Received prompt with input variables: ` +\n        `${promptToUse.inputVariables}. Full prompt:\\n\\n${promptToUse}`\n    );\n  }\n  if (promptToUse.inputVariables.includes(\"dialect\")) {\n    promptToUse = await promptToUse.partial({ dialect });\n  }\n\n  promptToUse = await promptToUse.partial({ top_k: k.toString() });\n\n  const inputs = {\n    input: (x: Record<string, unknown>) => {\n      if (\"question\" in x) {\n        return `${(x as SqlInput).question}\\nSQLQuery: `;\n      }\n      throw new Error(\"Input must include a question property.\");\n    },\n    table_info: async (x: Record<string, unknown>) =>\n      db.getTableInfo((x as SqlInoutWithTables).tableNamesToUse),\n  };\n\n  return RunnableSequence.from([\n    RunnablePassthrough.assign(inputs),\n    (x) => {\n      const newInputs = { ...x };\n      delete newInputs.question;\n      delete newInputs.tableNamesToUse;\n      return newInputs;\n    },\n    promptToUse,\n    llm.withConfig({ stop: [\"\\nSQLResult:\"] }),\n    new StringOutputParser(),\n    strip,\n  ]);\n}\n"],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAmEA,IAAa,mBAAb,cAAsCA,aAAAA,UAAU;CAC9C,OAAO,UAAU;AACf,SAAO;;CAIT;CAGA;CAGA,SAASC,sBAAAA;CAGT,OAAO;CAEP,WAAW;CAEX,YAAY;CAEZ,eAAmC,KAAA;CAGnC,eAAe;CAEf,YAAY,QAA+B;AACzC,QAAM,OAAO;AACb,OAAK,MAAM,OAAO;AAClB,OAAK,WAAW,OAAO;AACvB,OAAK,OAAO,OAAO,QAAQ,KAAK;AAChC,OAAK,WAAW,OAAO,YAAY,KAAK;AACxC,OAAK,YAAY,OAAO,aAAa,KAAK;AAC1C,OAAK,eAAe,OAAO,gBAAgB,KAAK;AAChD,OAAK,SACH,OAAO,UACPC,kBAAAA,gCAAgC,KAAK,SAAS,cAAc;;;CAIhE,MAAM,MACJ,QACA,YACsB;EACtB,MAAM,WAAW,IAAIC,kBAAAA,SAAS;GAC5B,QAAQ,KAAK;GACb,KAAK,KAAK;GACV,WAAW,KAAK;GAChB,QAAQ,KAAK;GACd,CAAC;AACF,MAAI,EAAE,KAAK,YAAY,QACrB,OAAM,IAAI,MAAM,gBAAgB,KAAK,SAAS,aAAa;EAG7D,IAAI,YAAY,GADS,OAAO,KAAK,UACT;EAC5B,MAAM,cAAc,OAAO;EAC3B,MAAM,YAAY,MAAM,KAAK,SAAS,aAAa,YAAY;EAE/D,MAAM,YAAY;GAChB,OAAO;GACP,OAAO,KAAK;GACZ,SAAS,KAAK,SAAS,qBAAqB;GAC5C,YAAY;GACZ,MAAM,CAAC,eAAe;GACvB;AACD,QAAM,KAAK,qBAAqB,WAAW,UAAU;EAErD,MAAM,aAAa,MAAM,SAAS,QAChC,WACA,YAAY,SAAS,iBAAiB,CACvC;EACD,IAAI,cAAc;AAClB,MAAI;AACF,iBAAc,MAAM,KAAK,SAAS,cAAc,MAAM,WAAW;WAC1D,OAAO;AACd,WAAQ,MAAM,MAAM;;EAGtB,IAAI;AACJ,MAAI,KAAK,aACP,eAAc,GAAG,KAAK,YAAY,aAAa;OAC1C;AACL,gBAAa,GAAG,WAAW,eAAe,KAAK,UAC7C,YACD,CAAC;AACF,aAAU,QAAQ;AAClB,iBAAc,GACX,KAAK,YAAY,MAAM,SAAS,QAC/B,WACA,YAAY,SAAS,oBAAoB,CAC1C,EACF;;AAGH,MAAI,KAAK,gBAAgB,KACvB,aAAY,KAAK,gBAAgB;AAGnC,SAAO;;CAGT,aAAa;AACX,SAAO;;CAGT,IAAI,YAAsB;AACxB,SAAO,CAAC,KAAK,SAAS;;CAGxB,IAAI,aAAuB;AACzB,MAAI,KAAK,gBAAgB,KACvB,QAAO,CAAC,KAAK,WAAW,KAAK,aAAa;AAE5C,SAAO,CAAC,KAAK,UAAU;;;;;;;;;;CAWzB,MAAc,qBACZ,WACA,WACe;AAEf,MAAI,KAAK,IAAI,UAAU,KAAK,SAC1B;EAEF,MAAM,MAAM,KAAK;AAUjB,MANiB,OAAA,GAAA,qCAAA,oBAAyB;GACxC,QAHmB,GAAG,YADD,KAAK,OAAO,WACkB;GAKnD,WAAW,IAAI;GAChB,CAAC,IAEc,IAAI,aAAa,IAC/B,OAAM,IAAI,MAAM,wFACd,IAAI,MACL,2CAAA,GAAA,qCAAA,qBACC,IAAI,MACL,CAAC;6LACqL;;;AAwB7L,MAAM,SAAS,SAAiB;CAE9B,IAAI,UAAU,KAAK,QAAQ,QAAQ,KAAI,CAAC,MAAM;AAE9C,KAAI,QAAQ,WAAW,KAAI,IAAI,QAAQ,SAAS,KAAI,CAClD,WAAU,QAAQ,UAAU,GAAG,QAAQ,SAAS,EAAE;AAEpD,QAAO;;AAGT,MAAM,cAAc,MAAmB,SACrC,IAAI,IAAI,CAAC,GAAG,KAAK,CAAC,QAAQ,MAAM,CAAC,KAAK,IAAI,EAAE,CAAC,CAAC;;;;;;;;;;;;;;;;;;;;;;;;;;;AA4BhD,eAAsB,oBAAoB,EACxC,KACA,IACA,QACA,IAAI,GACJ,WAC4B;CAC5B,IAAI;AACJ,KAAI,OACF,eAAc;UACLC,sBAAAA,gBAAgB,SACzB,eAAcA,sBAAAA,gBAAgB;KAE9B,eAAcH,sBAAAA;AAGhB,KACE,WACE,IAAI,IAAI;EAAC;EAAS;EAAS;EAAa,CAAC,EACzC,IAAI,IAAI,YAAY,eAAe,CACpC,CAAC,OAAO,EAET,OAAM,IAAI,MACR,2GACK,YAAY,eAAe,oBAAoB,cACrD;AAEH,KAAI,YAAY,eAAe,SAAS,UAAU,CAChD,eAAc,MAAM,YAAY,QAAQ,EAAE,SAAS,CAAC;AAGtD,eAAc,MAAM,YAAY,QAAQ,EAAE,OAAO,EAAE,UAAU,EAAE,CAAC;AAahE,QAAOI,0BAAAA,iBAAiB,KAAK;EAC3BC,0BAAAA,oBAAoB,OAZP;GACb,QAAQ,MAA+B;AACrC,QAAI,cAAc,EAChB,QAAO,GAAI,EAAe,SAAS;AAErC,UAAM,IAAI,MAAM,0CAA0C;;GAE5D,YAAY,OAAO,MACjB,GAAG,aAAc,EAAyB,gBAAgB;GAC7D,CAGmC;GACjC,MAAM;GACL,MAAM,YAAY,EAAE,GAAG,GAAG;AAC1B,UAAO,UAAU;AACjB,UAAO,UAAU;AACjB,UAAO;;EAET;EACA,IAAI,WAAW,EAAE,MAAM,CAAC,eAAe,EAAE,CAAC;EAC1C,IAAIC,+BAAAA,oBAAoB;EACxB;EACD,CAAC"}