6
6
7
7
/// The lock you get from [`Mutex`].
8
8
#[ cfg( feature = "multi_threaded" ) ]
9
+ #[ cfg( not( debug_assertions) ) ]
9
10
pub use parking_lot:: MutexGuard ;
10
11
12
+ /// The lock you get from [`Mutex`].
13
+ #[ cfg( feature = "multi_threaded" ) ]
14
+ #[ cfg( debug_assertions) ]
15
+ pub struct MutexGuard < ' a , T > ( parking_lot:: MutexGuard < ' a , T > , * const ( ) ) ;
16
+
11
17
/// Provides interior mutability. Only thread-safe if the `multi_threaded` feature is enabled.
12
18
#[ cfg( feature = "multi_threaded" ) ]
13
19
#[ derive( Default ) ]
14
20
pub struct Mutex < T > ( parking_lot:: Mutex < T > ) ;
15
21
22
+ #[ cfg( debug_assertions) ]
23
+ thread_local ! {
24
+ static HELD_LOCKS_TLS : std:: cell:: RefCell <std:: collections:: HashSet <* const ( ) >> = std:: cell:: RefCell :: new( std:: collections:: HashSet :: new( ) ) ;
25
+ }
26
+
16
27
#[ cfg( feature = "multi_threaded" ) ]
17
28
impl < T > Mutex < T > {
18
29
#[ inline( always) ]
@@ -22,12 +33,21 @@ impl<T> Mutex<T> {
22
33
23
34
#[ cfg( debug_assertions) ]
24
35
pub fn lock ( & self ) -> MutexGuard < ' _ , T > {
25
- // TODO: detect if we are trying to lock the same mutex from the same thread (bad)
26
- // vs locking it from another thread (fine).
27
- // At the moment we just panic on any double-locking of a mutex (so no multithreaded support in debug builds)
28
- self . 0
29
- . try_lock ( )
30
- . expect ( "The Mutex is already locked. Probably a bug" )
36
+ // Detect if we are recursively taking out a lock on this mutex.
37
+
38
+ // use a pointer to the inner data as an id for this lock
39
+ let ptr = ( & self . 0 as * const parking_lot:: Mutex < _ > ) . cast :: < ( ) > ( ) ;
40
+
41
+ // Store it in thread local storage while we have a lock guard taken out
42
+ HELD_LOCKS_TLS . with ( |locks| {
43
+ if locks. borrow ( ) . contains ( & ptr) {
44
+ panic ! ( "Recursively locking a Mutex in the same thread is not supported" )
45
+ } else {
46
+ locks. borrow_mut ( ) . insert ( ptr) ;
47
+ }
48
+ } ) ;
49
+
50
+ MutexGuard ( self . 0 . lock ( ) , ptr)
31
51
}
32
52
33
53
#[ inline( always) ]
@@ -37,6 +57,35 @@ impl<T> Mutex<T> {
37
57
}
38
58
}
39
59
60
+ #[ cfg( debug_assertions) ]
61
+ #[ cfg( feature = "multi_threaded" ) ]
62
+ impl < T > Drop for MutexGuard < ' _ , T > {
63
+ fn drop ( & mut self ) {
64
+ let ptr = self . 1 ;
65
+ HELD_LOCKS_TLS . with ( |locks| {
66
+ locks. borrow_mut ( ) . remove ( & ptr) ;
67
+ } ) ;
68
+ }
69
+ }
70
+
71
+ #[ cfg( debug_assertions) ]
72
+ #[ cfg( feature = "multi_threaded" ) ]
73
+ impl < T > std:: ops:: Deref for MutexGuard < ' _ , T > {
74
+ type Target = T ;
75
+
76
+ fn deref ( & self ) -> & Self :: Target {
77
+ & self . 0
78
+ }
79
+ }
80
+
81
+ #[ cfg( debug_assertions) ]
82
+ #[ cfg( feature = "multi_threaded" ) ]
83
+ impl < T > std:: ops:: DerefMut for MutexGuard < ' _ , T > {
84
+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
85
+ & mut self . 0
86
+ }
87
+ }
88
+
40
89
// ---------------------
41
90
42
91
/// The lock you get from [`RwLock::read`].
@@ -140,3 +189,41 @@ where
140
189
Self :: new ( self . lock ( ) . clone ( ) )
141
190
}
142
191
}
192
+
193
+ #[ cfg( test) ]
194
+ mod tests {
195
+ use crate :: mutex:: Mutex ;
196
+ use std:: time:: Duration ;
197
+
198
+ #[ test]
199
+ fn lock_two_different_mutexes_single_thread ( ) {
200
+ let one = Mutex :: new ( ( ) ) ;
201
+ let two = Mutex :: new ( ( ) ) ;
202
+ let _a = one. lock ( ) ;
203
+ let _b = two. lock ( ) ;
204
+ }
205
+
206
+ #[ test]
207
+ #[ should_panic]
208
+ fn lock_reentry_single_thread ( ) {
209
+ let one = Mutex :: new ( ( ) ) ;
210
+ let _a = one. lock ( ) ;
211
+ let _a2 = one. lock ( ) ; // panics
212
+ }
213
+
214
+ #[ test]
215
+ fn lock_multiple_threads ( ) {
216
+ use std:: sync:: Arc ;
217
+ let one = Arc :: new ( Mutex :: new ( ( ) ) ) ;
218
+ let our_lock = one. lock ( ) ;
219
+ let other_thread = {
220
+ let one = Arc :: clone ( & one) ;
221
+ std:: thread:: spawn ( move || {
222
+ let _ = one. lock ( ) ;
223
+ } )
224
+ } ;
225
+ std:: thread:: sleep ( Duration :: from_millis ( 200 ) ) ;
226
+ drop ( our_lock) ;
227
+ other_thread. join ( ) . unwrap ( ) ;
228
+ }
229
+ }
0 commit comments