Skip to content

Commit 436b3aa

Browse files
authored
Merge pull request FlowiseAI#1644 from FlowiseAI/feature/Retriever-Tool-Source-Documents
Feature/Return Source Documens to retriever tool
2 parents 18c9c1c + 4d6881b commit 436b3aa

File tree

6 files changed

+86
-14
lines changed

6 files changed

+86
-14
lines changed

packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts

+10-2
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,28 @@ class OpenAIFunctionAgent_Agents implements INode {
6464
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
6565
}
6666

67-
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
67+
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
6868
const memory = nodeData.inputs?.memory as FlowiseMemory
6969
const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
7070

7171
const loggerHandler = new ConsoleCallbackHandler(options.logger)
7272
const callbacks = await additionalCallbacks(nodeData, options)
7373

7474
let res: ChainValues = {}
75+
let sourceDocuments: ICommonObject[] = []
7576

7677
if (options.socketIO && options.socketIOClientId) {
7778
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
7879
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
80+
if (res.sourceDocuments) {
81+
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
82+
sourceDocuments = res.sourceDocuments
83+
}
7984
} else {
8085
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
86+
if (res.sourceDocuments) {
87+
sourceDocuments = res.sourceDocuments
88+
}
8189
}
8290

8391
await memory.addChatMessages(
@@ -94,7 +102,7 @@ class OpenAIFunctionAgent_Agents implements INode {
94102
this.sessionId
95103
)
96104

97-
return res?.output
105+
return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output
98106
}
99107
}
100108

packages/components/nodes/outputparsers/CustomListOutputParser/CustomListOutputParser.ts

+8-6
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,17 @@ class CustomListOutputParser implements INode {
2929
label: 'Length',
3030
name: 'length',
3131
type: 'number',
32-
default: 5,
3332
step: 1,
34-
description: 'Number of values to return'
33+
description: 'Number of values to return',
34+
optional: true
3535
},
3636
{
3737
label: 'Separator',
3838
name: 'separator',
3939
type: 'string',
4040
description: 'Separator between values',
41-
default: ','
41+
default: ',',
42+
optional: true
4243
},
4344
{
4445
label: 'Autofix',
@@ -54,10 +55,11 @@ class CustomListOutputParser implements INode {
5455
const separator = nodeData.inputs?.separator as string
5556
const lengthStr = nodeData.inputs?.length as string
5657
const autoFix = nodeData.inputs?.autofixParser as boolean
57-
let length = 5
58-
if (lengthStr) length = parseInt(lengthStr, 10)
5958

60-
const parser = new LangchainCustomListOutputParser({ length: length, separator: separator })
59+
const parser = new LangchainCustomListOutputParser({
60+
length: lengthStr ? parseInt(lengthStr, 10) : undefined,
61+
separator: separator
62+
})
6163
Object.defineProperty(parser, 'autoFix', {
6264
enumerable: true,
6365
configurable: true,

packages/components/nodes/tools/RetrieverTool/RetrieverTool.ts

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import { INode, INodeData, INodeParams } from '../../../src/Interface'
22
import { getBaseClasses } from '../../../src/utils'
33
import { DynamicTool } from 'langchain/tools'
4-
import { createRetrieverTool } from 'langchain/agents/toolkits'
4+
import { DynamicStructuredTool } from '@langchain/core/tools'
5+
import { CallbackManagerForToolRun } from '@langchain/core/callbacks/manager'
56
import { BaseRetriever } from 'langchain/schema/retriever'
7+
import { z } from 'zod'
8+
import { SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'
69

710
class Retriever_Tools implements INode {
811
label: string
@@ -19,7 +22,7 @@ class Retriever_Tools implements INode {
1922
constructor() {
2023
this.label = 'Retriever Tool'
2124
this.name = 'retrieverTool'
22-
this.version = 1.0
25+
this.version = 2.0
2326
this.type = 'RetrieverTool'
2427
this.icon = 'retrievertool.svg'
2528
this.category = 'Tools'
@@ -44,6 +47,12 @@ class Retriever_Tools implements INode {
4447
label: 'Retriever',
4548
name: 'retriever',
4649
type: 'BaseRetriever'
50+
},
51+
{
52+
label: 'Return Source Documents',
53+
name: 'returnSourceDocuments',
54+
type: 'boolean',
55+
optional: true
4756
}
4857
]
4958
}
@@ -52,12 +61,25 @@ class Retriever_Tools implements INode {
5261
const name = nodeData.inputs?.name as string
5362
const description = nodeData.inputs?.description as string
5463
const retriever = nodeData.inputs?.retriever as BaseRetriever
64+
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
5565

56-
const tool = createRetrieverTool(retriever, {
66+
const input = {
5767
name,
5868
description
69+
}
70+
71+
const func = async ({ input }: { input: string }, runManager?: CallbackManagerForToolRun) => {
72+
const docs = await retriever.getRelevantDocuments(input, runManager?.getChild('retriever'))
73+
const content = docs.map((doc) => doc.pageContent).join('\n\n')
74+
const sourceDocuments = JSON.stringify(docs)
75+
return returnSourceDocuments ? content + SOURCE_DOCUMENTS_PREFIX + sourceDocuments : content
76+
}
77+
78+
const schema = z.object({
79+
input: z.string().describe('query to look up in retriever')
5980
})
6081

82+
const tool = new DynamicStructuredTool({ ...input, func, schema })
6183
return tool
6284
}
6385
}

packages/components/src/agents.ts

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
import { flatten } from 'lodash'
12
import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents'
2-
import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
3+
import { ChainValues, AgentStep, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
34
import { OutputParserException } from 'langchain/schema/output_parser'
45
import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks'
56
import { ToolInputParsingException, Tool } from '@langchain/core/tools'
67
import { Runnable } from 'langchain/schema/runnable'
78
import { BaseChain, SerializedLLMChain } from 'langchain/chains'
89
import { Serializable } from '@langchain/core/load/serializable'
910

11+
export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
12+
type AgentFinish = {
13+
returnValues: Record<string, any>
14+
log: string
15+
}
1016
type AgentExecutorOutput = ChainValues
1117

1218
interface AgentExecutorIteratorInput {
@@ -315,10 +321,12 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
315321

316322
const steps: AgentStep[] = []
317323
let iterations = 0
324+
let sourceDocuments: Array<Document> = []
318325

319326
const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
320327
const { returnValues } = finishStep
321328
const additional = await this.agent.prepareForOutput(returnValues, steps)
329+
if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments)
322330

323331
if (this.returnIntermediateSteps) {
324332
return { ...returnValues, intermediateSteps: steps, ...additional }
@@ -406,6 +414,17 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
406414
return { action, observation: observation ?? '' }
407415
}
408416
}
417+
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
418+
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
419+
observation = observationArray[0]
420+
const docs = observationArray[1]
421+
try {
422+
const parsedDocs = JSON.parse(docs)
423+
sourceDocuments.push(parsedDocs)
424+
} catch (e) {
425+
console.error('Error parsing source documents from tool')
426+
}
427+
}
409428
return { action, observation: observation ?? '' }
410429
})
411430
)
@@ -500,6 +519,10 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
500519
chatId: this.chatId,
501520
input: this.input
502521
})
522+
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
523+
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
524+
observation = observationArray[0]
525+
}
503526
} catch (e) {
504527
if (e instanceof ToolInputParsingException) {
505528
if (this.handleParsingErrors === true) {

packages/server/marketplaces/chatflows/Conversational Retrieval Agent.json

+9-1
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@
217217
"rows": 3,
218218
"placeholder": "Searches and returns documents regarding the state-of-the-union.",
219219
"id": "retrieverTool_0-input-description-string"
220+
},
221+
{
222+
"label": "Return Source Documents",
223+
"name": "returnSourceDocuments",
224+
"type": "boolean",
225+
"optional": true,
226+
"id": "retrieverTool_0-input-returnSourceDocuments-boolean"
220227
}
221228
],
222229
"inputAnchors": [
@@ -230,7 +237,8 @@
230237
"inputs": {
231238
"name": "search_website",
232239
"description": "Searches and return documents regarding Jane - a culinary institution that offers top quality coffee, pastries, breakfast, lunch, and a variety of baked goods. They have multiple locations, including Jane on Fillmore, Jane on Larkin, Jane the Bakery, Toy Boat By Jane, and Little Jane on Grant. They emphasize healthy eating with a focus on flavor and quality ingredients. They bake everything in-house and work with local suppliers to source ingredients directly from farmers. They also offer catering services and delivery options.",
233-
"retriever": "{{pinecone_0.data.instance}}"
240+
"retriever": "{{pinecone_0.data.instance}}",
241+
"returnSourceDocuments": true
234242
},
235243
"outputAnchors": [
236244
{

packages/server/src/index.ts

+10-1
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,8 @@ export class App {
473473
const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id))
474474

475475
let isStreaming = false
476+
let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode')
477+
476478
for (const endingNode of endingNodes) {
477479
const endingNodeData = endingNode.data
478480
if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`)
@@ -488,7 +490,8 @@ export class App {
488490
isStreaming = isEndingNode ? false : isFlowValidForStream(nodes, endingNodeData)
489491
}
490492

491-
const obj = { isStreaming }
493+
// Once custom function ending node exists, flow is always unavailable to stream
494+
const obj = { isStreaming: isEndingNodeExists ? false : isStreaming }
492495
return res.json(obj)
493496
})
494497

@@ -1677,6 +1680,9 @@ export class App {
16771680
if (!endingNodeIds.length) return res.status(500).send(`Ending nodes not found`)
16781681

16791682
const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id))
1683+
1684+
let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode')
1685+
16801686
for (const endingNode of endingNodes) {
16811687
const endingNodeData = endingNode.data
16821688
if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`)
@@ -1704,6 +1710,9 @@ export class App {
17041710
isStreamValid = isFlowValidForStream(nodes, endingNodeData)
17051711
}
17061712

1713+
// Once custom function ending node exists, flow is always unavailable to stream
1714+
isStreamValid = isEndingNodeExists ? false : isStreamValid
1715+
17071716
let chatHistory: IMessage[] = incomingInput.history ?? []
17081717

17091718
// When {{chat_history}} is used in Prompt Template, fetch the chat conversations from memory node

0 commit comments

Comments
 (0)