@@ -6,6 +6,12 @@ package db
6
6
import (
7
7
"context"
8
8
"database/sql"
9
+ "errors"
10
+ "runtime"
11
+ "slices"
12
+ "sync"
13
+
14
+ "code.gitea.io/gitea/modules/setting"
9
15
10
16
"xorm.io/builder"
11
17
"xorm.io/xorm"
@@ -15,76 +21,90 @@ import (
15
21
// will be overwritten by Init with HammerContext
16
22
var DefaultContext context.Context
17
23
18
- // contextKey is a value for use with context.WithValue.
19
- type contextKey struct {
20
- name string
21
- }
24
+ type engineContextKeyType struct {}
22
25
23
- // enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context
24
- var (
25
- enginedContextKey = & contextKey {"engined" }
26
- _ Engined = & Context {}
27
- )
26
+ var engineContextKey = engineContextKeyType {}
28
27
29
28
// Context represents a db context
30
29
type Context struct {
31
30
context.Context
32
- e Engine
33
- transaction bool
34
- }
35
-
36
- func newContext (ctx context.Context , e Engine , transaction bool ) * Context {
37
- return & Context {
38
- Context : ctx ,
39
- e : e ,
40
- transaction : transaction ,
41
- }
42
- }
43
-
44
- // InTransaction if context is in a transaction
45
- func (ctx * Context ) InTransaction () bool {
46
- return ctx .transaction
31
+ engine Engine
47
32
}
48
33
49
- // Engine returns db engine
50
- func (ctx * Context ) Engine () Engine {
51
- return ctx .e
34
+ func newContext (ctx context.Context , e Engine ) * Context {
35
+ return & Context {Context : ctx , engine : e }
52
36
}
53
37
54
38
// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
55
39
func (ctx * Context ) Value (key any ) any {
56
- if key == enginedContextKey {
40
+ if key == engineContextKey {
57
41
return ctx
58
42
}
59
43
return ctx .Context .Value (key )
60
44
}
61
45
62
46
// WithContext returns this engine tied to this context
63
47
func (ctx * Context ) WithContext (other context.Context ) * Context {
64
- return newContext (ctx , ctx .e .Context (other ), ctx . transaction )
48
+ return newContext (ctx , ctx .engine .Context (other ))
65
49
}
66
50
67
- // Engined structs provide an Engine
68
- type Engined interface {
69
- Engine () Engine
51
+ var (
52
+ contextSafetyOnce sync.Once
53
+ contextSafetyDeniedFuncPCs []uintptr
54
+ )
55
+
56
+ func contextSafetyCheck (e Engine ) {
57
+ if setting .IsProd && ! setting .IsInTesting {
58
+ return
59
+ }
60
+ if e == nil {
61
+ return
62
+ }
63
+ // Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
64
+ contextSafetyOnce .Do (func () {
65
+ // try to figure out the bad functions to deny
66
+ type m struct {}
67
+ _ = e .SQL ("SELECT 1" ).Iterate (& m {}, func (int , any ) error {
68
+ callers := make ([]uintptr , 32 )
69
+ callerNum := runtime .Callers (1 , callers )
70
+ for i := 0 ; i < callerNum ; i ++ {
71
+ if funcName := runtime .FuncForPC (callers [i ]).Name (); funcName == "xorm.io/xorm.(*Session).Iterate" {
72
+ contextSafetyDeniedFuncPCs = append (contextSafetyDeniedFuncPCs , callers [i ])
73
+ }
74
+ }
75
+ return nil
76
+ })
77
+ if len (contextSafetyDeniedFuncPCs ) != 1 {
78
+ panic (errors .New ("unable to determine the functions to deny" ))
79
+ }
80
+ })
81
+
82
+ // it should be very fast: xxxx ns/op
83
+ callers := make ([]uintptr , 32 )
84
+ callerNum := runtime .Callers (3 , callers ) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
85
+ for i := 0 ; i < callerNum ; i ++ {
86
+ if slices .Contains (contextSafetyDeniedFuncPCs , callers [i ]) {
87
+ panic (errors .New ("using database context in an iterator would cause corrupted results" ))
88
+ }
89
+ }
70
90
}
71
91
72
- // GetEngine will get a db Engine from this context or return an Engine restricted to this context
92
+ // GetEngine gets an existing db Engine/Statement or creates a new Session
73
93
func GetEngine (ctx context.Context ) Engine {
74
- if e := getEngine (ctx ); e != nil {
94
+ if e := getExistingEngine (ctx ); e != nil {
75
95
return e
76
96
}
77
97
return x .Context (ctx )
78
98
}
79
99
80
- // getEngine will get a db Engine from this context or return nil
81
- func getEngine (ctx context.Context ) Engine {
82
- if engined , ok := ctx .(Engined ); ok {
83
- return engined .Engine ()
100
+ // getExistingEngine gets an existing db Engine/Statement from this context or returns nil
101
+ func getExistingEngine (ctx context.Context ) (e Engine ) {
102
+ defer func () { contextSafetyCheck (e ) }()
103
+ if engined , ok := ctx .(* Context ); ok {
104
+ return engined .engine
84
105
}
85
- enginedInterface := ctx .Value (enginedContextKey )
86
- if enginedInterface != nil {
87
- return enginedInterface .(Engined ).Engine ()
106
+ if engined , ok := ctx .Value (engineContextKey ).(* Context ); ok {
107
+ return engined .engine
88
108
}
89
109
return nil
90
110
}
@@ -132,23 +152,23 @@ func (c *halfCommitter) Close() error {
132
152
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
133
153
func TxContext (parentCtx context.Context ) (* Context , Committer , error ) {
134
154
if sess , ok := inTransaction (parentCtx ); ok {
135
- return newContext (parentCtx , sess , true ), & halfCommitter {committer : sess }, nil
155
+ return newContext (parentCtx , sess ), & halfCommitter {committer : sess }, nil
136
156
}
137
157
138
158
sess := x .NewSession ()
139
159
if err := sess .Begin (); err != nil {
140
- sess .Close ()
160
+ _ = sess .Close ()
141
161
return nil , nil , err
142
162
}
143
163
144
- return newContext (DefaultContext , sess , true ), sess , nil
164
+ return newContext (DefaultContext , sess ), sess , nil
145
165
}
146
166
147
167
// WithTx represents executing database operations on a transaction, if the transaction exist,
148
168
// this function will reuse it otherwise will create a new one and close it when finished.
149
169
func WithTx (parentCtx context.Context , f func (ctx context.Context ) error ) error {
150
170
if sess , ok := inTransaction (parentCtx ); ok {
151
- err := f (newContext (parentCtx , sess , true ))
171
+ err := f (newContext (parentCtx , sess ))
152
172
if err != nil {
153
173
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
154
174
_ = sess .Close ()
@@ -165,7 +185,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
165
185
return err
166
186
}
167
187
168
- if err := f (newContext (parentCtx , sess , true )); err != nil {
188
+ if err := f (newContext (parentCtx , sess )); err != nil {
169
189
return err
170
190
}
171
191
@@ -312,7 +332,7 @@ func InTransaction(ctx context.Context) bool {
312
332
}
313
333
314
334
func inTransaction (ctx context.Context ) (* xorm.Session , bool ) {
315
- e := getEngine (ctx )
335
+ e := getExistingEngine (ctx )
316
336
if e == nil {
317
337
return nil , false
318
338
}
0 commit comments