Skip to content

Commit

Permalink
Minor: reduce code duplication using rewrite_expr (#5114)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored Feb 1, 2023
1 parent 4c21a72 commit bd64527
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 114 deletions.
22 changes: 4 additions & 18 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use arrow::{
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_rewriter::rewrite_expr;
use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::expressions::Literal;
Expand Down Expand Up @@ -631,23 +631,9 @@ fn rewrite_column_expr(
column_old: &Column,
column_new: &Column,
) -> Result<Expr> {
struct ColumnReplacer<'a> {
old: &'a Column,
new: &'a Column,
}

impl<'a> ExprRewriter for ColumnReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match expr {
Expr::Column(c) if c == *self.old => Ok(Expr::Column(self.new.clone())),
_ => Ok(expr),
}
}
}

e.rewrite(&mut ColumnReplacer {
old: column_old,
new: column_new,
rewrite_expr(e, |expr| match expr {
Expr::Column(c) if c == *column_old => Ok(Expr::Column(column_new.clone())),
_ => Ok(expr),
})
}

Expand Down
64 changes: 28 additions & 36 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use crate::expr_rewriter::{ExprRewritable, ExprRewriter};
use crate::expr_rewriter::rewrite_expr;
use crate::expr_visitor::inspect_expr_pre;
use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
///! Logical plan types
Expand Down Expand Up @@ -702,44 +702,36 @@ impl LogicalPlan {
/// corresponding values provided in the params_values
fn replace_placeholders_with_values(
expr: Expr,
param_values: &Vec<ScalarValue>,
param_values: &[ScalarValue],
) -> Result<Expr, DataFusionError> {
struct PlaceholderReplacer<'a> {
param_values: &'a Vec<ScalarValue>,
}

impl<'a> ExprRewriter for PlaceholderReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr, DataFusionError> {
if let Expr::Placeholder { id, data_type } = &expr {
// convert id (in format $1, $2, ..) to idx (0, 1, ..)
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {e}"
))
})? - 1;
// value at the idx-th position in param_values should be the value for the placeholder
let value = self.param_values.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.get_datatype()) != *data_type {
return Err(DataFusionError::Internal(format!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.get_datatype()
)));
}
// Replace the placeholder with the value
Ok(Expr::Literal(value.clone()))
} else {
Ok(expr)
rewrite_expr(expr, |expr| {
if let Expr::Placeholder { id, data_type } = &expr {
// convert id (in format $1, $2, ..) to idx (0, 1, ..)
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {e}"
))
})? - 1;
// value at the idx-th position in param_values should be the value for the placeholder
let value = param_values.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {id}"
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.get_datatype()) != *data_type {
return Err(DataFusionError::Internal(format!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.get_datatype()
)));
}
// Replace the placeholder with the value
Ok(Expr::Literal(value.clone()))
} else {
Ok(expr)
}
}

expr.rewrite(&mut PlaceholderReplacer { param_values })
})
}
}

Expand Down
27 changes: 10 additions & 17 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
use crate::utils::{conjunction, split_conjunction};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{Column, DFSchema, DataFusionError, Result};
use datafusion_expr::expr_rewriter::rewrite_expr;
use datafusion_expr::{
and,
expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
expr_rewriter::replace_col,
logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
or,
utils::from_plan,
Expand Down Expand Up @@ -763,24 +764,16 @@ pub fn replace_cols_by_name(
e: Expr,
replace_map: &HashMap<String, Expr>,
) -> Result<Expr> {
struct ColumnReplacer<'a> {
replace_map: &'a HashMap<String, Expr>,
}

impl<'a> ExprRewriter for ColumnReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = &expr {
match self.replace_map.get(&c.flat_name()) {
Some(new_c) => Ok(new_c.clone()),
None => Ok(expr),
}
} else {
Ok(expr)
rewrite_expr(e, |expr| {
if let Expr::Column(c) = &expr {
match replace_map.get(&c.flat_name()) {
Some(new_c) => Ok(new_c.clone()),
None => Ok(expr),
}
} else {
Ok(expr)
}
}

e.rewrite(&mut ColumnReplacer { replace_map })
})
}

#[cfg(test)]
Expand Down
75 changes: 32 additions & 43 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::utils::normalize_ident;
use arrow_schema::DataType;
use datafusion_common::{Column, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_rewriter::rewrite_expr;
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetIndexedField, Like, Operator, TryCast,
Expand Down Expand Up @@ -500,48 +500,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

/// Find all `PlaceHolder` tokens in a logical plan, and try to infer their type from context
fn infer_placeholder_types(expr: Expr, schema: DFSchema) -> Result<Expr> {
struct PlaceholderReplacer {
schema: DFSchema,
}

impl ExprRewriter for PlaceholderReplacer {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let expr = match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let left = (*left).clone();
let right = (*right).clone();
let lt = left.get_type(&self.schema);
let rt = right.get_type(&self.schema);
let left = match (&left, rt) {
(Expr::Placeholder { id, data_type }, Ok(dt)) => {
Expr::Placeholder {
id: id.clone(),
data_type: Some(data_type.clone().unwrap_or(dt)),
}
}
_ => left.clone(),
};
let right = match (&right, lt) {
(Expr::Placeholder { id, data_type }, Ok(dt)) => {
Expr::Placeholder {
id: id.clone(),
data_type: Some(data_type.clone().unwrap_or(dt)),
}
}
_ => right.clone(),
};
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
})
}
_ => expr.clone(),
};
Ok(expr)
}
}
expr.rewrite(&mut PlaceholderReplacer { schema })
rewrite_expr(expr, |expr| {
let expr = match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let left = (*left).clone();
let right = (*right).clone();
let lt = left.get_type(&schema);
let rt = right.get_type(&schema);
let left = match (&left, rt) {
(Expr::Placeholder { id, data_type }, Ok(dt)) => Expr::Placeholder {
id: id.clone(),
data_type: Some(data_type.clone().unwrap_or(dt)),
},
_ => left.clone(),
};
let right = match (&right, lt) {
(Expr::Placeholder { id, data_type }, Ok(dt)) => Expr::Placeholder {
id: id.clone(),
data_type: Some(data_type.clone().unwrap_or(dt)),
},
_ => right.clone(),
};
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left),
op,
right: Box::new(right),
})
}
_ => expr.clone(),
};
Ok(expr)
})
}

fn plan_key(key: SQLExpr) -> Result<ScalarValue> {
Expand Down

0 comments on commit bd64527

Please sign in to comment.