Skip to content

Commit eae7362

Browse files
committed
Fix return failed to begin transaction error when failed to start a transaction
1 parent 0df42e9 commit eae7362

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

finisher_api.go

+12-12
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,7 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
534534

535535
defer conn.Close()
536536
tx.Statement.ConnPool = conn
537-
err = fc(tx)
538-
539-
return
537+
return fc(tx)
540538
}
541539

542540
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
@@ -547,6 +545,10 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
547545
// nested transaction
548546
if !db.DisableNestedTransaction {
549547
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
548+
if err != nil {
549+
return
550+
}
551+
550552
defer func() {
551553
// Make sure to rollback when panic, Block error or Commit error
552554
if panicked || err != nil {
@@ -555,11 +557,12 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
555557
}()
556558
}
557559

558-
if err == nil {
559-
err = fc(db.Session(&Session{}))
560-
}
560+
err = fc(db.Session(&Session{}))
561561
} else {
562562
tx := db.Begin(opts...)
563+
if tx.Error != nil {
564+
return tx.Error
565+
}
563566

564567
defer func() {
565568
// Make sure to rollback when panic, Block error or Commit error
@@ -568,12 +571,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
568571
}
569572
}()
570573

571-
if err = tx.Error; err == nil {
572-
err = fc(tx)
573-
}
574-
575-
if err == nil {
576-
err = tx.Commit().Error
574+
if err = fc(tx); err == nil {
575+
panicked = false
576+
return tx.Commit().Error
577577
}
578578
}
579579

tests/connection_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ package tests_test
22

33
import (
44
"fmt"
5+
"testing"
6+
57
"gorm.io/driver/mysql"
68
"gorm.io/gorm"
7-
"testing"
89
)
910

1011
func TestWithSingleConnection(t *testing.T) {
11-
1212
var expectedName = "test"
1313
var actualName string
1414

@@ -35,7 +35,6 @@ func TestWithSingleConnection(t *testing.T) {
3535
if actualName != expectedName {
3636
t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName)
3737
}
38-
3938
}
4039

4140
func getSetSQL(driverName string) (string, string) {

0 commit comments

Comments
 (0)