Skip to content

Commit

Permalink
Replace parallel condition/result vectors with single CaseWhen vector…
Browse files Browse the repository at this point in the history
… in Expr::Case

The primary motivation for this change is to fix the visitor traversal order for CASE expressions. In SQL, CASE expressions follow a specific syntactic order (e.g., `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5`), AST visitors now process nodes in the same order as they appear in the source code. The previous implementation, using separate `conditions` and `results` vectors, would visit all conditions first and then all results, which didn't match the source order. The new `CaseWhen` structure ensures visitors process expressions in the correct order: `a,1,2,3,4,5`.

A secondary benefit is making invalid states unrepresentable in the type system. The previous implementation using parallel vectors (`conditions` and `results`) made it possible to create invalid CASE expressions where the number of conditions didn't match the number of results. When this happened, the `Display` implementation would silently drop elements from the longer list, potentially masking bugs. The new `CaseWhen` struct couples each condition with its result, making it impossible to create such mismatched states.

While this is a breaking change to the AST structure, sqlparser has a history of making such changes when they improve correctness. I don't expect significant downstream breakages, and the benefits of correct visitor ordering and type safety are significant, so I think the trade-off is worthwhile.
  • Loading branch information
lovasoa committed Feb 20, 2025
1 parent 97f0be6 commit 0503160
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 50 deletions.
25 changes: 19 additions & 6 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,22 @@ pub enum CeilFloorKind {
Scale(Value),
}

/// A WHEN clause in a CASE expression containing both
/// the condition and its corresponding result
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseWhen {
pub condition: Expr,
pub result: Expr,
}

impl fmt::Display for CaseWhen {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WHEN {} THEN {}", self.condition, self.result)
}
}

/// An SQL expression of any type.
///
/// # Semantics / Type Checking
Expand Down Expand Up @@ -917,8 +933,7 @@ pub enum Expr {
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
Case {
operand: Option<Box<Expr>>,
conditions: Vec<Expr>,
results: Vec<Expr>,
conditions: Vec<CaseWhen>,
else_result: Option<Box<Expr>>,
},
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
Expand Down Expand Up @@ -1612,17 +1627,15 @@ impl fmt::Display for Expr {
Expr::Case {
operand,
conditions,
results,
else_result,
} => {
write!(f, "CASE")?;
if let Some(operand) = operand {
write!(f, " {operand}")?;
}
for (c, r) in conditions.iter().zip(results) {
write!(f, " WHEN {c} THEN {r}")?;
for when in conditions {
write!(f, " {when}")?;
}

if let Some(else_result) = else_result {
write!(f, " ELSE {else_result}")?;
}
Expand Down
6 changes: 3 additions & 3 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1444,15 +1444,15 @@ impl Spanned for Expr {
Expr::Case {
operand,
conditions,
results,
else_result,
} => union_spans(
operand
.as_ref()
.map(|i| i.span())
.into_iter()
.chain(conditions.iter().map(|i| i.span()))
.chain(results.iter().map(|i| i.span()))
.chain(conditions.iter().flat_map(|case_when| {
[case_when.condition.span(), case_when.result.span()]
}))
.chain(else_result.as_ref().map(|i| i.span())),
),
Expr::Exists { subquery, .. } => subquery.span(),
Expand Down
7 changes: 3 additions & 4 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2020,11 +2020,11 @@ impl<'a> Parser<'a> {
self.expect_keyword_is(Keyword::WHEN)?;
}
let mut conditions = vec![];
let mut results = vec![];
loop {
conditions.push(self.parse_expr()?);
let condition = self.parse_expr()?;
self.expect_keyword_is(Keyword::THEN)?;
results.push(self.parse_expr()?);
let result = self.parse_expr()?;
conditions.push(CaseWhen { condition, result });
if !self.parse_keyword(Keyword::WHEN) {
break;
}
Expand All @@ -2038,7 +2038,6 @@ impl<'a> Parser<'a> {
Ok(Expr::Case {
operand,
conditions,
results,
else_result,
})
}
Expand Down
107 changes: 70 additions & 37 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6539,22 +6539,26 @@ fn parse_searched_case_expr() {
&Case {
operand: None,
conditions: vec![
IsNull(Box::new(Identifier(Ident::new("bar")))),
BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: Eq,
right: Box::new(Expr::Value(number("0"))),
CaseWhen {
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
},
BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq,
right: Box::new(Expr::Value(number("0"))),
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: Eq,
right: Box::new(Expr::Value(number("0"))),
},
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
},
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq,
right: Box::new(Expr::Value(number("0"))),
},
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
},
],
results: vec![
Expr::Value(Value::SingleQuotedString("null".to_string())),
Expr::Value(Value::SingleQuotedString("=0".to_string())),
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"<0".to_string()
Expand All @@ -6573,8 +6577,10 @@ fn parse_simple_case_expr() {
assert_eq!(
&Case {
operand: Some(Box::new(Identifier(Ident::new("foo")))),
conditions: vec![Expr::Value(number("1"))],
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
conditions: vec![CaseWhen {
condition: Expr::Value(number("1")),
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
}],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"N".to_string()
)))),
Expand Down Expand Up @@ -13734,6 +13740,31 @@ fn test_trailing_commas_in_from() {
);
}

#[test]
#[cfg(feature = "visitor")]
fn test_visit_order() {
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
let stmt = verified_stmt(sql);
let mut visited = vec![];
sqlparser::ast::visit_expressions(&stmt, |expr| {
visited.push(expr.to_string());
core::ops::ControlFlow::<()>::Continue(())
});

assert_eq!(
visited,
[
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
"a",
"1",
"2",
"3",
"4",
"5"
]
);
}

#[test]
fn test_lambdas() {
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
Expand Down Expand Up @@ -13761,28 +13792,30 @@ fn test_lambdas() {
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
}
],
results: vec![
Expr::Value(number("0")),
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
}
],
else_result: Some(Box::new(Expr::Value(number("1"))))
Expand Down
65 changes: 65 additions & 0 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,71 @@ fn test_databricks_exists() {
);
}

#[test]
fn test_databricks_lambdas() {
#[rustfmt::skip]
let sql = concat!(
"SELECT array_sort(array('Hello', 'World'), ",
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
"ELSE 1 END)",
);
pretty_assertions::assert_eq!(
SelectItem::UnnamedExpr(call(
"array_sort",
[
call(
"array",
[
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
Expr::Value(Value::SingleQuotedString("World".to_owned()))
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
)),
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
},
],
else_result: Some(Box::new(Expr::Value(number("1"))))
})
})
]
)),
databricks().verified_only_select(sql).projection[0]
);

databricks().verified_expr(
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
);
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
}

#[test]
fn test_values_clause() {
let values = Values {
Expand Down

0 comments on commit 0503160

Please sign in to comment.