From 9bad26d3997fc5998f1b91ab30c7100d1463ac56 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Mon, 23 Sep 2024 14:09:20 -0500 Subject: [PATCH 1/3] feat(sql): struct subscript and json_query --- Cargo.lock | 1 + src/daft-sql/Cargo.toml | 3 +- src/daft-sql/src/modules/json.rs | 35 ++++++++++-- src/daft-sql/src/planner.rs | 92 +++++++++++++++++++++----------- tests/sql/test_nested_access.py | 42 +++++++++++++++ 5 files changed, 136 insertions(+), 37 deletions(-) create mode 100644 tests/sql/test_nested_access.py diff --git a/Cargo.lock b/Cargo.lock index 8dc27f4c89..43b2d3fcb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2171,6 +2171,7 @@ dependencies = [ "daft-core", "daft-dsl", "daft-functions", + "daft-functions-json", "daft-plan", "once_cell", "pyo3", diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index c15a71f948..86f3baa11c 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -4,6 +4,7 @@ common-error = {path = "../common/error"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} +daft-functions-json = {path = "../daft-functions-json"} daft-plan = {path = "../daft-plan"} once_cell = {workspace = true} pyo3 = {workspace = true, optional = true} @@ -14,7 +15,7 @@ snafu.workspace = true rstest = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "daft-functions/python"] +python = ["dep:pyo3", "common-error/python", "daft-functions/python", "daft-functions-json/python"] [package] name = "daft-sql" diff --git a/src/daft-sql/src/modules/json.rs b/src/daft-sql/src/modules/json.rs index 845be622c0..f0d600daea 100644 --- a/src/daft-sql/src/modules/json.rs +++ b/src/daft-sql/src/modules/json.rs @@ -1,11 +1,38 @@ use super::SQLModule; -use crate::functions::SQLFunctions; +use crate::{ + functions::{SQLFunction, SQLFunctions}, + invalid_operation_err, +}; pub struct SQLModuleJson; impl SQLModule for SQLModuleJson { - fn register(_parent: &mut SQLFunctions) { - // use FunctionExpr::Json as f; - // TODO + fn register(parent: &mut SQLFunctions) { + parent.add_fn("json_query", JsonQuery); + } +} + +struct JsonQuery; + +impl SQLFunction for JsonQuery { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + match inputs { + [input, query] => { + let input = planner.plan_function_arg(input)?; + let query = planner.plan_function_arg(query)?; + if let Some(q) = query.as_literal().and_then(|l| l.as_str()) { + Ok(daft_functions_json::json_query(input, q)) + } else { + invalid_operation_err!("Expected a string literal for the query argument") + } + } + _ => invalid_operation_err!( + "invalid arguments for json_query. expected json_query(input, query)" + ), + } } } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index cf82f72743..2c0cd0a2a0 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -551,9 +551,6 @@ impl SQLPlanner { .plan_compound_identifier(idents.as_slice()) .map(|e| e[0].clone()), - SQLExpr::JsonAccess { .. } => { - unsupported_sql_err!("json access") - } SQLExpr::CompositeAccess { .. } => { unsupported_sql_err!("composite access") } @@ -626,9 +623,7 @@ impl SQLPlanner { SQLExpr::Collate { .. } => unsupported_sql_err!("COLLATE"), SQLExpr::Nested(_) => unsupported_sql_err!("NESTED"), SQLExpr::IntroducedString { .. } => unsupported_sql_err!("INTRODUCED STRING"), - SQLExpr::TypedString { .. } => unsupported_sql_err!("TYPED STRING"), - SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"), SQLExpr::Function(func) => self.plan_function(func), SQLExpr::Case { operand, @@ -696,33 +691,7 @@ impl SQLPlanner { SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"), SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"), SQLExpr::Map(_) => unsupported_sql_err!("MAP"), - SQLExpr::Subscript { expr, subscript } => match subscript.as_ref() { - Subscript::Index { index } => { - let index = self.plan_expr(index)?; - let expr = self.plan_expr(expr)?; - Ok(daft_functions::list::get(expr, index, null_lit())) - } - Subscript::Slice { - lower_bound, - upper_bound, - stride, - } => { - if stride.is_some() { - unsupported_sql_err!("stride"); - } - match (lower_bound, upper_bound) { - (Some(lower), Some(upper)) => { - let lower = self.plan_expr(lower)?; - let upper = self.plan_expr(upper)?; - let expr = self.plan_expr(expr)?; - Ok(daft_functions::list::slice(expr, lower, upper)) - } - _ => { - unsupported_sql_err!("slice with only one bound not yet supported"); - } - } - } - }, + SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"), SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"), @@ -731,6 +700,9 @@ impl SQLPlanner { SQLExpr::OuterJoin(_) => unsupported_sql_err!("OUTER JOIN"), SQLExpr::Prior(_) => unsupported_sql_err!("PRIOR"), SQLExpr::Lambda(_) => unsupported_sql_err!("LAMBDA"), + SQLExpr::JsonAccess { .. } | SQLExpr::MapAccess { .. } => { + unreachable!("Not reachable in our dialect, should always be parsed as subscript") + } } } @@ -926,6 +898,62 @@ impl SQLPlanner { other => unsupported_sql_err!("unary operator {:?}", other), }) } + fn plan_subscript( + &self, + expr: &sqlparser::ast::Expr, + subscript: &Subscript, + ) -> SQLPlannerResult { + match subscript { + Subscript::Index { index } => { + let expr = self.plan_expr(expr)?; + let index = self.plan_expr(index)?; + let schema = self + .current_relation + .as_ref() + .ok_or_else(|| { + PlannerError::invalid_operation("subscript without a current relation") + }) + .map(|p| p.schema())?; + let expr_field = expr.to_field(schema.as_ref())?; + match expr_field.dtype { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + Ok(daft_functions::list::get(expr, index, null_lit())) + } + DataType::Struct(_) => { + if let Some(s) = index.as_literal().and_then(|l| l.as_str()) { + Ok(daft_dsl::functions::struct_::get(expr, s)) + } else { + invalid_operation_err!("Index must be a string literal") + } + } + DataType::Map(_) => Ok(daft_dsl::functions::map::get(expr, index)), + dtype => { + invalid_operation_err!("nested access on column with type: {}", dtype) + } + } + } + Subscript::Slice { + lower_bound, + upper_bound, + stride, + } => { + if stride.is_some() { + unsupported_sql_err!("stride"); + } + match (lower_bound, upper_bound) { + (Some(lower), Some(upper)) => { + let lower = self.plan_expr(lower)?; + let upper = self.plan_expr(upper)?; + let expr = self.plan_expr(expr)?; + Ok(daft_functions::list::slice(expr, lower, upper)) + } + _ => { + unsupported_sql_err!("slice with only one bound not yet supported"); + } + } + } + } + } } /// Checks if the SQL query is valid syntax and doesn't use unsupported features. diff --git a/tests/sql/test_nested_access.py b/tests/sql/test_nested_access.py new file mode 100644 index 0000000000..8811afee59 --- /dev/null +++ b/tests/sql/test_nested_access.py @@ -0,0 +1,42 @@ +import daft +from daft.sql.sql import SQLCatalog + + +def test_nested_access(): + df = daft.from_pydict( + { + "json": ['{"a": 1, "b": {"c": 2}}', '{"a": 3, "b": {"c": 4}}', '{"a": 5, "b": {"c": 6}}'], + "list": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + "dict": [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], + } + ) + + catalog = SQLCatalog({"test": df}) + + actual = daft.sql( + """ + select + json_query(json, '.b.c') as json_b_c, + list[1] as list_1, + list[0:1] as list_slice, + dict['a'] as dict_a, + cast(list as int[3])[1] as fsl_1, + cast(list as int[3])[0:1] as fsl_slice + from test + """, + catalog, + ).collect() + + expected = df.select( + daft.col("json").json.query(".b.c").alias("json_b_c"), + daft.col("list").list.get(1).alias("list_1"), + daft.col("list").list.slice(0, 1).alias("list_slice"), + daft.col("dict").struct.get("a").alias("dict_a"), + daft.col("list").cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)).list.get(1).alias("fsl_1"), + daft.col("list") + .cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)) + .list.slice(0, 1) + .alias("fsl_slice"), + ).collect() + + assert actual.to_pydict() == expected.to_pydict() From ffe17d2d389a0dc143b07844f095e9884b2433f0 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Tue, 24 Sep 2024 13:06:48 -0500 Subject: [PATCH 2/3] fix struct get --- src/daft-sql/src/modules/structs.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/daft-sql/src/modules/structs.rs b/src/daft-sql/src/modules/structs.rs index 5767b42d73..66be42d8e3 100644 --- a/src/daft-sql/src/modules/structs.rs +++ b/src/daft-sql/src/modules/structs.rs @@ -25,7 +25,11 @@ impl SQLFunction for StructGet { [input, key] => { let input = planner.plan_function_arg(input)?; let key = planner.plan_function_arg(key)?; - Ok(daft_dsl::functions::map::get(input, key)) + if let Some(lit) = key.as_literal().and_then(|lit| lit.as_str()) { + Ok(daft_dsl::functions::struct_::get(input, lit)) + } else { + invalid_operation_err!("Expected key to be a string literal") + } } _ => invalid_operation_err!("Expected 2 input args"), } From 9d159bd31d5d738e4ae2ba3c80ef7e5bc2d58b37 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 24 Sep 2024 17:49:35 -0500 Subject: [PATCH 3/3] Update src/daft-sql/src/planner.rs Co-authored-by: Jay Chia <17691182+jaychia@users.noreply.github.com> --- src/daft-sql/src/planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 7f088199cf..460836ccde 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -949,7 +949,7 @@ impl SQLPlanner { stride, } => { if stride.is_some() { - unsupported_sql_err!("stride"); + unsupported_sql_err!("stride cannot be provided when slicing an expression"); } match (lower_bound, upper_bound) { (Some(lower), Some(upper)) => {