diff --git a/src/ast/mod.rs b/src/ast/mod.rs index efdad164e..841703862 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -599,6 +599,22 @@ pub enum CeilFloorKind { Scale(Value), } +/// A WHEN clause in a CASE expression containing both +/// the condition and its corresponding result +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct CaseWhen { + pub condition: Expr, + pub result: Expr, +} + +impl fmt::Display for CaseWhen { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "WHEN {} THEN {}", self.condition, self.result) + } +} + /// An SQL expression of any type. /// /// # Semantics / Type Checking @@ -917,8 +933,7 @@ pub enum Expr { /// Case { operand: Option>, - conditions: Vec, - results: Vec, + conditions: Vec, else_result: Option>, }, /// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like @@ -1612,17 +1627,15 @@ impl fmt::Display for Expr { Expr::Case { operand, conditions, - results, else_result, } => { write!(f, "CASE")?; if let Some(operand) = operand { write!(f, " {operand}")?; } - for (c, r) in conditions.iter().zip(results) { - write!(f, " WHEN {c} THEN {r}")?; + for when in conditions { + write!(f, " {when}")?; } - if let Some(else_result) = else_result { write!(f, " ELSE {else_result}")?; } diff --git a/src/ast/spans.rs b/src/ast/spans.rs index de39e50d6..fb2b318c0 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1444,15 +1444,15 @@ impl Spanned for Expr { Expr::Case { operand, conditions, - results, else_result, } => union_spans( operand .as_ref() .map(|i| i.span()) .into_iter() - .chain(conditions.iter().map(|i| i.span())) - .chain(results.iter().map(|i| i.span())) + .chain(conditions.iter().flat_map(|case_when| { + [case_when.condition.span(), case_when.result.span()] + })) .chain(else_result.as_ref().map(|i| i.span())), ), Expr::Exists { subquery, .. } => subquery.span(), diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 9c021d918..04a9fc0bf 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -2020,11 +2020,11 @@ impl<'a> Parser<'a> { self.expect_keyword_is(Keyword::WHEN)?; } let mut conditions = vec![]; - let mut results = vec![]; loop { - conditions.push(self.parse_expr()?); + let condition = self.parse_expr()?; self.expect_keyword_is(Keyword::THEN)?; - results.push(self.parse_expr()?); + let result = self.parse_expr()?; + conditions.push(CaseWhen { condition, result }); if !self.parse_keyword(Keyword::WHEN) { break; } @@ -2038,7 +2038,6 @@ impl<'a> Parser<'a> { Ok(Expr::Case { operand, conditions, - results, else_result, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index a4e83be06..260cde9a8 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6539,22 +6539,26 @@ fn parse_searched_case_expr() { &Case { operand: None, conditions: vec![ - IsNull(Box::new(Identifier(Ident::new("bar")))), - BinaryOp { - left: Box::new(Identifier(Ident::new("bar"))), - op: Eq, - right: Box::new(Expr::Value(number("0"))), + CaseWhen { + condition: IsNull(Box::new(Identifier(Ident::new("bar")))), + result: Expr::Value(Value::SingleQuotedString("null".to_string())), }, - BinaryOp { - left: Box::new(Identifier(Ident::new("bar"))), - op: GtEq, - right: Box::new(Expr::Value(number("0"))), + CaseWhen { + condition: BinaryOp { + left: Box::new(Identifier(Ident::new("bar"))), + op: Eq, + right: Box::new(Expr::Value(number("0"))), + }, + result: Expr::Value(Value::SingleQuotedString("=0".to_string())), + }, + CaseWhen { + condition: BinaryOp { + left: Box::new(Identifier(Ident::new("bar"))), + op: GtEq, + right: Box::new(Expr::Value(number("0"))), + }, + result: Expr::Value(Value::SingleQuotedString(">=0".to_string())), }, - ], - results: vec![ - Expr::Value(Value::SingleQuotedString("null".to_string())), - Expr::Value(Value::SingleQuotedString("=0".to_string())), - Expr::Value(Value::SingleQuotedString(">=0".to_string())), ], else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString( "<0".to_string() @@ -6573,8 +6577,10 @@ fn parse_simple_case_expr() { assert_eq!( &Case { operand: Some(Box::new(Identifier(Ident::new("foo")))), - conditions: vec![Expr::Value(number("1"))], - results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))], + conditions: vec![CaseWhen { + condition: Expr::Value(number("1")), + result: Expr::Value(Value::SingleQuotedString("Y".to_string())), + }], else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString( "N".to_string() )))), @@ -13734,6 +13740,31 @@ fn test_trailing_commas_in_from() { ); } +#[test] +#[cfg(feature = "visitor")] +fn test_visit_order() { + let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END"; + let stmt = verified_stmt(sql); + let mut visited = vec![]; + sqlparser::ast::visit_expressions(&stmt, |expr| { + visited.push(expr.to_string()); + core::ops::ControlFlow::<()>::Continue(()) + }); + + assert_eq!( + visited, + [ + "CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END", + "a", + "1", + "2", + "3", + "4", + "5" + ] + ); +} + #[test] fn test_lambdas() { let dialects = all_dialects_where(|d| d.supports_lambda_functions()); @@ -13761,28 +13792,30 @@ fn test_lambdas() { body: Box::new(Expr::Case { operand: None, conditions: vec![ - Expr::BinaryOp { - left: Box::new(Expr::Identifier(Ident::new("p1"))), - op: BinaryOperator::Eq, - right: Box::new(Expr::Identifier(Ident::new("p2"))) + CaseWhen { + condition: Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("p1"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Identifier(Ident::new("p2"))) + }, + result: Expr::Value(number("0")) }, - Expr::BinaryOp { - left: Box::new(call( - "reverse", - [Expr::Identifier(Ident::new("p1"))] - )), - op: BinaryOperator::Lt, - right: Box::new(call( - "reverse", - [Expr::Identifier(Ident::new("p2"))] - )) - } - ], - results: vec![ - Expr::Value(number("0")), - Expr::UnaryOp { - op: UnaryOperator::Minus, - expr: Box::new(Expr::Value(number("1"))) + CaseWhen { + condition: Expr::BinaryOp { + left: Box::new(call( + "reverse", + [Expr::Identifier(Ident::new("p1"))] + )), + op: BinaryOperator::Lt, + right: Box::new(call( + "reverse", + [Expr::Identifier(Ident::new("p2"))] + )) + }, + result: Expr::UnaryOp { + op: UnaryOperator::Minus, + expr: Box::new(Expr::Value(number("1"))) + } } ], else_result: Some(Box::new(Expr::Value(number("1")))) diff --git a/tests/sqlparser_databricks.rs b/tests/sqlparser_databricks.rs index 8338a0e71..724bedf47 100644 --- a/tests/sqlparser_databricks.rs +++ b/tests/sqlparser_databricks.rs @@ -83,6 +83,71 @@ fn test_databricks_exists() { ); } +#[test] +fn test_databricks_lambdas() { + #[rustfmt::skip] + let sql = concat!( + "SELECT array_sort(array('Hello', 'World'), ", + "(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ", + "WHEN reverse(p1) < reverse(p2) THEN -1 ", + "ELSE 1 END)", + ); + pretty_assertions::assert_eq!( + SelectItem::UnnamedExpr(call( + "array_sort", + [ + call( + "array", + [ + Expr::Value(Value::SingleQuotedString("Hello".to_owned())), + Expr::Value(Value::SingleQuotedString("World".to_owned())) + ] + ), + Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]), + body: Box::new(Expr::Case { + operand: None, + conditions: vec![ + CaseWhen { + condition: Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("p1"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Identifier(Ident::new("p2"))) + }, + result: Expr::Value(number("0")) + }, + CaseWhen { + condition: Expr::BinaryOp { + left: Box::new(call( + "reverse", + [Expr::Identifier(Ident::new("p1"))] + )), + op: BinaryOperator::Lt, + right: Box::new(call( + "reverse", + [Expr::Identifier(Ident::new("p2"))] + )), + }, + result: Expr::UnaryOp { + op: UnaryOperator::Minus, + expr: Box::new(Expr::Value(number("1"))) + } + }, + ], + else_result: Some(Box::new(Expr::Value(number("1")))) + }) + }) + ] + )), + databricks().verified_only_select(sql).projection[0] + ); + + databricks().verified_expr( + "map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))", + ); + databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)"); +} + #[test] fn test_values_clause() { let values = Values {