From 0e1c825e9e8ea2f34b794cbba707ef2e7178e54d Mon Sep 17 00:00:00 2001 From: Marco Napetti Date: Tue, 31 May 2022 12:38:38 +0200 Subject: [PATCH] Stream metrics --- src/database/stream/metric.rs | 56 ++++++++++++++++++++++++++++++ src/database/stream/mod.rs | 2 ++ src/database/stream/query.rs | 48 +++++++++++++------------ src/database/stream/transaction.rs | 51 ++++++++++++++------------- src/metric.rs | 23 ++---------- 5 files changed, 113 insertions(+), 67 deletions(-) create mode 100644 src/database/stream/metric.rs diff --git a/src/database/stream/metric.rs b/src/database/stream/metric.rs new file mode 100644 index 0000000000..b0afab679b --- /dev/null +++ b/src/database/stream/metric.rs @@ -0,0 +1,56 @@ +use std::{time::Duration, pin::Pin, task::Poll}; + +use futures::Stream; + +use crate::{QueryResult, DbErr, Statement}; + +pub(crate) struct MetricStream<'a> { + metric_callback: &'a Option, + stmt: &'a Statement, + elapsed: Option, + stream: Pin> + 'a + Send>>, +} + +impl<'a> MetricStream<'a> { + pub(crate) fn new(metric_callback: &'a Option, stmt: &'a Statement, elapsed: Option, stream: S) -> Self + where + S: Stream> + 'a + Send, + { + MetricStream { + metric_callback, + stmt, + elapsed, + stream: Box::pin(stream), + } + } +} + +impl<'a> Stream for MetricStream<'a> { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.get_mut(); + let _start = this.metric_callback.is_some().then(std::time::SystemTime::now); + let res = Pin::new(&mut this.stream).poll_next(cx); + if let (Some(_start), Some(elapsed)) = (_start, &mut this.elapsed) { + *elapsed += _start.elapsed().unwrap_or_default(); + } + res + } +} + +impl<'a> Drop for MetricStream<'a> { + fn drop(&mut self) { + if let (Some(callback), Some(elapsed)) = (self.metric_callback.as_deref(), self.elapsed) { + let info = crate::metric::Info { + elapsed: elapsed, + statement: self.stmt, + failed: false, + }; + callback(&info); + } + } +} diff --git a/src/database/stream/mod.rs b/src/database/stream/mod.rs index 774cf45fa7..deb4495f05 100644 --- a/src/database/stream/mod.rs +++ b/src/database/stream/mod.rs @@ -1,3 +1,5 @@ +mod metric; + mod query; mod transaction; diff --git a/src/database/stream/query.rs b/src/database/stream/query.rs index 7a8eba1346..eb66cc7777 100644 --- a/src/database/stream/query.rs +++ b/src/database/stream/query.rs @@ -1,6 +1,6 @@ #![allow(missing_docs)] -use std::{pin::Pin, task::Poll}; +use std::{pin::Pin, task::Poll, time::SystemTime}; #[cfg(feature = "mock")] use std::sync::Arc; @@ -16,6 +16,8 @@ use tracing::instrument; use crate::{DbErr, InnerConnection, QueryResult, Statement}; +use super::metric::MetricStream; + /// Creates a stream from a [QueryResult] #[ouroboros::self_referencing] pub struct QueryStream { @@ -24,7 +26,7 @@ pub struct QueryStream { metric_callback: Option, #[borrows(mut conn, stmt, metric_callback)] #[not_covariant] - stream: Pin> + Send + 'this>>, + stream: MetricStream<'this>, } #[cfg(feature = "sqlx-mysql")] @@ -124,38 +126,40 @@ impl QueryStream { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { let query = crate::driver::sqlx_mysql::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(query) .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - }) + .map_err(crate::sqlx_error_to_query_err); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(c) => { let query = crate::driver::sqlx_postgres::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(query) .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - }) + .map_err(crate::sqlx_error_to_query_err); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(c) => { let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(query) .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - }) + .map_err(crate::sqlx_error_to_query_err); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) } #[cfg(feature = "mock")] - InnerConnection::Mock(c) => c.fetch(stmt), + InnerConnection::Mock(c) => { + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(stmt); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) + }, #[allow(unreachable_patterns)] _ => unreachable!(), }, @@ -172,6 +176,6 @@ impl Stream for QueryStream { cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.get_mut(); - this.with_stream_mut(|stream| stream.as_mut().poll_next(cx)) + this.with_stream_mut(|stream| Pin::new(stream).poll_next(cx)) } } diff --git a/src/database/stream/transaction.rs b/src/database/stream/transaction.rs index 74b04421ce..96b00982eb 100644 --- a/src/database/stream/transaction.rs +++ b/src/database/stream/transaction.rs @@ -1,6 +1,6 @@ #![allow(missing_docs)] -use std::{ops::DerefMut, pin::Pin, task::Poll}; +use std::{ops::DerefMut, pin::Pin, task::Poll, time::SystemTime}; use futures::Stream; #[cfg(feature = "sqlx-dep")] @@ -15,6 +15,8 @@ use tracing::instrument; use crate::{DbErr, InnerConnection, QueryResult, Statement}; +use super::metric::MetricStream; + /// `TransactionStream` cannot be used in a `transaction` closure as it does not impl `Send`. /// It seems to be a Rust limitation right now, and solution to work around this deemed to be extremely hard. #[ouroboros::self_referencing] @@ -24,7 +26,7 @@ pub struct TransactionStream<'a> { metric_callback: Option, #[borrows(mut conn, stmt, metric_callback)] #[not_covariant] - stream: Pin> + 'this + Send>>, + stream: MetricStream<'this>, } impl<'a> std::fmt::Debug for TransactionStream<'a> { @@ -48,41 +50,40 @@ impl<'a> TransactionStream<'a> { #[cfg(feature = "sqlx-mysql")] InnerConnection::MySql(c) => { let query = crate::driver::sqlx_mysql::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(query) .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - as Pin> + Send>> - }) + .map_err(crate::sqlx_error_to_query_err); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) } #[cfg(feature = "sqlx-postgres")] InnerConnection::Postgres(c) => { let query = crate::driver::sqlx_postgres::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(query) .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - as Pin> + Send>> - }) + .map_err(crate::sqlx_error_to_query_err); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) } #[cfg(feature = "sqlx-sqlite")] InnerConnection::Sqlite(c) => { let query = crate::driver::sqlx_sqlite::sqlx_query(stmt); - crate::metric::metric_ok!(_metric_callback, stmt, { - Box::pin( - c.fetch(query) + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(query) .map_ok(Into::into) - .map_err(crate::sqlx_error_to_query_err), - ) - as Pin> + Send>> - }) + .map_err(crate::sqlx_error_to_query_err); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) } #[cfg(feature = "mock")] - InnerConnection::Mock(c) => c.fetch(stmt), + InnerConnection::Mock(c) => { + let _start = _metric_callback.is_some().then(SystemTime::now); + let stream = c.fetch(stmt); + let elapsed = _start.map(|s| s.elapsed().unwrap_or_default()); + MetricStream::new(_metric_callback, stmt, elapsed, stream) + }, #[allow(unreachable_patterns)] _ => unreachable!(), }, @@ -99,6 +100,6 @@ impl<'a> Stream for TransactionStream<'a> { cx: &mut std::task::Context<'_>, ) -> Poll> { let this = self.get_mut(); - this.with_stream_mut(|stream| stream.as_mut().poll_next(cx)) + this.with_stream_mut(|stream| Pin::new(stream).poll_next(cx)) } } diff --git a/src/metric.rs b/src/metric.rs index 97da30a1fe..6e9521994e 100644 --- a/src/metric.rs +++ b/src/metric.rs @@ -3,7 +3,7 @@ use std::{sync::Arc, time::Duration}; pub(crate) type Callback = Arc) + Send + Sync>; #[allow(unused_imports)] -pub(crate) use inner::{metric, metric_ok}; +pub(crate) use inner::metric; #[derive(Debug)] /// Query execution infos @@ -20,9 +20,9 @@ mod inner { #[allow(unused_macros)] macro_rules! metric { ($metric_callback:expr, $stmt:expr, $code:block) => {{ - let _start = std::time::SystemTime::now(); + let _start = $metric_callback.is_some().then(std::time::SystemTime::now); let res = $code; - if let Some(callback) = $metric_callback.as_deref() { + if let (Some(_start), Some(callback)) = (_start, $metric_callback.as_deref()) { let info = crate::metric::Info { elapsed: _start.elapsed().unwrap_or_default(), statement: $stmt, @@ -34,21 +34,4 @@ mod inner { }}; } pub(crate) use metric; - #[allow(unused_macros)] - macro_rules! metric_ok { - ($metric_callback:expr, $stmt:expr, $code:block) => {{ - let _start = std::time::SystemTime::now(); - let res = $code; - if let Some(callback) = $metric_callback.as_deref() { - let info = crate::metric::Info { - elapsed: _start.elapsed().unwrap_or_default(), - statement: $stmt, - failed: false, - }; - callback(&info); - } - res - }}; - } - pub(crate) use metric_ok; }