Skip to content

Commit 0df42e9

Browse files
author
kinggo
authored
feat: add Connection to execute multiple commands in a single connection; (#4982)
1 parent f757b8f commit 0df42e9

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

finisher_api.go

+24
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
515515
return tx.Error
516516
}
517517

518+
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
519+
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
520+
if db.Error != nil {
521+
return db.Error
522+
}
523+
524+
tx := db.getInstance()
525+
sqlDB, err := tx.DB()
526+
if err != nil {
527+
return
528+
}
529+
530+
conn, err := sqlDB.Conn(tx.Statement.Context)
531+
if err != nil {
532+
return
533+
}
534+
535+
defer conn.Close()
536+
tx.Statement.ConnPool = conn
537+
err = fc(tx)
538+
539+
return
540+
}
541+
518542
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
519543
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
520544
panicked := true

tests/connection_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package tests_test
2+
3+
import (
4+
"fmt"
5+
"gorm.io/driver/mysql"
6+
"gorm.io/gorm"
7+
"testing"
8+
)
9+
10+
func TestWithSingleConnection(t *testing.T) {
11+
12+
var expectedName = "test"
13+
var actualName string
14+
15+
setSQL, getSQL := getSetSQL(DB.Dialector.Name())
16+
if len(setSQL) == 0 || len(getSQL) == 0 {
17+
return
18+
}
19+
20+
err := DB.Connection(func(tx *gorm.DB) error {
21+
if err := tx.Exec(setSQL, expectedName).Error; err != nil {
22+
return err
23+
}
24+
25+
if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil {
26+
return err
27+
}
28+
return nil
29+
})
30+
31+
if err != nil {
32+
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
33+
}
34+
35+
if actualName != expectedName {
36+
t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName)
37+
}
38+
39+
}
40+
41+
func getSetSQL(driverName string) (string, string) {
42+
switch driverName {
43+
case mysql.Dialector{}.Name():
44+
return "SET @testName := ?", "SELECT @testName"
45+
default:
46+
return "", ""
47+
}
48+
}

0 commit comments

Comments
 (0)