Skip to content

Commit 621c0a3

Browse files
authored
Merge pull request #511 from FlowiseAI/feature/ConversationalRetrievalQAChain-Memory
Feature/conversational retrieval qa chain memory
2 parents a701c6e + aeb143a commit 621c0a3

File tree

11 files changed

+90
-20
lines changed

11 files changed

+90
-20
lines changed

packages/components/nodes/chains/ConversationalRetrievalQAChain/ConversationalRetrievalQAChain.ts

+42-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { BaseLanguageModel } from 'langchain/base_language'
22
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
33
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
44
import { ConversationalRetrievalQAChain } from 'langchain/chains'
5-
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
5+
import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema'
66
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
77
import { PromptTemplate } from 'langchain/prompts'
88

@@ -20,6 +20,20 @@ const qa_template = `Use the following pieces of context to answer the question
2020
Question: {question}
2121
Helpful Answer:`
2222

23+
const CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT = `Given the following conversation and a follow up question, return the conversation history excerpt that includes any relevant context to the question if it exists and rephrase the follow up question to be a standalone question.
24+
Chat History:
25+
{chat_history}
26+
Follow Up Input: {question}
27+
Your answer should follow the following format:
28+
\`\`\`
29+
Use the following pieces of context to answer the users question.
30+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
31+
----------------
32+
<Relevant chat history excerpt as context here>
33+
Standalone question: <Rephrased question here>
34+
\`\`\`
35+
Your answer:`
36+
2337
class ConversationalRetrievalQAChain_Chains implements INode {
2438
label: string
2539
name: string
@@ -49,6 +63,13 @@ class ConversationalRetrievalQAChain_Chains implements INode {
4963
name: 'vectorStoreRetriever',
5064
type: 'BaseRetriever'
5165
},
66+
{
67+
label: 'Memory',
68+
name: 'memory',
69+
type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory',
70+
optional: true,
71+
description: 'If no memory connected, default BufferMemory will be used'
72+
},
5273
{
5374
label: 'Return Source Documents',
5475
name: 'returnSourceDocuments',
@@ -99,22 +120,33 @@ class ConversationalRetrievalQAChain_Chains implements INode {
99120
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
100121
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
101122
const chainOption = nodeData.inputs?.chainOption as string
123+
const memory = nodeData.inputs?.memory
102124

103125
const obj: any = {
104126
verbose: process.env.DEBUG === 'true' ? true : false,
105127
qaChainOptions: {
106128
type: 'stuff',
107129
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
108130
},
109-
memory: new BufferMemory({
131+
questionGeneratorChainOptions: {
132+
template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
133+
}
134+
}
135+
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
136+
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
137+
if (memory) {
138+
memory.inputKey = 'question'
139+
memory.outputKey = 'text'
140+
memory.memoryKey = 'chat_history'
141+
obj.memory = memory
142+
} else {
143+
obj.memory = new BufferMemory({
110144
memoryKey: 'chat_history',
111145
inputKey: 'question',
112146
outputKey: 'text',
113147
returnMessages: true
114148
})
115149
}
116-
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
117-
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
118150

119151
const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
120152
return chain
@@ -123,6 +155,8 @@ class ConversationalRetrievalQAChain_Chains implements INode {
123155
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
124156
const chain = nodeData.instance as ConversationalRetrievalQAChain
125157
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
158+
const memory = nodeData.inputs?.memory
159+
126160
let model = nodeData.inputs?.model
127161

128162
// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
@@ -131,16 +165,17 @@ class ConversationalRetrievalQAChain_Chains implements INode {
131165

132166
const obj = { question: input }
133167

134-
if (chain.memory && options && options.chatHistory) {
168+
// If external memory like Zep, Redis is being used, ignore below
169+
if (!memory && chain.memory && options && options.chatHistory) {
135170
const chatHistory = []
136171
const histories: IMessage[] = options.chatHistory
137172
const memory = chain.memory as BaseChatMemory
138173

139174
for (const message of histories) {
140175
if (message.type === 'apiMessage') {
141-
chatHistory.push(new AIChatMessage(message.message))
176+
chatHistory.push(new AIMessage(message.message))
142177
} else if (message.type === 'userMessage') {
143-
chatHistory.push(new HumanChatMessage(message.message))
178+
chatHistory.push(new HumanMessage(message.message))
144179
}
145180
}
146181
memory.chatHistory = new ChatMessageHistory(chatHistory)

packages/components/nodes/memory/DynamoDb/DynamoDb.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ class DynamoDb_Memory implements INode {
1313
inputs: INodeParams[]
1414

1515
constructor() {
16-
this.label = 'DynamoDB Memory'
17-
this.name = 'DynamoDbMemory'
16+
this.label = 'DynamoDB Chat Memory'
17+
this.name = 'DynamoDBChatMemory'
18+
this.type = 'DynamoDBChatMemory'
1819
this.icon = 'dynamodb.svg'
1920
this.category = 'Memory'
2021
this.description = 'Stores the conversation in dynamo db table'

packages/components/nodes/memory/ZepMemory/ZepMemory.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class ZepMemory_Memory implements INode {
1818
this.label = 'Zep Memory'
1919
this.name = 'ZepMemory'
2020
this.type = 'ZepMemory'
21-
this.icon = 'memory.svg'
21+
this.icon = 'zep.png'
2222
this.category = 'Memory'
2323
this.description = 'Summarizes the conversation and stores the memory in zep server'
2424
this.baseClasses = [this.type, ...getBaseClasses(ZepMemory)]

packages/components/nodes/memory/ZepMemory/memory.svg

-8
This file was deleted.
Loading

packages/server/marketplaces/chatflows/Conversational Retrieval QA Chain.json

+8
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,14 @@
539539
"name": "vectorStoreRetriever",
540540
"type": "BaseRetriever",
541541
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
542+
},
543+
{
544+
"label": "Memory",
545+
"name": "memory",
546+
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
547+
"optional": true,
548+
"description": "If no memory connected, default BufferMemory will be used",
549+
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
542550
}
543551
],
544552
"inputs": {

packages/server/marketplaces/chatflows/Github Repo QnA.json

+8
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,14 @@
556556
"name": "vectorStoreRetriever",
557557
"type": "BaseRetriever",
558558
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
559+
},
560+
{
561+
"label": "Memory",
562+
"name": "memory",
563+
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
564+
"optional": true,
565+
"description": "If no memory connected, default BufferMemory will be used",
566+
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
559567
}
560568
],
561569
"inputs": {

packages/server/marketplaces/chatflows/Local QnA.json

+8
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@
131131
"name": "vectorStoreRetriever",
132132
"type": "BaseRetriever",
133133
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
134+
},
135+
{
136+
"label": "Memory",
137+
"name": "memory",
138+
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
139+
"optional": true,
140+
"description": "If no memory connected, default BufferMemory will be used",
141+
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
134142
}
135143
],
136144
"inputs": {

packages/server/marketplaces/chatflows/Metadata Filter Load.json

+8
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,14 @@
421421
"name": "vectorStoreRetriever",
422422
"type": "BaseRetriever",
423423
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
424+
},
425+
{
426+
"label": "Memory",
427+
"name": "memory",
428+
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
429+
"optional": true,
430+
"description": "If no memory connected, default BufferMemory will be used",
431+
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
424432
}
425433
],
426434
"inputs": {

packages/server/marketplaces/chatflows/Metadata Filter Upsert.json

+8
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,14 @@
625625
"name": "vectorStoreRetriever",
626626
"type": "BaseRetriever",
627627
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
628+
},
629+
{
630+
"label": "Memory",
631+
"name": "memory",
632+
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
633+
"optional": true,
634+
"description": "If no memory connected, default BufferMemory will be used",
635+
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
628636
}
629637
],
630638
"inputs": {

packages/ui/src/utils/genericHelper.js

+4-2
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,10 @@ export const isValidConnection = (connection, reactFlowInstance) => {
168168
//sourceHandle: "llmChain_0-output-llmChain-BaseChain"
169169
//targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel"
170170

171-
const sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
172-
const targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
171+
let sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
172+
sourceTypes = sourceTypes.map((s) => s.trim())
173+
let targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
174+
targetTypes = targetTypes.map((t) => t.trim())
173175

174176
if (targetTypes.some((t) => sourceTypes.includes(t))) {
175177
let targetNode = reactFlowInstance.getNode(target)

0 commit comments

Comments
 (0)