-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT]: [SQL] struct subscript and json_query #2891
Changes from all commits
9bad26d
2524964
ffe17d2
9d159bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,38 @@ | ||
use super::SQLModule; | ||
use crate::functions::SQLFunctions; | ||
use crate::{ | ||
functions::{SQLFunction, SQLFunctions}, | ||
invalid_operation_err, | ||
}; | ||
|
||
pub struct SQLModuleJson; | ||
|
||
impl SQLModule for SQLModuleJson { | ||
fn register(_parent: &mut SQLFunctions) { | ||
// use FunctionExpr::Json as f; | ||
// TODO | ||
fn register(parent: &mut SQLFunctions) { | ||
parent.add_fn("json_query", JsonQuery); | ||
} | ||
} | ||
|
||
struct JsonQuery; | ||
|
||
impl SQLFunction for JsonQuery { | ||
fn to_expr( | ||
&self, | ||
inputs: &[sqlparser::ast::FunctionArg], | ||
planner: &crate::planner::SQLPlanner, | ||
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> { | ||
match inputs { | ||
[input, query] => { | ||
let input = planner.plan_function_arg(input)?; | ||
let query = planner.plan_function_arg(query)?; | ||
if let Some(q) = query.as_literal().and_then(|l| l.as_str()) { | ||
Ok(daft_functions_json::json_query(input, q)) | ||
} else { | ||
invalid_operation_err!("Expected a string literal for the query argument") | ||
} | ||
} | ||
_ => invalid_operation_err!( | ||
"invalid arguments for json_query. expected json_query(input, query)" | ||
), | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,33 @@ | ||
use super::SQLModule; | ||
use crate::functions::SQLFunctions; | ||
use crate::{ | ||
functions::{SQLFunction, SQLFunctions}, | ||
invalid_operation_err, | ||
}; | ||
|
||
pub struct SQLModuleMap; | ||
|
||
impl SQLModule for SQLModuleMap { | ||
fn register(_parent: &mut SQLFunctions) { | ||
// use FunctionExpr::Map as f; | ||
// TODO | ||
fn register(parent: &mut SQLFunctions) { | ||
parent.add_fn("map_get", MapGet); | ||
parent.add_fn("map_extract", MapGet); | ||
} | ||
} | ||
|
||
pub struct MapGet; | ||
|
||
impl SQLFunction for MapGet { | ||
fn to_expr( | ||
&self, | ||
inputs: &[sqlparser::ast::FunctionArg], | ||
planner: &crate::planner::SQLPlanner, | ||
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> { | ||
match inputs { | ||
[input, key] => { | ||
let input = planner.plan_function_arg(input)?; | ||
let key = planner.plan_function_arg(key)?; | ||
Ok(daft_dsl::functions::map::get(input, key)) | ||
} | ||
_ => invalid_operation_err!("Expected 2 input args"), | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,37 @@ | ||
use super::SQLModule; | ||
use crate::functions::SQLFunctions; | ||
use crate::{ | ||
functions::{SQLFunction, SQLFunctions}, | ||
invalid_operation_err, | ||
}; | ||
|
||
pub struct SQLModuleStructs; | ||
|
||
impl SQLModule for SQLModuleStructs { | ||
fn register(_parent: &mut SQLFunctions) { | ||
// use FunctionExpr::Struct as f; | ||
// TODO | ||
fn register(parent: &mut SQLFunctions) { | ||
parent.add_fn("struct_get", StructGet); | ||
parent.add_fn("struct_extract", StructGet); | ||
} | ||
} | ||
|
||
pub struct StructGet; | ||
|
||
impl SQLFunction for StructGet { | ||
fn to_expr( | ||
&self, | ||
inputs: &[sqlparser::ast::FunctionArg], | ||
planner: &crate::planner::SQLPlanner, | ||
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> { | ||
match inputs { | ||
[input, key] => { | ||
let input = planner.plan_function_arg(input)?; | ||
let key = planner.plan_function_arg(key)?; | ||
if let Some(lit) = key.as_literal().and_then(|lit| lit.as_str()) { | ||
Ok(daft_dsl::functions::struct_::get(input, lit)) | ||
} else { | ||
invalid_operation_err!("Expected key to be a string literal") | ||
} | ||
} | ||
_ => invalid_operation_err!("Expected 2 input args"), | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -551,9 +551,6 @@ impl SQLPlanner { | |
.plan_compound_identifier(idents.as_slice()) | ||
.map(|e| e[0].clone()), | ||
|
||
SQLExpr::JsonAccess { .. } => { | ||
unsupported_sql_err!("json access") | ||
} | ||
SQLExpr::CompositeAccess { .. } => { | ||
unsupported_sql_err!("composite access") | ||
} | ||
|
@@ -638,7 +635,6 @@ impl SQLPlanner { | |
unsupported_sql_err!("TypedString with data type {:?}", dtype) | ||
} | ||
}, | ||
SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"), | ||
SQLExpr::Function(func) => self.plan_function(func), | ||
SQLExpr::Case { | ||
operand, | ||
|
@@ -706,33 +702,7 @@ impl SQLPlanner { | |
SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"), | ||
SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"), | ||
SQLExpr::Map(_) => unsupported_sql_err!("MAP"), | ||
SQLExpr::Subscript { expr, subscript } => match subscript.as_ref() { | ||
Subscript::Index { index } => { | ||
let index = self.plan_expr(index)?; | ||
let expr = self.plan_expr(expr)?; | ||
Ok(daft_functions::list::get(expr, index, null_lit())) | ||
} | ||
Subscript::Slice { | ||
lower_bound, | ||
upper_bound, | ||
stride, | ||
} => { | ||
if stride.is_some() { | ||
unsupported_sql_err!("stride"); | ||
} | ||
match (lower_bound, upper_bound) { | ||
(Some(lower), Some(upper)) => { | ||
let lower = self.plan_expr(lower)?; | ||
let upper = self.plan_expr(upper)?; | ||
let expr = self.plan_expr(expr)?; | ||
Ok(daft_functions::list::slice(expr, lower, upper)) | ||
} | ||
_ => { | ||
unsupported_sql_err!("slice with only one bound not yet supported"); | ||
} | ||
} | ||
} | ||
}, | ||
SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), | ||
SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), | ||
SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"), | ||
SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"), | ||
|
@@ -741,6 +711,9 @@ impl SQLPlanner { | |
SQLExpr::OuterJoin(_) => unsupported_sql_err!("OUTER JOIN"), | ||
SQLExpr::Prior(_) => unsupported_sql_err!("PRIOR"), | ||
SQLExpr::Lambda(_) => unsupported_sql_err!("LAMBDA"), | ||
SQLExpr::JsonAccess { .. } | SQLExpr::MapAccess { .. } => { | ||
unreachable!("Not reachable in our dialect, should always be parsed as subscript") | ||
} | ||
} | ||
} | ||
|
||
|
@@ -936,6 +909,62 @@ impl SQLPlanner { | |
other => unsupported_sql_err!("unary operator {:?}", other), | ||
}) | ||
} | ||
fn plan_subscript( | ||
&self, | ||
expr: &sqlparser::ast::Expr, | ||
subscript: &Subscript, | ||
) -> SQLPlannerResult<ExprRef> { | ||
match subscript { | ||
Subscript::Index { index } => { | ||
let expr = self.plan_expr(expr)?; | ||
let index = self.plan_expr(index)?; | ||
let schema = self | ||
.current_relation | ||
.as_ref() | ||
.ok_or_else(|| { | ||
PlannerError::invalid_operation("subscript without a current relation") | ||
}) | ||
.map(|p| p.schema())?; | ||
let expr_field = expr.to_field(schema.as_ref())?; | ||
match expr_field.dtype { | ||
DataType::List(_) | DataType::FixedSizeList(_, _) => { | ||
Ok(daft_functions::list::get(expr, index, null_lit())) | ||
} | ||
DataType::Struct(_) => { | ||
if let Some(s) = index.as_literal().and_then(|l| l.as_str()) { | ||
Ok(daft_dsl::functions::struct_::get(expr, s)) | ||
} else { | ||
invalid_operation_err!("Index must be a string literal") | ||
} | ||
} | ||
DataType::Map(_) => Ok(daft_dsl::functions::map::get(expr, index)), | ||
dtype => { | ||
invalid_operation_err!("nested access on column with type: {}", dtype) | ||
} | ||
} | ||
} | ||
Subscript::Slice { | ||
lower_bound, | ||
upper_bound, | ||
stride, | ||
} => { | ||
if stride.is_some() { | ||
unsupported_sql_err!("stride cannot be provided when slicing an expression"); | ||
} | ||
match (lower_bound, upper_bound) { | ||
(Some(lower), Some(upper)) => { | ||
let lower = self.plan_expr(lower)?; | ||
let upper = self.plan_expr(upper)?; | ||
let expr = self.plan_expr(expr)?; | ||
Ok(daft_functions::list::slice(expr, lower, upper)) | ||
} | ||
_ => { | ||
unsupported_sql_err!("slice with only one bound not yet supported"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we do currently support it. If lower is not specified, it should be 0. If upper is not specified, we can pass in a null literal to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, just havent gotten around to implementing it yet. will follow up with support for it. |
||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// Checks if the SQL query is valid syntax and doesn't use unsupported features. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import daft | ||
from daft.sql.sql import SQLCatalog | ||
|
||
|
||
def test_nested_access(): | ||
df = daft.from_pydict( | ||
{ | ||
"json": ['{"a": 1, "b": {"c": 2}}', '{"a": 3, "b": {"c": 4}}', '{"a": 5, "b": {"c": 6}}'], | ||
"list": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], | ||
"dict": [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], | ||
} | ||
) | ||
|
||
catalog = SQLCatalog({"test": df}) | ||
|
||
actual = daft.sql( | ||
""" | ||
select | ||
json_query(json, '.b.c') as json_b_c, | ||
list[1] as list_1, | ||
list[0:1] as list_slice, | ||
dict['a'] as dict_a, | ||
struct_get(dict, 'a') as dict_a_2, | ||
cast(list as int[3])[1] as fsl_1, | ||
cast(list as int[3])[0:1] as fsl_slice | ||
from test | ||
""", | ||
catalog, | ||
).collect() | ||
|
||
expected = df.select( | ||
daft.col("json").json.query(".b.c").alias("json_b_c"), | ||
daft.col("list").list.get(1).alias("list_1"), | ||
daft.col("list").list.slice(0, 1).alias("list_slice"), | ||
daft.col("dict").struct.get("a").alias("dict_a"), | ||
daft.col("dict").struct.get("a").alias("dict_a_2"), | ||
daft.col("list").cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)).list.get(1).alias("fsl_1"), | ||
daft.col("list") | ||
.cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)) | ||
.list.slice(0, 1) | ||
.alias("fsl_slice"), | ||
).collect() | ||
|
||
assert actual.to_pydict() == expected.to_pydict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an alias to maintain ANSI/postgres compatibility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK, ANSI doesn't really have any standard for what you can/cant call your functions. I mostly added the alias because I've seen it in other db systems (such as duckdb).