Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/conversational retrieval qa chain memory #511

Merged
merged 3 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { BaseLanguageModel } from 'langchain/base_language'
import { ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { CustomChainHandler, getBaseClasses } from '../../../src/utils'
import { ConversationalRetrievalQAChain } from 'langchain/chains'
import { AIChatMessage, BaseRetriever, HumanChatMessage } from 'langchain/schema'
import { AIMessage, BaseRetriever, HumanMessage } from 'langchain/schema'
import { BaseChatMemory, BufferMemory, ChatMessageHistory } from 'langchain/memory'
import { PromptTemplate } from 'langchain/prompts'

Expand All @@ -20,6 +20,20 @@ const qa_template = `Use the following pieces of context to answer the question
Question: {question}
Helpful Answer:`

const CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT = `Given the following conversation and a follow up question, return the conversation history excerpt that includes any relevant context to the question if it exists and rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Your answer should follow the following format:
\`\`\`
Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
<Relevant chat history excerpt as context here>
Standalone question: <Rephrased question here>
\`\`\`
Your answer:`

class ConversationalRetrievalQAChain_Chains implements INode {
label: string
name: string
Expand Down Expand Up @@ -49,6 +63,13 @@ class ConversationalRetrievalQAChain_Chains implements INode {
name: 'vectorStoreRetriever',
type: 'BaseRetriever'
},
{
label: 'Memory',
name: 'memory',
type: 'DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory',
optional: true,
description: 'If no memory connected, default BufferMemory will be used'
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
Expand Down Expand Up @@ -99,22 +120,33 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const systemMessagePrompt = nodeData.inputs?.systemMessagePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const chainOption = nodeData.inputs?.chainOption as string
const memory = nodeData.inputs?.memory

const obj: any = {
verbose: process.env.DEBUG === 'true' ? true : false,
qaChainOptions: {
type: 'stuff',
prompt: PromptTemplate.fromTemplate(systemMessagePrompt ? `${systemMessagePrompt}\n${qa_template}` : default_qa_template)
},
memory: new BufferMemory({
questionGeneratorChainOptions: {
template: CUSTOM_QUESTION_GENERATOR_CHAIN_PROMPT
}
}
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }
if (memory) {
memory.inputKey = 'question'
memory.outputKey = 'text'
memory.memoryKey = 'chat_history'
obj.memory = memory
} else {
obj.memory = new BufferMemory({
memoryKey: 'chat_history',
inputKey: 'question',
outputKey: 'text',
returnMessages: true
})
}
if (returnSourceDocuments) obj.returnSourceDocuments = returnSourceDocuments
if (chainOption) obj.qaChainOptions = { ...obj.qaChainOptions, type: chainOption }

const chain = ConversationalRetrievalQAChain.fromLLM(model, vectorStoreRetriever, obj)
return chain
Expand All @@ -123,6 +155,8 @@ class ConversationalRetrievalQAChain_Chains implements INode {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const chain = nodeData.instance as ConversationalRetrievalQAChain
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const memory = nodeData.inputs?.memory

let model = nodeData.inputs?.model

// Temporary fix: https://github.com/hwchase17/langchainjs/issues/754
Expand All @@ -131,16 +165,17 @@ class ConversationalRetrievalQAChain_Chains implements INode {

const obj = { question: input }

if (chain.memory && options && options.chatHistory) {
// If external memory like Zep, Redis is being used, ignore below
if (!memory && chain.memory && options && options.chatHistory) {
const chatHistory = []
const histories: IMessage[] = options.chatHistory
const memory = chain.memory as BaseChatMemory

for (const message of histories) {
if (message.type === 'apiMessage') {
chatHistory.push(new AIChatMessage(message.message))
chatHistory.push(new AIMessage(message.message))
} else if (message.type === 'userMessage') {
chatHistory.push(new HumanChatMessage(message.message))
chatHistory.push(new HumanMessage(message.message))
}
}
memory.chatHistory = new ChatMessageHistory(chatHistory)
Expand Down
5 changes: 3 additions & 2 deletions packages/components/nodes/memory/DynamoDb/DynamoDb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ class DynamoDb_Memory implements INode {
inputs: INodeParams[]

constructor() {
this.label = 'DynamoDB Memory'
this.name = 'DynamoDbMemory'
this.label = 'DynamoDB Chat Memory'
this.name = 'DynamoDBChatMemory'
this.type = 'DynamoDBChatMemory'
this.icon = 'dynamodb.svg'
this.category = 'Memory'
this.description = 'Stores the conversation in dynamo db table'
Expand Down
2 changes: 1 addition & 1 deletion packages/components/nodes/memory/ZepMemory/ZepMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ZepMemory_Memory implements INode {
this.label = 'Zep Memory'
this.name = 'ZepMemory'
this.type = 'ZepMemory'
this.icon = 'memory.svg'
this.icon = 'zep.png'
this.category = 'Memory'
this.description = 'Summarizes the conversation and stores the memory in zep server'
this.baseClasses = [this.type, ...getBaseClasses(ZepMemory)]
Expand Down
8 changes: 0 additions & 8 deletions packages/components/nodes/memory/ZepMemory/memory.svg

This file was deleted.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,14 @@
"name": "vectorStoreRetriever",
"type": "BaseRetriever",
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
},
{
"label": "Memory",
"name": "memory",
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
"optional": true,
"description": "If no memory connected, default BufferMemory will be used",
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
}
],
"inputs": {
Expand Down
8 changes: 8 additions & 0 deletions packages/server/marketplaces/chatflows/Github Repo QnA.json
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,14 @@
"name": "vectorStoreRetriever",
"type": "BaseRetriever",
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
},
{
"label": "Memory",
"name": "memory",
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
"optional": true,
"description": "If no memory connected, default BufferMemory will be used",
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
}
],
"inputs": {
Expand Down
8 changes: 8 additions & 0 deletions packages/server/marketplaces/chatflows/Local QnA.json
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@
"name": "vectorStoreRetriever",
"type": "BaseRetriever",
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
},
{
"label": "Memory",
"name": "memory",
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
"optional": true,
"description": "If no memory connected, default BufferMemory will be used",
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
}
],
"inputs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,14 @@
"name": "vectorStoreRetriever",
"type": "BaseRetriever",
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
},
{
"label": "Memory",
"name": "memory",
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
"optional": true,
"description": "If no memory connected, default BufferMemory will be used",
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
}
],
"inputs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,14 @@
"name": "vectorStoreRetriever",
"type": "BaseRetriever",
"id": "conversationalRetrievalQAChain_0-input-vectorStoreRetriever-BaseRetriever"
},
{
"label": "Memory",
"name": "memory",
"type": "DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory",
"optional": true,
"description": "If no memory connected, default BufferMemory will be used",
"id": "conversationalRetrievalQAChain_0-input-memory-DynamoDBChatMemory | RedisBackedChatMemory | ZepMemory"
}
],
"inputs": {
Expand Down
6 changes: 4 additions & 2 deletions packages/ui/src/utils/genericHelper.js
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ export const isValidConnection = (connection, reactFlowInstance) => {
//sourceHandle: "llmChain_0-output-llmChain-BaseChain"
//targetHandle: "mrlkAgentLLM_0-input-model-BaseLanguageModel"

const sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
const targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
let sourceTypes = sourceHandle.split('-')[sourceHandle.split('-').length - 1].split('|')
sourceTypes = sourceTypes.map((s) => s.trim())
let targetTypes = targetHandle.split('-')[targetHandle.split('-').length - 1].split('|')
targetTypes = targetTypes.map((t) => t.trim())

if (targetTypes.some((t) => sourceTypes.includes(t))) {
let targetNode = reactFlowInstance.getNode(target)
Expand Down