diff --git a/src/executor/paginator.rs b/src/executor/paginator.rs index 59453ff1fb..4627ce2ad7 100644 --- a/src/executor/paginator.rs +++ b/src/executor/paginator.rs @@ -1,6 +1,6 @@ use crate::{ error::*, ConnectionTrait, DbBackend, EntityTrait, FromQueryResult, Select, SelectModel, - SelectTwo, SelectTwoModel, Selector, SelectorTrait, + SelectTwo, SelectTwoModel, Selector, SelectorRaw, SelectorTrait, }; use async_stream::stream; use futures::Stream; @@ -219,6 +219,31 @@ where } } +impl<'db, C, S> PaginatorTrait<'db, C> for SelectorRaw +where + C: ConnectionTrait, + S: SelectorTrait + Send + Sync + 'db, +{ + type Selector = S; + fn paginate(self, db: &'db C, page_size: usize) -> Paginator<'db, C, S> { + let sql = &self.stmt.sql[6..]; + let mut query = SelectStatement::new(); + query.expr(if let Some(values) = self.stmt.values { + Expr::cust_with_values(sql, values.0) + } else { + Expr::cust(sql) + }); + + Paginator { + query, + page: 0, + page_size, + db, + selector: PhantomData, + } + } +} + impl<'db, C, M, E> PaginatorTrait<'db, C> for Select where C: ConnectionTrait, @@ -252,11 +277,20 @@ where mod tests { use super::*; use crate::entity::prelude::*; - use crate::{tests_cfg::*, ConnectionTrait}; + use crate::{tests_cfg::*, ConnectionTrait, Statement}; use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction}; use futures::TryStreamExt; + use once_cell::sync::Lazy; use sea_query::{Alias, Expr, SelectStatement, Value}; + static RAW_STMT: Lazy = Lazy::new(|| { + Statement::from_sql_and_values( + DbBackend::Postgres, + r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#, + vec![], + ) + }); + fn setup() -> (DatabaseConnection, Vec>) { let page1 = vec![ fruit::Model { @@ -327,6 +361,38 @@ mod tests { Ok(()) } + #[smol_potat::test] + async fn fetch_page_raw() -> Result<(), DbErr> { + let (db, pages) = setup(); + + let paginator = fruit::Entity::find() + .from_raw_sql(RAW_STMT.clone()) + .paginate(&db, 2); + + assert_eq!(paginator.fetch_page(0).await?, pages[0].clone()); + assert_eq!(paginator.fetch_page(1).await?, pages[1].clone()); + assert_eq!(paginator.fetch_page(2).await?, pages[2].clone()); + + let mut select = SelectStatement::new() + .exprs(vec![ + Expr::tbl(fruit::Entity, fruit::Column::Id), + Expr::tbl(fruit::Entity, fruit::Column::Name), + Expr::tbl(fruit::Entity, fruit::Column::CakeId), + ]) + .from(fruit::Entity) + .to_owned(); + + let query_builder = db.get_database_backend(); + let stmts = vec![ + query_builder.build(select.clone().offset(0).limit(2)), + query_builder.build(select.clone().offset(2).limit(2)), + query_builder.build(select.offset(4).limit(2)), + ]; + + assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + Ok(()) + } + #[smol_potat::test] async fn fetch() -> Result<(), DbErr> { let (db, pages) = setup(); @@ -361,6 +427,42 @@ mod tests { Ok(()) } + #[smol_potat::test] + async fn fetch_raw() -> Result<(), DbErr> { + let (db, pages) = setup(); + + let mut paginator = fruit::Entity::find() + .from_raw_sql(RAW_STMT.clone()) + .paginate(&db, 2); + + assert_eq!(paginator.fetch().await?, pages[0].clone()); + paginator.next(); + + assert_eq!(paginator.fetch().await?, pages[1].clone()); + paginator.next(); + + assert_eq!(paginator.fetch().await?, pages[2].clone()); + + let mut select = SelectStatement::new() + .exprs(vec![ + Expr::tbl(fruit::Entity, fruit::Column::Id), + Expr::tbl(fruit::Entity, fruit::Column::Name), + Expr::tbl(fruit::Entity, fruit::Column::CakeId), + ]) + .from(fruit::Entity) + .to_owned(); + + let query_builder = db.get_database_backend(); + let stmts = vec![ + query_builder.build(select.clone().offset(0).limit(2)), + query_builder.build(select.clone().offset(2).limit(2)), + query_builder.build(select.offset(4).limit(2)), + ]; + + assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + Ok(()) + } + #[smol_potat::test] async fn num_pages() -> Result<(), DbErr> { let (db, num_items) = setup_num_items(); @@ -393,6 +495,40 @@ mod tests { Ok(()) } + #[smol_potat::test] + async fn num_pages_raw() -> Result<(), DbErr> { + let (db, num_items) = setup_num_items(); + + let num_items = num_items as usize; + let page_size = 2_usize; + let num_pages = (num_items / page_size) + (num_items % page_size > 0) as usize; + let paginator = fruit::Entity::find() + .from_raw_sql(RAW_STMT.clone()) + .paginate(&db, page_size); + + assert_eq!(paginator.num_pages().await?, num_pages); + + let sub_query = SelectStatement::new() + .exprs(vec![ + Expr::tbl(fruit::Entity, fruit::Column::Id), + Expr::tbl(fruit::Entity, fruit::Column::Name), + Expr::tbl(fruit::Entity, fruit::Column::CakeId), + ]) + .from(fruit::Entity) + .to_owned(); + + let select = SelectStatement::new() + .expr(Expr::cust("COUNT(*) AS num_items")) + .from_subquery(sub_query, Alias::new("sub_query")) + .to_owned(); + + let query_builder = db.get_database_backend(); + let stmts = vec![query_builder.build(&select)]; + + assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + Ok(()) + } + #[smol_potat::test] async fn next_and_cur_page() -> Result<(), DbErr> { let (db, _) = setup(); @@ -409,6 +545,24 @@ mod tests { Ok(()) } + #[smol_potat::test] + async fn next_and_cur_page_raw() -> Result<(), DbErr> { + let (db, _) = setup(); + + let mut paginator = fruit::Entity::find() + .from_raw_sql(RAW_STMT.clone()) + .paginate(&db, 2); + + assert_eq!(paginator.cur_page(), 0); + paginator.next(); + + assert_eq!(paginator.cur_page(), 1); + paginator.next(); + + assert_eq!(paginator.cur_page(), 2); + Ok(()) + } + #[smol_potat::test] async fn fetch_and_next() -> Result<(), DbErr> { let (db, pages) = setup(); @@ -444,6 +598,43 @@ mod tests { Ok(()) } + #[smol_potat::test] + async fn fetch_and_next_raw() -> Result<(), DbErr> { + let (db, pages) = setup(); + + let mut paginator = fruit::Entity::find() + .from_raw_sql(RAW_STMT.clone()) + .paginate(&db, 2); + + assert_eq!(paginator.cur_page(), 0); + assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone())); + + assert_eq!(paginator.cur_page(), 1); + assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone())); + + assert_eq!(paginator.cur_page(), 2); + assert_eq!(paginator.fetch_and_next().await?, None); + + let mut select = SelectStatement::new() + .exprs(vec![ + Expr::tbl(fruit::Entity, fruit::Column::Id), + Expr::tbl(fruit::Entity, fruit::Column::Name), + Expr::tbl(fruit::Entity, fruit::Column::CakeId), + ]) + .from(fruit::Entity) + .to_owned(); + + let query_builder = db.get_database_backend(); + let stmts = vec![ + query_builder.build(select.clone().offset(0).limit(2)), + query_builder.build(select.clone().offset(2).limit(2)), + query_builder.build(select.offset(4).limit(2)), + ]; + + assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + Ok(()) + } + #[smol_potat::test] async fn into_stream() -> Result<(), DbErr> { let (db, pages) = setup(); @@ -475,4 +666,39 @@ mod tests { assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); Ok(()) } + + #[smol_potat::test] + async fn into_stream_raw() -> Result<(), DbErr> { + let (db, pages) = setup(); + + let mut fruit_stream = fruit::Entity::find() + .from_raw_sql(RAW_STMT.clone()) + .paginate(&db, 2) + .into_stream(); + + assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone())); + assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone())); + assert_eq!(fruit_stream.try_next().await?, None); + + drop(fruit_stream); + + let mut select = SelectStatement::new() + .exprs(vec![ + Expr::tbl(fruit::Entity, fruit::Column::Id), + Expr::tbl(fruit::Entity, fruit::Column::Name), + Expr::tbl(fruit::Entity, fruit::Column::CakeId), + ]) + .from(fruit::Entity) + .to_owned(); + + let query_builder = db.get_database_backend(); + let stmts = vec![ + query_builder.build(select.clone().offset(0).limit(2)), + query_builder.build(select.clone().offset(2).limit(2)), + query_builder.build(select.offset(4).limit(2)), + ]; + + assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts)); + Ok(()) + } } diff --git a/src/executor/select.rs b/src/executor/select.rs index aa0b86de14..b9c997585f 100644 --- a/src/executor/select.rs +++ b/src/executor/select.rs @@ -27,7 +27,7 @@ pub struct SelectorRaw where S: SelectorTrait, { - stmt: Statement, + pub(crate) stmt: Statement, #[allow(dead_code)] selector: S, }