1
+ use crate :: error:: { Error , Result } ;
2
+ use crate :: server:: { Request , Response } ;
3
+ use crate :: sql:: engine:: Status ;
4
+ use crate :: sql:: execution:: ResultSet ;
5
+ use crate :: sql:: schema:: Table ;
6
+
7
+ use futures:: future:: FutureExt as _;
8
+ use futures:: sink:: SinkExt as _;
9
+ use futures:: stream:: TryStreamExt as _;
10
+ use rand:: Rng as _;
11
+ use std:: cell:: Cell ;
12
+ use std:: future:: Future ;
13
+ use std:: ops:: { Deref , Drop } ;
14
+ use std:: sync:: Arc ;
15
+ use tokio:: net:: { TcpStream , ToSocketAddrs } ;
16
+ use tokio:: sync:: { Mutex , MutexGuard } ;
17
+ use tokio_util:: codec:: { Framed , LengthDelimitedCodec } ;
18
+
19
+ type Connection = tokio_serde:: Framed <
20
+ Framed < TcpStream , LengthDelimitedCodec > ,
21
+ Result < Response > ,
22
+ Request ,
23
+ tokio_serde:: formats:: Bincode < Result < Response > , Request > ,
24
+ > ;
25
+
26
+ /// Number of serialization retries in with_txn()
27
+ const WITH_TXN_RETRIES : u8 = 8 ;
28
+
29
+ /// A rrDB client
30
+ #[ derive( Clone ) ]
31
+ pub struct Client {
32
+ conn : Arc < Mutex < Connection > > ,
33
+ txn : Cell < Option < ( u64 , bool ) > > ,
34
+ }
35
+
36
+ impl Client {
37
+ /// Creates a new client
38
+ pub async fn new < A : ToSocketAddrs > ( addr : A ) -> Result < Self > {
39
+ Ok ( Self {
40
+ conn : Arc :: new ( Mutex :: new ( tokio_serde:: Framed :: new (
41
+ Framed :: new ( TcpStream :: connect ( addr) . await ?, LengthDelimitedCodec :: new ( ) ) ,
42
+ tokio_serde:: formats:: Bincode :: default ( ) ,
43
+ ) ) ) ,
44
+ txn : Cell :: new ( None ) ,
45
+ } )
46
+ }
47
+
48
+ /// Call a server method
49
+ async fn call ( & self , request : Request ) -> Result < Response > {
50
+ let mut conn = self . conn . lock ( ) . await ;
51
+ self . call_locked ( & mut conn, request) . await
52
+ }
53
+
54
+ /// Call a server method while holding the mutex lock
55
+ async fn call_locked (
56
+ & self ,
57
+ conn : & mut MutexGuard < ' _ , Connection > ,
58
+ request : Request ,
59
+ ) -> Result < Response > {
60
+ conn. send ( request) . await ?;
61
+ match conn. try_next ( ) . await ? {
62
+ Some ( result) => result,
63
+ None => Err ( Error :: Internal ( "Server disconnected" . into ( ) ) ) ,
64
+ }
65
+ }
66
+
67
+ /// Executes a query
68
+ pub async fn execute ( & self , query : & str ) -> Result < ResultSet > {
69
+ let mut conn = self . conn . lock ( ) . await ;
70
+ let mut resultset =
71
+ match self . call_locked ( & mut conn, Request :: Execute ( query. into ( ) ) ) . await ? {
72
+ Response :: Execute ( rs) => rs,
73
+ resp => return Err ( Error :: Internal ( format ! ( "Unexpected response {:?}" , resp) ) ) ,
74
+ } ;
75
+ if let ResultSet :: Query { columns, .. } = resultset {
76
+ // FIXME We buffer rows for now to avoid lifetime hassles
77
+ let mut rows = Vec :: new ( ) ;
78
+ while let Some ( result) = conn. try_next ( ) . await ? {
79
+ match result? {
80
+ Response :: Row ( Some ( row) ) => rows. push ( row) ,
81
+ Response :: Row ( None ) => break ,
82
+ response => {
83
+ return Err ( Error :: Internal ( format ! ( "Unexpected response {:?}" , response) ) )
84
+ }
85
+ }
86
+ }
87
+ resultset = ResultSet :: Query { columns, rows : Box :: new ( rows. into_iter ( ) . map ( Ok ) ) }
88
+ } ;
89
+ match & resultset {
90
+ ResultSet :: Begin { version, read_only } => self . txn . set ( Some ( ( * version, * read_only) ) ) ,
91
+ ResultSet :: Commit { .. } => self . txn . set ( None ) ,
92
+ ResultSet :: Rollback { .. } => self . txn . set ( None ) ,
93
+ _ => { }
94
+ }
95
+ Ok ( resultset)
96
+ }
97
+
98
+ /// Fetches the table schema as SQL
99
+ pub async fn get_table ( & self , table : & str ) -> Result < Table > {
100
+ match self . call ( Request :: GetTable ( table. into ( ) ) ) . await ? {
101
+ Response :: GetTable ( t) => Ok ( t) ,
102
+ resp => Err ( Error :: Value ( format ! ( "Unexpected response: {:?}" , resp) ) ) ,
103
+ }
104
+ }
105
+
106
+ /// Lists database tables
107
+ pub async fn list_tables ( & self ) -> Result < Vec < String > > {
108
+ match self . call ( Request :: ListTables ) . await ? {
109
+ Response :: ListTables ( t) => Ok ( t) ,
110
+ resp => Err ( Error :: Value ( format ! ( "Unexpected response: {:?}" , resp) ) ) ,
111
+ }
112
+ }
113
+
114
+ /// Checks server status
115
+ pub async fn status ( & self ) -> Result < Status > {
116
+ match self . call ( Request :: Status ) . await ? {
117
+ Response :: Status ( s) => Ok ( s) ,
118
+ resp => Err ( Error :: Value ( format ! ( "Unexpected response: {:?}" , resp) ) ) ,
119
+ }
120
+ }
121
+
122
+ /// Returns the version and read-only state of the txn
123
+ pub fn txn ( & self ) -> Option < ( u64 , bool ) > {
124
+ self . txn . get ( )
125
+ }
126
+
127
+ /// Runs a query in a transaction, automatically retrying serialization failures with
128
+ /// exponential backoff.
129
+ pub async fn with_txn < W , F , R > ( & self , mut with : W ) -> Result < R >
130
+ where
131
+ W : FnMut ( Client ) -> F ,
132
+ F : Future < Output = Result < R > > ,
133
+ {
134
+ for i in 0 ..WITH_TXN_RETRIES {
135
+ if i > 0 {
136
+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis (
137
+ 2_u64 . pow ( i as u32 - 1 ) * rand:: thread_rng ( ) . gen_range ( 25 ..=75 ) ,
138
+ ) )
139
+ . await ;
140
+ }
141
+ let result = async {
142
+ self . execute ( "BEGIN" ) . await ?;
143
+ let result = with ( self . clone ( ) ) . await ?;
144
+ self . execute ( "COMMIT" ) . await ?;
145
+ Ok ( result)
146
+ }
147
+ . await ;
148
+ if result. is_err ( ) {
149
+ self . execute ( "ROLLBACK" ) . await . ok ( ) ;
150
+ if matches ! ( result, Err ( Error :: Serialization ) | Err ( Error :: Abort ) ) {
151
+ continue ;
152
+ }
153
+ }
154
+ return result;
155
+ }
156
+ Err ( Error :: Serialization )
157
+ }
158
+ }
159
+
160
+ /// A rrDB client pool
161
+ pub struct Pool {
162
+ clients : Vec < Mutex < Client > > ,
163
+ }
164
+
165
+ impl Pool {
166
+ /// Creates a new connection pool for the given servers, eagerly connecting clients.
167
+ pub async fn new < A : ToSocketAddrs + Clone > ( addrs : Vec < A > , size : u64 ) -> Result < Self > {
168
+ let mut addrs = addrs. into_iter ( ) . cycle ( ) ;
169
+ let clients = futures:: future:: try_join_all (
170
+ std:: iter:: from_fn ( || {
171
+ Some ( Client :: new ( addrs. next ( ) . unwrap ( ) ) . map ( |r| r. map ( Mutex :: new) ) )
172
+ } )
173
+ . take ( size as usize ) ,
174
+ )
175
+ . await ?;
176
+ Ok ( Self { clients } )
177
+ }
178
+
179
+ /// Fetches a client from the pool. It is reset (i.e. any open txns are rolled back) and
180
+ /// returned when it goes out of scope.
181
+ pub async fn get ( & self ) -> PoolClient < ' _ > {
182
+ let ( client, index, _) =
183
+ futures:: future:: select_all ( self . clients . iter ( ) . map ( |m| m. lock ( ) . boxed ( ) ) ) . await ;
184
+ PoolClient :: new ( index, client)
185
+ }
186
+
187
+ /// Returns the size of the pool
188
+ pub fn size ( & self ) -> usize {
189
+ self . clients . len ( )
190
+ }
191
+ }
192
+
193
+ /// A client returned from the pool
194
+ pub struct PoolClient < ' a > {
195
+ id : usize ,
196
+ client : MutexGuard < ' a , Client > ,
197
+ }
198
+
199
+ impl < ' a > PoolClient < ' a > {
200
+ /// Creates a new PoolClient
201
+ fn new ( id : usize , client : MutexGuard < ' a , Client > ) -> Self {
202
+ Self { id, client }
203
+ }
204
+
205
+ /// Returns the ID of the client in the pool
206
+ pub fn id ( & self ) -> usize {
207
+ self . id
208
+ }
209
+ }
210
+
211
+ impl < ' a > Deref for PoolClient < ' a > {
212
+ type Target = MutexGuard < ' a , Client > ;
213
+
214
+ fn deref ( & self ) -> & Self :: Target {
215
+ & self . client
216
+ }
217
+ }
218
+
219
+ impl < ' a > Drop for PoolClient < ' a > {
220
+ fn drop ( & mut self ) {
221
+ if self . txn ( ) . is_some ( ) {
222
+ // FIXME This should disconnect or destroy the client if it errors.
223
+ futures:: executor:: block_on ( self . client . execute ( "ROLLBACK" ) ) . ok ( ) ;
224
+ }
225
+ }
226
+ }
0 commit comments