Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Update Query Engine Tool #1726

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 50 additions & 44 deletions packages/components/nodes/engine/QueryEngine/QueryEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class QueryEngine_LlamaIndex implements INode {
constructor(fields?: { sessionId?: string }) {
this.label = 'Query Engine'
this.name = 'queryEngine'
this.version = 1.0
this.version = 2.0
this.type = 'QueryEngine'
this.icon = 'query-engine.png'
this.category = 'Engine'
this.description = 'Simple query engine built to answer question over your data, without memory'
this.baseClasses = [this.type]
this.baseClasses = [this.type, 'BaseQueryEngine']
this.tags = ['LlamaIndex']
this.inputs = [
{
Expand All @@ -59,52 +59,13 @@ class QueryEngine_LlamaIndex implements INode {
this.sessionId = fields?.sessionId
}

async init(): Promise<any> {
return null
async init(nodeData: INodeData): Promise<any> {
return prepareEngine(nodeData)
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever
const responseSynthesizerObj = nodeData.inputs?.responseSynthesizer

let queryEngine = new RetrieverQueryEngine(vectorStoreRetriever)

if (responseSynthesizerObj) {
if (responseSynthesizerObj.type === 'TreeSummarize') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new TreeSummarize(vectorStoreRetriever.serviceContext, responseSynthesizerObj.textQAPromptTemplate),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
} else if (responseSynthesizerObj.type === 'CompactAndRefine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(
vectorStoreRetriever.serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
} else if (responseSynthesizerObj.type === 'Refine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new Refine(
vectorStoreRetriever.serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
} else if (responseSynthesizerObj.type === 'SimpleResponseBuilder') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new SimpleResponseBuilder(vectorStoreRetriever.serviceContext),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
}
}
const queryEngine = prepareEngine(nodeData)

let text = ''
let sourceDocuments: ICommonObject[] = []
Expand Down Expand Up @@ -140,4 +101,49 @@ class QueryEngine_LlamaIndex implements INode {
}
}

const prepareEngine = (nodeData: INodeData) => {
const vectorStoreRetriever = nodeData.inputs?.vectorStoreRetriever
const responseSynthesizerObj = nodeData.inputs?.responseSynthesizer

let queryEngine = new RetrieverQueryEngine(vectorStoreRetriever)

if (responseSynthesizerObj) {
if (responseSynthesizerObj.type === 'TreeSummarize') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new TreeSummarize(vectorStoreRetriever.serviceContext, responseSynthesizerObj.textQAPromptTemplate),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
} else if (responseSynthesizerObj.type === 'CompactAndRefine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(
vectorStoreRetriever.serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
} else if (responseSynthesizerObj.type === 'Refine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new Refine(
vectorStoreRetriever.serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
} else if (responseSynthesizerObj.type === 'SimpleResponseBuilder') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new SimpleResponseBuilder(vectorStoreRetriever.serviceContext),
serviceContext: vectorStoreRetriever.serviceContext
})
queryEngine = new RetrieverQueryEngine(vectorStoreRetriever, responseSynthesizer)
}
}

return queryEngine
}

module.exports = { nodeClass: QueryEngine_LlamaIndex }
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ class SubQuestionQueryEngine_LlamaIndex implements INode {
constructor(fields?: { sessionId?: string }) {
this.label = 'Sub Question Query Engine'
this.name = 'subQuestionQueryEngine'
this.version = 1.0
this.version = 2.0
this.type = 'SubQuestionQueryEngine'
this.icon = 'subQueryEngine.svg'
this.category = 'Engine'
this.description =
'Breaks complex query into sub questions for each relevant data source, then gather all the intermediate reponses and synthesizes a final response'
this.baseClasses = [this.type]
this.baseClasses = [this.type, 'BaseQueryEngine']
this.tags = ['LlamaIndex']
this.inputs = [
{
Expand Down Expand Up @@ -76,85 +76,13 @@ class SubQuestionQueryEngine_LlamaIndex implements INode {
this.sessionId = fields?.sessionId
}

async init(): Promise<any> {
return null
async init(nodeData: INodeData): Promise<any> {
return prepareEngine(nodeData)
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const embeddings = nodeData.inputs?.embeddings as BaseEmbedding
const model = nodeData.inputs?.model

const serviceContext = serviceContextFromDefaults({
llm: model,
embedModel: embeddings
})

let queryEngineTools = nodeData.inputs?.queryEngineTools as QueryEngineTool[]
queryEngineTools = flatten(queryEngineTools)

let queryEngine = SubQuestionQueryEngine.fromDefaults({
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})

const responseSynthesizerObj = nodeData.inputs?.responseSynthesizer
if (responseSynthesizerObj) {
if (responseSynthesizerObj.type === 'TreeSummarize') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new TreeSummarize(serviceContext, responseSynthesizerObj.textQAPromptTemplate),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'CompactAndRefine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'Refine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new Refine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'SimpleResponseBuilder') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new SimpleResponseBuilder(serviceContext),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
}
}
const queryEngine = prepareEngine(nodeData)

let text = ''
let sourceDocuments: ICommonObject[] = []
Expand Down Expand Up @@ -190,4 +118,82 @@ class SubQuestionQueryEngine_LlamaIndex implements INode {
}
}

const prepareEngine = (nodeData: INodeData) => {
const embeddings = nodeData.inputs?.embeddings as BaseEmbedding
const model = nodeData.inputs?.model

const serviceContext = serviceContextFromDefaults({
llm: model,
embedModel: embeddings
})

let queryEngineTools = nodeData.inputs?.queryEngineTools as QueryEngineTool[]
queryEngineTools = flatten(queryEngineTools)

let queryEngine = SubQuestionQueryEngine.fromDefaults({
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})

const responseSynthesizerObj = nodeData.inputs?.responseSynthesizer
if (responseSynthesizerObj) {
if (responseSynthesizerObj.type === 'TreeSummarize') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new TreeSummarize(serviceContext, responseSynthesizerObj.textQAPromptTemplate),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'CompactAndRefine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'Refine') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new Refine(
serviceContext,
responseSynthesizerObj.textQAPromptTemplate,
responseSynthesizerObj.refinePromptTemplate
),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
} else if (responseSynthesizerObj.type === 'SimpleResponseBuilder') {
const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new SimpleResponseBuilder(serviceContext),
serviceContext
})
queryEngine = SubQuestionQueryEngine.fromDefaults({
responseSynthesizer,
serviceContext,
queryEngineTools,
questionGen: new LLMQuestionGenerator({ llm: model })
})
}
}

return queryEngine
}

module.exports = { nodeClass: SubQuestionQueryEngine_LlamaIndex }
21 changes: 8 additions & 13 deletions packages/components/nodes/tools/QueryEngineTool/QueryEngineTool.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { VectorStoreIndex } from 'llamaindex'
import { BaseQueryEngine } from 'llamaindex'

class QueryEngine_Tools implements INode {
label: string
Expand All @@ -16,7 +16,7 @@ class QueryEngine_Tools implements INode {
constructor() {
this.label = 'QueryEngine Tool'
this.name = 'queryEngineToolLlamaIndex'
this.version = 1.0
this.version = 2.0
this.type = 'QueryEngineTool'
this.icon = 'queryEngineTool.svg'
this.category = 'Tools'
Expand All @@ -25,9 +25,9 @@ class QueryEngine_Tools implements INode {
this.baseClasses = [this.type]
this.inputs = [
{
label: 'Vector Store Index',
name: 'vectorStoreIndex',
type: 'VectorStoreIndex'
label: 'Base QueryEngine',
name: 'baseQueryEngine',
type: 'BaseQueryEngine'
},
{
label: 'Tool Name',
Expand All @@ -45,20 +45,15 @@ class QueryEngine_Tools implements INode {
}

async init(nodeData: INodeData): Promise<any> {
const vectorStoreIndex = nodeData.inputs?.vectorStoreIndex as VectorStoreIndex
const baseQueryEngine = nodeData.inputs?.baseQueryEngine as BaseQueryEngine
const toolName = nodeData.inputs?.toolName as string
const toolDesc = nodeData.inputs?.toolDesc as string
const queryEngineTool = {
queryEngine: vectorStoreIndex.asQueryEngine({
preFilters: {
...(vectorStoreIndex as any).metadatafilter
}
}),
queryEngine: baseQueryEngine,
metadata: {
name: toolName,
description: toolDesc
},
vectorStoreIndex
}
}

return queryEngineTool
Expand Down
8 changes: 4 additions & 4 deletions packages/server/marketplaces/chatflows/Query Engine.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
"data": {
"id": "queryEngine_0",
"label": "Query Engine",
"version": 1,
"version": 2,
"name": "queryEngine",
"type": "QueryEngine",
"baseClasses": ["QueryEngine"],
"baseClasses": ["QueryEngine", "BaseQueryEngine"],
"tags": ["LlamaIndex"],
"category": "Engine",
"description": "Simple query engine built to answer question over your data, without memory",
Expand Down Expand Up @@ -55,10 +55,10 @@
},
"outputAnchors": [
{
"id": "queryEngine_0-output-queryEngine-QueryEngine",
"id": "queryEngine_0-output-queryEngine-QueryEngine|BaseQueryEngine",
"name": "queryEngine",
"label": "QueryEngine",
"type": "QueryEngine"
"type": "QueryEngine | BaseQueryEngine"
}
],
"outputs": {},
Expand Down
Loading
Loading