Skip to content

Commit 7d1234a

Browse files
authored
Bugfix/SQLite agent memory node (#3650)
* add dedicated agent memory nodes * sqlite agent memory fix * Update pnpm-lock.yaml
1 parent cadc3b8 commit 7d1234a

File tree

1 file changed

+75
-43
lines changed
  • packages/components/nodes/memory/AgentMemory/SQLiteAgentMemory

1 file changed

+75
-43
lines changed

packages/components/nodes/memory/AgentMemory/SQLiteAgentMemory/sqliteSaver.ts

+75-43
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,47 @@
11
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
22
import { RunnableConfig } from '@langchain/core/runnables'
33
import { BaseMessage } from '@langchain/core/messages'
4-
import { DataSource, QueryRunner } from 'typeorm'
4+
import { DataSource } from 'typeorm'
55
import { CheckpointTuple, SaverOptions, SerializerProtocol } from '../interface'
66
import { IMessage, MemoryMethods } from '../../../../src/Interface'
77
import { mapChatMessageToBaseMessage } from '../../../../src/utils'
88

99
export class SqliteSaver extends BaseCheckpointSaver implements MemoryMethods {
1010
protected isSetup: boolean
11-
12-
datasource: DataSource
13-
14-
queryRunner: QueryRunner
15-
1611
config: SaverOptions
17-
1812
threadId: string
19-
2013
tableName = 'checkpoints'
2114

2215
constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
2316
super(serde)
2417
this.config = config
25-
const { datasourceOptions, threadId } = config
18+
const { threadId } = config
2619
this.threadId = threadId
27-
this.datasource = new DataSource(datasourceOptions)
2820
}
2921

30-
private async setup(): Promise<void> {
22+
private async getDataSource(): Promise<DataSource> {
23+
const { datasourceOptions } = this.config
24+
const dataSource = new DataSource(datasourceOptions)
25+
await dataSource.initialize()
26+
return dataSource
27+
}
28+
29+
private async setup(dataSource: DataSource): Promise<void> {
3130
if (this.isSetup) {
3231
return
3332
}
3433

3534
try {
36-
const appDataSource = await this.datasource.initialize()
37-
38-
this.queryRunner = appDataSource.createQueryRunner()
39-
await this.queryRunner.manager.query(`
35+
const queryRunner = dataSource.createQueryRunner()
36+
await queryRunner.manager.query(`
4037
CREATE TABLE IF NOT EXISTS ${this.tableName} (
4138
thread_id TEXT NOT NULL,
4239
checkpoint_id TEXT NOT NULL,
4340
parent_id TEXT,
4441
checkpoint BLOB,
4542
metadata BLOB,
4643
PRIMARY KEY (thread_id, checkpoint_id));`)
44+
await queryRunner.release()
4745
} catch (error) {
4846
console.error(`Error creating ${this.tableName} table`, error)
4947
throw new Error(`Error creating ${this.tableName} table`)
@@ -53,16 +51,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
5351
}
5452

5553
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
56-
await this.setup()
54+
const dataSource = await this.getDataSource()
55+
await this.setup(dataSource)
56+
5757
const thread_id = config.configurable?.thread_id || this.threadId
5858
const checkpoint_id = config.configurable?.checkpoint_id
5959

6060
if (checkpoint_id) {
6161
try {
62+
const queryRunner = dataSource.createQueryRunner()
6263
const keys = [thread_id, checkpoint_id]
6364
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`
6465

65-
const rows = await this.queryRunner.manager.query(sql, [...keys])
66+
const rows = await queryRunner.manager.query(sql, [...keys])
67+
await queryRunner.release()
6668

6769
if (rows && rows.length > 0) {
6870
return {
@@ -82,39 +84,53 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
8284
} catch (error) {
8385
console.error(`Error retrieving ${this.tableName}`, error)
8486
throw new Error(`Error retrieving ${this.tableName}`)
87+
} finally {
88+
await dataSource.destroy()
8589
}
8690
} else {
87-
const keys = [thread_id]
88-
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
91+
try {
92+
const queryRunner = dataSource.createQueryRunner()
93+
const keys = [thread_id]
94+
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`
8995

90-
const rows = await this.queryRunner.manager.query(sql, [...keys])
96+
const rows = await queryRunner.manager.query(sql, [...keys])
97+
await queryRunner.release()
9198

92-
if (rows && rows.length > 0) {
93-
return {
94-
config: {
95-
configurable: {
96-
thread_id: rows[0].thread_id,
97-
checkpoint_id: rows[0].checkpoint_id
98-
}
99-
},
100-
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
101-
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
102-
parentConfig: rows[0].parent_id
103-
? {
104-
configurable: {
105-
thread_id: rows[0].thread_id,
106-
checkpoint_id: rows[0].parent_id
99+
if (rows && rows.length > 0) {
100+
return {
101+
config: {
102+
configurable: {
103+
thread_id: rows[0].thread_id,
104+
checkpoint_id: rows[0].checkpoint_id
105+
}
106+
},
107+
checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint,
108+
metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata,
109+
parentConfig: rows[0].parent_id
110+
? {
111+
configurable: {
112+
thread_id: rows[0].thread_id,
113+
checkpoint_id: rows[0].parent_id
114+
}
107115
}
108-
}
109-
: undefined
116+
: undefined
117+
}
110118
}
119+
} catch (error) {
120+
console.error(`Error retrieving ${this.tableName}`, error)
121+
throw new Error(`Error retrieving ${this.tableName}`)
122+
} finally {
123+
await dataSource.destroy()
111124
}
112125
}
113126
return undefined
114127
}
115128

116129
async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
117-
await this.setup()
130+
const dataSource = await this.getDataSource()
131+
await this.setup(dataSource)
132+
133+
const queryRunner = dataSource.createQueryRunner()
118134
const thread_id = config.configurable?.thread_id || this.threadId
119135
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
120136
before ? 'AND checkpoint_id < ?' : ''
@@ -125,7 +141,8 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
125141
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)
126142

127143
try {
128-
const rows = await this.queryRunner.manager.query(sql, [...args])
144+
const rows = await queryRunner.manager.query(sql, [...args])
145+
await queryRunner.release()
129146

130147
if (rows && rows.length > 0) {
131148
for (const row of rows) {
@@ -152,13 +169,18 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
152169
} catch (error) {
153170
console.error(`Error listing ${this.tableName}`, error)
154171
throw new Error(`Error listing ${this.tableName}`)
172+
} finally {
173+
await dataSource.destroy()
155174
}
156175
}
157176

158177
async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
159-
await this.setup()
178+
const dataSource = await this.getDataSource()
179+
await this.setup(dataSource)
180+
160181
if (!config.configurable?.checkpoint_id) return {}
161182
try {
183+
const queryRunner = dataSource.createQueryRunner()
162184
const row = [
163185
config.configurable?.thread_id || this.threadId,
164186
checkpoint.id,
@@ -169,10 +191,13 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
169191

170192
const query = `INSERT OR REPLACE INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) VALUES (?, ?, ?, ?, ?)`
171193

172-
await this.queryRunner.manager.query(query, row)
194+
await queryRunner.manager.query(query, row)
195+
await queryRunner.release()
173196
} catch (error) {
174197
console.error('Error saving checkpoint', error)
175198
throw new Error('Error saving checkpoint')
199+
} finally {
200+
await dataSource.destroy()
176201
}
177202

178203
return {
@@ -187,13 +212,20 @@ CREATE TABLE IF NOT EXISTS ${this.tableName} (
187212
if (!threadId) {
188213
return
189214
}
190-
await this.setup()
215+
216+
const dataSource = await this.getDataSource()
217+
await this.setup(dataSource)
218+
191219
const query = `DELETE FROM "${this.tableName}" WHERE thread_id = ?;`
192220

193221
try {
194-
await this.queryRunner.manager.query(query, [threadId])
222+
const queryRunner = dataSource.createQueryRunner()
223+
await queryRunner.manager.query(query, [threadId])
224+
await queryRunner.release()
195225
} catch (error) {
196226
console.error(`Error deleting thread_id ${threadId}`, error)
227+
} finally {
228+
await dataSource.destroy()
197229
}
198230
}
199231

0 commit comments

Comments
 (0)