Skip to content

Commit 0423fc2

Browse files
authored
Merge pull request #847 from FlowiseAI/feature/SQLDatabaseChain
Feature/Add custom prompt to SQLDbChain
2 parents 2cbaaa7 + 3f0157d commit 0423fc2

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

packages/components/nodes/chains/SqlDatabaseChain/SqlDatabaseChain.ts

+48-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface'
22
import { SqlDatabaseChain, SqlDatabaseChainInput } from 'langchain/chains/sql_db'
3-
import { getBaseClasses } from '../../../src/utils'
3+
import { getBaseClasses, getInputVariables } from '../../../src/utils'
44
import { DataSource } from 'typeorm'
55
import { SqlDatabase } from 'langchain/sql_db'
66
import { BaseLanguageModel } from 'langchain/base_language'
7+
import { PromptTemplate, PromptTemplateInput } from 'langchain/prompts'
78
import { ConsoleCallbackHandler, CustomChainHandler } from '../../../src/handler'
89
import { DataSourceOptions } from 'typeorm/data-source'
910

1011
type DatabaseType = 'sqlite' | 'postgres' | 'mssql' | 'mysql'
1112

13+
const defaultPrompt = `Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
14+
15+
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
16+
17+
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
18+
19+
Use the following format:
20+
21+
Question: "Question here"
22+
SQLQuery: "SQL Query to run"
23+
SQLResult: "Result of the SQLQuery"
24+
Answer: "Final answer here"
25+
26+
Only use the tables listed below.
27+
28+
{table_info}
29+
30+
Question: {input}`
31+
1232
class SqlDatabaseChain_Chains implements INode {
1333
label: string
1434
name: string
@@ -23,7 +43,7 @@ class SqlDatabaseChain_Chains implements INode {
2343
constructor() {
2444
this.label = 'Sql Database Chain'
2545
this.name = 'sqlDatabaseChain'
26-
this.version = 1.0
46+
this.version = 2.0
2747
this.type = 'SqlDatabaseChain'
2848
this.icon = 'sqlchain.svg'
2949
this.category = 'Chains'
@@ -64,6 +84,19 @@ class SqlDatabaseChain_Chains implements INode {
6484
name: 'url',
6585
type: 'string',
6686
placeholder: '1270.0.0.1:5432/chinook'
87+
},
88+
{
89+
label: 'Custom Prompt',
90+
name: 'customPrompt',
91+
type: 'string',
92+
description:
93+
'You can provide custom prompt to the chain. This will override the existing default prompt used. See <a target="_blank" href="https://python.langchain.com/docs/integrations/tools/sqlite#customize-prompt">guide</a>',
94+
warning:
95+
'Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above',
96+
rows: 4,
97+
placeholder: defaultPrompt,
98+
additionalParams: true,
99+
optional: true
67100
}
68101
]
69102
}
@@ -72,17 +105,19 @@ class SqlDatabaseChain_Chains implements INode {
72105
const databaseType = nodeData.inputs?.database as DatabaseType
73106
const model = nodeData.inputs?.model as BaseLanguageModel
74107
const url = nodeData.inputs?.url
108+
const customPrompt = nodeData.inputs?.customPrompt as string
75109

76-
const chain = await getSQLDBChain(databaseType, url, model)
110+
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
77111
return chain
78112
}
79113

80114
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
81115
const databaseType = nodeData.inputs?.database as DatabaseType
82116
const model = nodeData.inputs?.model as BaseLanguageModel
83117
const url = nodeData.inputs?.url
118+
const customPrompt = nodeData.inputs?.customPrompt as string
84119

85-
const chain = await getSQLDBChain(databaseType, url, model)
120+
const chain = await getSQLDBChain(databaseType, url, model, customPrompt)
86121
const loggerHandler = new ConsoleCallbackHandler(options.logger)
87122

88123
if (options.socketIO && options.socketIOClientId) {
@@ -96,7 +131,7 @@ class SqlDatabaseChain_Chains implements INode {
96131
}
97132
}
98133

99-
const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel) => {
134+
const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseLanguageModel, customPrompt?: string) => {
100135
const datasource = new DataSource(
101136
databaseType === 'sqlite'
102137
? {
@@ -119,6 +154,14 @@ const getSQLDBChain = async (databaseType: DatabaseType, url: string, llm: BaseL
119154
verbose: process.env.DEBUG === 'true' ? true : false
120155
}
121156

157+
if (customPrompt) {
158+
const options: PromptTemplateInput = {
159+
template: customPrompt,
160+
inputVariables: getInputVariables(customPrompt)
161+
}
162+
obj.prompt = new PromptTemplate(options)
163+
}
164+
122165
const chain = new SqlDatabaseChain(obj)
123166
return chain
124167
}

packages/server/marketplaces/chatflows/SQL DB Chain.json

+20-7
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,17 @@
157157
},
158158
{
159159
"width": 300,
160-
"height": 423,
160+
"height": 475,
161161
"id": "sqlDatabaseChain_0",
162162
"position": {
163-
"x": 1229.0092429246013,
164-
"y": 231.59431102290245
163+
"x": 1206.5244299447634,
164+
"y": 201.04431101230608
165165
},
166166
"type": "customNode",
167167
"data": {
168168
"id": "sqlDatabaseChain_0",
169169
"label": "Sql Database Chain",
170-
"version": 1,
170+
"version": 2,
171171
"name": "sqlDatabaseChain",
172172
"type": "SqlDatabaseChain",
173173
"baseClasses": ["SqlDatabaseChain", "BaseChain", "Runnable"],
@@ -205,6 +205,18 @@
205205
"type": "string",
206206
"placeholder": "1270.0.0.1:5432/chinook",
207207
"id": "sqlDatabaseChain_0-input-url-string"
208+
},
209+
{
210+
"label": "Custom Prompt",
211+
"name": "customPrompt",
212+
"type": "string",
213+
"description": "You can provide custom prompt to the chain. This will override the existing default prompt used. See <a target=\"_blank\" href=\"https://python.langchain.com/docs/integrations/tools/sqlite#customize-prompt\">guide</a>",
214+
"warning": "Prompt must include 3 input variables: {input}, {dialect}, {table_info}. You can refer to official guide from description above",
215+
"rows": 4,
216+
"placeholder": "Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n\nNever query for all the columns from a specific table, only ask for a the few relevant columns given the question.\n\nPay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\nUse the following format:\n\nQuestion: \"Question here\"\nSQLQuery: \"SQL Query to run\"\nSQLResult: \"Result of the SQLQuery\"\nAnswer: \"Final answer here\"\n\nOnly use the tables listed below.\n\n{table_info}\n\nQuestion: {input}",
217+
"additionalParams": true,
218+
"optional": true,
219+
"id": "sqlDatabaseChain_0-input-customPrompt-string"
208220
}
209221
],
210222
"inputAnchors": [
@@ -218,7 +230,8 @@
218230
"inputs": {
219231
"model": "{{chatOpenAI_0.data.instance}}",
220232
"database": "sqlite",
221-
"url": ""
233+
"url": "",
234+
"customPrompt": ""
222235
},
223236
"outputAnchors": [
224237
{
@@ -233,8 +246,8 @@
233246
},
234247
"selected": false,
235248
"positionAbsolute": {
236-
"x": 1229.0092429246013,
237-
"y": 231.59431102290245
249+
"x": 1206.5244299447634,
250+
"y": 201.04431101230608
238251
},
239252
"dragging": false
240253
}

0 commit comments

Comments
 (0)