Skip to content

Commit 4b3b8a5

Browse files
committed
sql client init
1 parent 110e2d7 commit 4b3b8a5

21 files changed

+4732
-20
lines changed

Cargo.lock

+1,946
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[package]
2-
name = "rust-raft-db"
2+
name = "rustraftdb"
33
version = "0.1.0"
44
edition = "2021"
55

src/bin/dump.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#![warn(clippy::all)]
22

3-
use rrdb::error::{Error, Result};
4-
use rrdb::storage::debug;
5-
use rrdb::storage::engine::{BitCask, Engine};
3+
use rustraftdb::error::{Error, Result};
4+
use rustraftdb::storage::debug;
5+
use rustraftdb::storage::engine::{BitCask, Engine};
66

77
fn main() -> Result<()> {
88
let args = clap::command!()

src/bin/rrdb.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
use serde_derive::Deserialize;
44
use std::collections::HashMap;
5-
use rrdb::error::{Error, Result};
6-
use rrdb::raft;
7-
use rrdb::sql;
8-
use rrdb::storage;
9-
use rrdb::Server;
5+
use rustraftdb::error::{Error, Result};
6+
use rustraftdb::raft;
7+
use rustraftdb::sql;
8+
use rustraftdb::storage;
9+
use rustraftdb::Server;
1010

1111
#[tokio::main]
1212
async fn main() -> Result<()> {

src/bin/sql.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ use rustyline::history::DefaultHistory;
44
use rustyline::validate::{ValidationContext, ValidationResult, Validator};
55
use rustyline::{error::ReadlineError, Editor, Modifiers};
66
use rustyline_derive::{Completer, Helper, Highlighter, Hinter};
7-
use rrdb::error::{Error, Result};
8-
use rrdb::sql::execution::ResultSet;
9-
use rrdb::sql::parser::{Lexer, Token};
10-
use rrdb::Client;
7+
use rustraftdb::error::{Error, Result};
8+
use rustraftdb::sql::execution::ResultSet;
9+
use rustraftdb::sql::parser::{Lexer, Token};
10+
use rustraftdb::Client;
1111

1212
#[tokio::main]
1313
async fn main() -> Result<()> {
@@ -30,14 +30,14 @@ async fn main() -> Result<()> {
3030
])
3131
.get_matches();
3232

33-
let mut toysql =
34-
ToySQL::new(opts.get_one::<String>("host").unwrap(), *opts.get_one("port").unwrap())
33+
let mut sjysql =
34+
SjySQL::new(opts.get_one::<String>("host").unwrap(), *opts.get_one("port").unwrap())
3535
.await?;
3636

3737
if let Some(command) = opts.get_one::<&str>("command") {
38-
toysql.execute(command).await
38+
sjysql.execute(command).await
3939
} else {
40-
toysql.run().await
40+
sjysql.run().await
4141
}
4242
}
4343

src/client.rs

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
use crate::error::{Error, Result};
2+
use crate::server::{Request, Response};
3+
use crate::sql::engine::Status;
4+
use crate::sql::execution::ResultSet;
5+
use crate::sql::schema::Table;
6+
7+
use futures::future::FutureExt as _;
8+
use futures::sink::SinkExt as _;
9+
use futures::stream::TryStreamExt as _;
10+
use rand::Rng as _;
11+
use std::cell::Cell;
12+
use std::future::Future;
13+
use std::ops::{Deref, Drop};
14+
use std::sync::Arc;
15+
use tokio::net::{TcpStream, ToSocketAddrs};
16+
use tokio::sync::{Mutex, MutexGuard};
17+
use tokio_util::codec::{Framed, LengthDelimitedCodec};
18+
19+
type Connection = tokio_serde::Framed<
20+
Framed<TcpStream, LengthDelimitedCodec>,
21+
Result<Response>,
22+
Request,
23+
tokio_serde::formats::Bincode<Result<Response>, Request>,
24+
>;
25+
26+
/// Number of serialization retries in with_txn()
27+
const WITH_TXN_RETRIES: u8 = 8;
28+
29+
/// A rrDB client
30+
#[derive(Clone)]
31+
pub struct Client {
32+
conn: Arc<Mutex<Connection>>,
33+
txn: Cell<Option<(u64, bool)>>,
34+
}
35+
36+
impl Client {
37+
/// Creates a new client
38+
pub async fn new<A: ToSocketAddrs>(addr: A) -> Result<Self> {
39+
Ok(Self {
40+
conn: Arc::new(Mutex::new(tokio_serde::Framed::new(
41+
Framed::new(TcpStream::connect(addr).await?, LengthDelimitedCodec::new()),
42+
tokio_serde::formats::Bincode::default(),
43+
))),
44+
txn: Cell::new(None),
45+
})
46+
}
47+
48+
/// Call a server method
49+
async fn call(&self, request: Request) -> Result<Response> {
50+
let mut conn = self.conn.lock().await;
51+
self.call_locked(&mut conn, request).await
52+
}
53+
54+
/// Call a server method while holding the mutex lock
55+
async fn call_locked(
56+
&self,
57+
conn: &mut MutexGuard<'_, Connection>,
58+
request: Request,
59+
) -> Result<Response> {
60+
conn.send(request).await?;
61+
match conn.try_next().await? {
62+
Some(result) => result,
63+
None => Err(Error::Internal("Server disconnected".into())),
64+
}
65+
}
66+
67+
/// Executes a query
68+
pub async fn execute(&self, query: &str) -> Result<ResultSet> {
69+
let mut conn = self.conn.lock().await;
70+
let mut resultset =
71+
match self.call_locked(&mut conn, Request::Execute(query.into())).await? {
72+
Response::Execute(rs) => rs,
73+
resp => return Err(Error::Internal(format!("Unexpected response {:?}", resp))),
74+
};
75+
if let ResultSet::Query { columns, .. } = resultset {
76+
// FIXME We buffer rows for now to avoid lifetime hassles
77+
let mut rows = Vec::new();
78+
while let Some(result) = conn.try_next().await? {
79+
match result? {
80+
Response::Row(Some(row)) => rows.push(row),
81+
Response::Row(None) => break,
82+
response => {
83+
return Err(Error::Internal(format!("Unexpected response {:?}", response)))
84+
}
85+
}
86+
}
87+
resultset = ResultSet::Query { columns, rows: Box::new(rows.into_iter().map(Ok)) }
88+
};
89+
match &resultset {
90+
ResultSet::Begin { version, read_only } => self.txn.set(Some((*version, *read_only))),
91+
ResultSet::Commit { .. } => self.txn.set(None),
92+
ResultSet::Rollback { .. } => self.txn.set(None),
93+
_ => {}
94+
}
95+
Ok(resultset)
96+
}
97+
98+
/// Fetches the table schema as SQL
99+
pub async fn get_table(&self, table: &str) -> Result<Table> {
100+
match self.call(Request::GetTable(table.into())).await? {
101+
Response::GetTable(t) => Ok(t),
102+
resp => Err(Error::Value(format!("Unexpected response: {:?}", resp))),
103+
}
104+
}
105+
106+
/// Lists database tables
107+
pub async fn list_tables(&self) -> Result<Vec<String>> {
108+
match self.call(Request::ListTables).await? {
109+
Response::ListTables(t) => Ok(t),
110+
resp => Err(Error::Value(format!("Unexpected response: {:?}", resp))),
111+
}
112+
}
113+
114+
/// Checks server status
115+
pub async fn status(&self) -> Result<Status> {
116+
match self.call(Request::Status).await? {
117+
Response::Status(s) => Ok(s),
118+
resp => Err(Error::Value(format!("Unexpected response: {:?}", resp))),
119+
}
120+
}
121+
122+
/// Returns the version and read-only state of the txn
123+
pub fn txn(&self) -> Option<(u64, bool)> {
124+
self.txn.get()
125+
}
126+
127+
/// Runs a query in a transaction, automatically retrying serialization failures with
128+
/// exponential backoff.
129+
pub async fn with_txn<W, F, R>(&self, mut with: W) -> Result<R>
130+
where
131+
W: FnMut(Client) -> F,
132+
F: Future<Output = Result<R>>,
133+
{
134+
for i in 0..WITH_TXN_RETRIES {
135+
if i > 0 {
136+
tokio::time::sleep(std::time::Duration::from_millis(
137+
2_u64.pow(i as u32 - 1) * rand::thread_rng().gen_range(25..=75),
138+
))
139+
.await;
140+
}
141+
let result = async {
142+
self.execute("BEGIN").await?;
143+
let result = with(self.clone()).await?;
144+
self.execute("COMMIT").await?;
145+
Ok(result)
146+
}
147+
.await;
148+
if result.is_err() {
149+
self.execute("ROLLBACK").await.ok();
150+
if matches!(result, Err(Error::Serialization) | Err(Error::Abort)) {
151+
continue;
152+
}
153+
}
154+
return result;
155+
}
156+
Err(Error::Serialization)
157+
}
158+
}
159+
160+
/// A rrDB client pool
161+
pub struct Pool {
162+
clients: Vec<Mutex<Client>>,
163+
}
164+
165+
impl Pool {
166+
/// Creates a new connection pool for the given servers, eagerly connecting clients.
167+
pub async fn new<A: ToSocketAddrs + Clone>(addrs: Vec<A>, size: u64) -> Result<Self> {
168+
let mut addrs = addrs.into_iter().cycle();
169+
let clients = futures::future::try_join_all(
170+
std::iter::from_fn(|| {
171+
Some(Client::new(addrs.next().unwrap()).map(|r| r.map(Mutex::new)))
172+
})
173+
.take(size as usize),
174+
)
175+
.await?;
176+
Ok(Self { clients })
177+
}
178+
179+
/// Fetches a client from the pool. It is reset (i.e. any open txns are rolled back) and
180+
/// returned when it goes out of scope.
181+
pub async fn get(&self) -> PoolClient<'_> {
182+
let (client, index, _) =
183+
futures::future::select_all(self.clients.iter().map(|m| m.lock().boxed())).await;
184+
PoolClient::new(index, client)
185+
}
186+
187+
/// Returns the size of the pool
188+
pub fn size(&self) -> usize {
189+
self.clients.len()
190+
}
191+
}
192+
193+
/// A client returned from the pool
194+
pub struct PoolClient<'a> {
195+
id: usize,
196+
client: MutexGuard<'a, Client>,
197+
}
198+
199+
impl<'a> PoolClient<'a> {
200+
/// Creates a new PoolClient
201+
fn new(id: usize, client: MutexGuard<'a, Client>) -> Self {
202+
Self { id, client }
203+
}
204+
205+
/// Returns the ID of the client in the pool
206+
pub fn id(&self) -> usize {
207+
self.id
208+
}
209+
}
210+
211+
impl<'a> Deref for PoolClient<'a> {
212+
type Target = MutexGuard<'a, Client>;
213+
214+
fn deref(&self) -> &Self::Target {
215+
&self.client
216+
}
217+
}
218+
219+
impl<'a> Drop for PoolClient<'a> {
220+
fn drop(&mut self) {
221+
if self.txn().is_some() {
222+
// FIXME This should disconnect or destroy the client if it errors.
223+
futures::executor::block_on(self.client.execute("ROLLBACK")).ok();
224+
}
225+
}
226+
}

0 commit comments

Comments
 (0)