|
| 1 | +import { flatten } from 'lodash' |
1 | 2 | import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents'
|
2 |
| -import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema' |
| 3 | +import { ChainValues, AgentStep, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema' |
3 | 4 | import { OutputParserException } from 'langchain/schema/output_parser'
|
4 | 5 | import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks'
|
5 | 6 | import { ToolInputParsingException, Tool } from '@langchain/core/tools'
|
6 | 7 | import { Runnable } from 'langchain/schema/runnable'
|
7 | 8 | import { BaseChain, SerializedLLMChain } from 'langchain/chains'
|
8 | 9 | import { Serializable } from '@langchain/core/load/serializable'
|
9 | 10 |
|
| 11 | +export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n' |
| 12 | +type AgentFinish = { |
| 13 | + returnValues: Record<string, any> |
| 14 | + log: string |
| 15 | +} |
10 | 16 | type AgentExecutorOutput = ChainValues
|
11 | 17 |
|
12 | 18 | interface AgentExecutorIteratorInput {
|
@@ -315,10 +321,12 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
|
315 | 321 |
|
316 | 322 | const steps: AgentStep[] = []
|
317 | 323 | let iterations = 0
|
| 324 | + let sourceDocuments: Array<Document> = [] |
318 | 325 |
|
319 | 326 | const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
|
320 | 327 | const { returnValues } = finishStep
|
321 | 328 | const additional = await this.agent.prepareForOutput(returnValues, steps)
|
| 329 | + if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments) |
322 | 330 |
|
323 | 331 | if (this.returnIntermediateSteps) {
|
324 | 332 | return { ...returnValues, intermediateSteps: steps, ...additional }
|
@@ -406,6 +414,17 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
|
406 | 414 | return { action, observation: observation ?? '' }
|
407 | 415 | }
|
408 | 416 | }
|
| 417 | + if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) { |
| 418 | + const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX) |
| 419 | + observation = observationArray[0] |
| 420 | + const docs = observationArray[1] |
| 421 | + try { |
| 422 | + const parsedDocs = JSON.parse(docs) |
| 423 | + sourceDocuments.push(parsedDocs) |
| 424 | + } catch (e) { |
| 425 | + console.error('Error parsing source documents from tool') |
| 426 | + } |
| 427 | + } |
409 | 428 | return { action, observation: observation ?? '' }
|
410 | 429 | })
|
411 | 430 | )
|
@@ -500,6 +519,10 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
|
500 | 519 | chatId: this.chatId,
|
501 | 520 | input: this.input
|
502 | 521 | })
|
| 522 | + if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) { |
| 523 | + const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX) |
| 524 | + observation = observationArray[0] |
| 525 | + } |
503 | 526 | } catch (e) {
|
504 | 527 | if (e instanceof ToolInputParsingException) {
|
505 | 528 | if (this.handleParsingErrors === true) {
|
|
0 commit comments