Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PaginatorTrait for SelectorRaw #617

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 228 additions & 2 deletions src/executor/paginator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::*, ConnectionTrait, DbBackend, EntityTrait, FromQueryResult, Select, SelectModel,
SelectTwo, SelectTwoModel, Selector, SelectorTrait,
SelectTwo, SelectTwoModel, Selector, SelectorRaw, SelectorTrait,
};
use async_stream::stream;
use futures::Stream;
Expand Down Expand Up @@ -219,6 +219,31 @@ where
}
}

impl<'db, C, S> PaginatorTrait<'db, C> for SelectorRaw<S>
where
C: ConnectionTrait,
S: SelectorTrait + Send + Sync + 'db,
{
type Selector = S;
fn paginate(self, db: &'db C, page_size: usize) -> Paginator<'db, C, S> {
let sql = &self.stmt.sql[6..];
let mut query = SelectStatement::new();
query.expr(if let Some(values) = self.stmt.values {
Expr::cust_with_values(sql, values.0)
} else {
Expr::cust(sql)
});

Paginator {
query,
page: 0,
page_size,
db,
selector: PhantomData,
}
}
}

impl<'db, C, M, E> PaginatorTrait<'db, C> for Select<E>
where
C: ConnectionTrait,
Expand Down Expand Up @@ -252,11 +277,20 @@ where
mod tests {
use super::*;
use crate::entity::prelude::*;
use crate::{tests_cfg::*, ConnectionTrait};
use crate::{tests_cfg::*, ConnectionTrait, Statement};
use crate::{DatabaseConnection, DbBackend, MockDatabase, Transaction};
use futures::TryStreamExt;
use once_cell::sync::Lazy;
use sea_query::{Alias, Expr, SelectStatement, Value};

static RAW_STMT: Lazy<Statement> = Lazy::new(|| {
Statement::from_sql_and_values(
DbBackend::Postgres,
r#"SELECT "fruit"."id", "fruit"."name", "fruit"."cake_id" FROM "fruit""#,
vec![],
)
});

fn setup() -> (DatabaseConnection, Vec<Vec<fruit::Model>>) {
let page1 = vec![
fruit::Model {
Expand Down Expand Up @@ -327,6 +361,38 @@ mod tests {
Ok(())
}

#[smol_potat::test]
async fn fetch_page_raw() -> Result<(), DbErr> {
let (db, pages) = setup();

let paginator = fruit::Entity::find()
.from_raw_sql(RAW_STMT.clone())
.paginate(&db, 2);

assert_eq!(paginator.fetch_page(0).await?, pages[0].clone());
assert_eq!(paginator.fetch_page(1).await?, pages[1].clone());
assert_eq!(paginator.fetch_page(2).await?, pages[2].clone());

let mut select = SelectStatement::new()
.exprs(vec![
Expr::tbl(fruit::Entity, fruit::Column::Id),
Expr::tbl(fruit::Entity, fruit::Column::Name),
Expr::tbl(fruit::Entity, fruit::Column::CakeId),
])
.from(fruit::Entity)
.to_owned();

let query_builder = db.get_database_backend();
let stmts = vec![
query_builder.build(select.clone().offset(0).limit(2)),
query_builder.build(select.clone().offset(2).limit(2)),
query_builder.build(select.offset(4).limit(2)),
];

assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
Ok(())
}

#[smol_potat::test]
async fn fetch() -> Result<(), DbErr> {
let (db, pages) = setup();
Expand Down Expand Up @@ -361,6 +427,42 @@ mod tests {
Ok(())
}

#[smol_potat::test]
async fn fetch_raw() -> Result<(), DbErr> {
let (db, pages) = setup();

let mut paginator = fruit::Entity::find()
.from_raw_sql(RAW_STMT.clone())
.paginate(&db, 2);

assert_eq!(paginator.fetch().await?, pages[0].clone());
paginator.next();

assert_eq!(paginator.fetch().await?, pages[1].clone());
paginator.next();

assert_eq!(paginator.fetch().await?, pages[2].clone());

let mut select = SelectStatement::new()
.exprs(vec![
Expr::tbl(fruit::Entity, fruit::Column::Id),
Expr::tbl(fruit::Entity, fruit::Column::Name),
Expr::tbl(fruit::Entity, fruit::Column::CakeId),
])
.from(fruit::Entity)
.to_owned();

let query_builder = db.get_database_backend();
let stmts = vec![
query_builder.build(select.clone().offset(0).limit(2)),
query_builder.build(select.clone().offset(2).limit(2)),
query_builder.build(select.offset(4).limit(2)),
];

assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
Ok(())
}

#[smol_potat::test]
async fn num_pages() -> Result<(), DbErr> {
let (db, num_items) = setup_num_items();
Expand Down Expand Up @@ -393,6 +495,40 @@ mod tests {
Ok(())
}

#[smol_potat::test]
async fn num_pages_raw() -> Result<(), DbErr> {
let (db, num_items) = setup_num_items();

let num_items = num_items as usize;
let page_size = 2_usize;
let num_pages = (num_items / page_size) + (num_items % page_size > 0) as usize;
let paginator = fruit::Entity::find()
.from_raw_sql(RAW_STMT.clone())
.paginate(&db, page_size);

assert_eq!(paginator.num_pages().await?, num_pages);

let sub_query = SelectStatement::new()
.exprs(vec![
Expr::tbl(fruit::Entity, fruit::Column::Id),
Expr::tbl(fruit::Entity, fruit::Column::Name),
Expr::tbl(fruit::Entity, fruit::Column::CakeId),
])
.from(fruit::Entity)
.to_owned();

let select = SelectStatement::new()
.expr(Expr::cust("COUNT(*) AS num_items"))
.from_subquery(sub_query, Alias::new("sub_query"))
.to_owned();

let query_builder = db.get_database_backend();
let stmts = vec![query_builder.build(&select)];

assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
Ok(())
}

#[smol_potat::test]
async fn next_and_cur_page() -> Result<(), DbErr> {
let (db, _) = setup();
Expand All @@ -409,6 +545,24 @@ mod tests {
Ok(())
}

#[smol_potat::test]
async fn next_and_cur_page_raw() -> Result<(), DbErr> {
let (db, _) = setup();

let mut paginator = fruit::Entity::find()
.from_raw_sql(RAW_STMT.clone())
.paginate(&db, 2);

assert_eq!(paginator.cur_page(), 0);
paginator.next();

assert_eq!(paginator.cur_page(), 1);
paginator.next();

assert_eq!(paginator.cur_page(), 2);
Ok(())
}

#[smol_potat::test]
async fn fetch_and_next() -> Result<(), DbErr> {
let (db, pages) = setup();
Expand Down Expand Up @@ -444,6 +598,43 @@ mod tests {
Ok(())
}

#[smol_potat::test]
async fn fetch_and_next_raw() -> Result<(), DbErr> {
let (db, pages) = setup();

let mut paginator = fruit::Entity::find()
.from_raw_sql(RAW_STMT.clone())
.paginate(&db, 2);

assert_eq!(paginator.cur_page(), 0);
assert_eq!(paginator.fetch_and_next().await?, Some(pages[0].clone()));

assert_eq!(paginator.cur_page(), 1);
assert_eq!(paginator.fetch_and_next().await?, Some(pages[1].clone()));

assert_eq!(paginator.cur_page(), 2);
assert_eq!(paginator.fetch_and_next().await?, None);

let mut select = SelectStatement::new()
.exprs(vec![
Expr::tbl(fruit::Entity, fruit::Column::Id),
Expr::tbl(fruit::Entity, fruit::Column::Name),
Expr::tbl(fruit::Entity, fruit::Column::CakeId),
])
.from(fruit::Entity)
.to_owned();

let query_builder = db.get_database_backend();
let stmts = vec![
query_builder.build(select.clone().offset(0).limit(2)),
query_builder.build(select.clone().offset(2).limit(2)),
query_builder.build(select.offset(4).limit(2)),
];

assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
Ok(())
}

#[smol_potat::test]
async fn into_stream() -> Result<(), DbErr> {
let (db, pages) = setup();
Expand Down Expand Up @@ -475,4 +666,39 @@ mod tests {
assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
Ok(())
}

#[smol_potat::test]
async fn into_stream_raw() -> Result<(), DbErr> {
let (db, pages) = setup();

let mut fruit_stream = fruit::Entity::find()
.from_raw_sql(RAW_STMT.clone())
.paginate(&db, 2)
.into_stream();

assert_eq!(fruit_stream.try_next().await?, Some(pages[0].clone()));
assert_eq!(fruit_stream.try_next().await?, Some(pages[1].clone()));
assert_eq!(fruit_stream.try_next().await?, None);

drop(fruit_stream);

let mut select = SelectStatement::new()
.exprs(vec![
Expr::tbl(fruit::Entity, fruit::Column::Id),
Expr::tbl(fruit::Entity, fruit::Column::Name),
Expr::tbl(fruit::Entity, fruit::Column::CakeId),
])
.from(fruit::Entity)
.to_owned();

let query_builder = db.get_database_backend();
let stmts = vec![
query_builder.build(select.clone().offset(0).limit(2)),
query_builder.build(select.clone().offset(2).limit(2)),
query_builder.build(select.offset(4).limit(2)),
];

assert_eq!(db.into_transaction_log(), Transaction::wrap(stmts));
Ok(())
}
}
2 changes: 1 addition & 1 deletion src/executor/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub struct SelectorRaw<S>
where
S: SelectorTrait,
{
stmt: Statement,
pub(crate) stmt: Statement,
#[allow(dead_code)]
selector: S,
}
Expand Down