Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: [SQL] struct subscript and json_query #2891

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion src/daft-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ common-error = {path = "../common/error"}
daft-core = {path = "../daft-core"}
daft-dsl = {path = "../daft-dsl"}
daft-functions = {path = "../daft-functions"}
daft-functions-json = {path = "../daft-functions-json"}
daft-plan = {path = "../daft-plan"}
once_cell = {workspace = true}
pyo3 = {workspace = true, optional = true}
Expand All @@ -14,7 +15,7 @@ snafu.workspace = true
rstest = {workspace = true}

[features]
python = ["dep:pyo3", "common-error/python", "daft-functions/python"]
python = ["dep:pyo3", "common-error/python", "daft-functions/python", "daft-functions-json/python"]

[package]
name = "daft-sql"
Expand Down
35 changes: 31 additions & 4 deletions src/daft-sql/src/modules/json.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
use super::SQLModule;
use crate::functions::SQLFunctions;
use crate::{
functions::{SQLFunction, SQLFunctions},
invalid_operation_err,
};

pub struct SQLModuleJson;

impl SQLModule for SQLModuleJson {
fn register(_parent: &mut SQLFunctions) {
// use FunctionExpr::Json as f;
// TODO
fn register(parent: &mut SQLFunctions) {
parent.add_fn("json_query", JsonQuery);
}
}

struct JsonQuery;

impl SQLFunction for JsonQuery {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input, query] => {
let input = planner.plan_function_arg(input)?;
let query = planner.plan_function_arg(query)?;
if let Some(q) = query.as_literal().and_then(|l| l.as_str()) {
Ok(daft_functions_json::json_query(input, q))
} else {
invalid_operation_err!("Expected a string literal for the query argument")
}
}
_ => invalid_operation_err!(
"invalid arguments for json_query. expected json_query(input, query)"
),
}
}
}
30 changes: 26 additions & 4 deletions src/daft-sql/src/modules/map.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
use super::SQLModule;
use crate::functions::SQLFunctions;
use crate::{
functions::{SQLFunction, SQLFunctions},
invalid_operation_err,
};

pub struct SQLModuleMap;

impl SQLModule for SQLModuleMap {
fn register(_parent: &mut SQLFunctions) {
// use FunctionExpr::Map as f;
// TODO
fn register(parent: &mut SQLFunctions) {
parent.add_fn("map_get", MapGet);
parent.add_fn("map_extract", MapGet);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an alias to maintain ANSI/postgres compatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, ANSI doesn't really have any standard for what you can/cant call your functions. I mostly added the alias because I've seen it in other db systems (such as duckdb).

}
}

pub struct MapGet;

impl SQLFunction for MapGet {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input, key] => {
let input = planner.plan_function_arg(input)?;
let key = planner.plan_function_arg(key)?;
Ok(daft_dsl::functions::map::get(input, key))
}
_ => invalid_operation_err!("Expected 2 input args"),
}
}
}
34 changes: 30 additions & 4 deletions src/daft-sql/src/modules/structs.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
use super::SQLModule;
use crate::functions::SQLFunctions;
use crate::{
functions::{SQLFunction, SQLFunctions},
invalid_operation_err,
};

pub struct SQLModuleStructs;

impl SQLModule for SQLModuleStructs {
fn register(_parent: &mut SQLFunctions) {
// use FunctionExpr::Struct as f;
// TODO
fn register(parent: &mut SQLFunctions) {
parent.add_fn("struct_get", StructGet);
parent.add_fn("struct_extract", StructGet);
}
}

pub struct StructGet;

impl SQLFunction for StructGet {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> crate::error::SQLPlannerResult<daft_dsl::ExprRef> {
match inputs {
[input, key] => {
let input = planner.plan_function_arg(input)?;
let key = planner.plan_function_arg(key)?;
if let Some(lit) = key.as_literal().and_then(|lit| lit.as_str()) {
Ok(daft_dsl::functions::struct_::get(input, lit))
} else {
invalid_operation_err!("Expected key to be a string literal")
}
}
_ => invalid_operation_err!("Expected 2 input args"),
}
}
}
91 changes: 60 additions & 31 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,6 @@ impl SQLPlanner {
.plan_compound_identifier(idents.as_slice())
.map(|e| e[0].clone()),

SQLExpr::JsonAccess { .. } => {
unsupported_sql_err!("json access")
}
SQLExpr::CompositeAccess { .. } => {
unsupported_sql_err!("composite access")
}
Expand Down Expand Up @@ -638,7 +635,6 @@ impl SQLPlanner {
unsupported_sql_err!("TypedString with data type {:?}", dtype)
}
},
SQLExpr::MapAccess { .. } => unsupported_sql_err!("MAP ACCESS"),
SQLExpr::Function(func) => self.plan_function(func),
SQLExpr::Case {
operand,
Expand Down Expand Up @@ -706,33 +702,7 @@ impl SQLPlanner {
SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"),
SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"),
SQLExpr::Map(_) => unsupported_sql_err!("MAP"),
SQLExpr::Subscript { expr, subscript } => match subscript.as_ref() {
Subscript::Index { index } => {
let index = self.plan_expr(index)?;
let expr = self.plan_expr(expr)?;
Ok(daft_functions::list::get(expr, index, null_lit()))
}
Subscript::Slice {
lower_bound,
upper_bound,
stride,
} => {
if stride.is_some() {
unsupported_sql_err!("stride");
}
match (lower_bound, upper_bound) {
(Some(lower), Some(upper)) => {
let lower = self.plan_expr(lower)?;
let upper = self.plan_expr(upper)?;
let expr = self.plan_expr(expr)?;
Ok(daft_functions::list::slice(expr, lower, upper))
}
_ => {
unsupported_sql_err!("slice with only one bound not yet supported");
}
}
}
},
SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()),
SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"),
SQLExpr::Interval(_) => unsupported_sql_err!("INTERVAL"),
SQLExpr::MatchAgainst { .. } => unsupported_sql_err!("MATCH AGAINST"),
Expand All @@ -741,6 +711,9 @@ impl SQLPlanner {
SQLExpr::OuterJoin(_) => unsupported_sql_err!("OUTER JOIN"),
SQLExpr::Prior(_) => unsupported_sql_err!("PRIOR"),
SQLExpr::Lambda(_) => unsupported_sql_err!("LAMBDA"),
SQLExpr::JsonAccess { .. } | SQLExpr::MapAccess { .. } => {
unreachable!("Not reachable in our dialect, should always be parsed as subscript")
}
}
}

Expand Down Expand Up @@ -936,6 +909,62 @@ impl SQLPlanner {
other => unsupported_sql_err!("unary operator {:?}", other),
})
}
fn plan_subscript(
&self,
expr: &sqlparser::ast::Expr,
subscript: &Subscript,
) -> SQLPlannerResult<ExprRef> {
match subscript {
Subscript::Index { index } => {
let expr = self.plan_expr(expr)?;
let index = self.plan_expr(index)?;
let schema = self
.current_relation
.as_ref()
.ok_or_else(|| {
PlannerError::invalid_operation("subscript without a current relation")
})
.map(|p| p.schema())?;
let expr_field = expr.to_field(schema.as_ref())?;
match expr_field.dtype {
DataType::List(_) | DataType::FixedSizeList(_, _) => {
Ok(daft_functions::list::get(expr, index, null_lit()))
}
DataType::Struct(_) => {
if let Some(s) = index.as_literal().and_then(|l| l.as_str()) {
Ok(daft_dsl::functions::struct_::get(expr, s))
} else {
invalid_operation_err!("Index must be a string literal")
}
}
DataType::Map(_) => Ok(daft_dsl::functions::map::get(expr, index)),
dtype => {
invalid_operation_err!("nested access on column with type: {}", dtype)
}
}
}
Subscript::Slice {
lower_bound,
upper_bound,
stride,
} => {
if stride.is_some() {
unsupported_sql_err!("stride cannot be provided when slicing an expression");
}
match (lower_bound, upper_bound) {
(Some(lower), Some(upper)) => {
let lower = self.plan_expr(lower)?;
let upper = self.plan_expr(upper)?;
let expr = self.plan_expr(expr)?;
Ok(daft_functions::list::slice(expr, lower, upper))
}
_ => {
unsupported_sql_err!("slice with only one bound not yet supported");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we do currently support it. If lower is not specified, it should be 0. If upper is not specified, we can pass in a null literal to daft_functions::list::slice and it will work properly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, just havent gotten around to implementing it yet. will follow up with support for it.

}
}
}
}
}
}

/// Checks if the SQL query is valid syntax and doesn't use unsupported features.
Expand Down
44 changes: 44 additions & 0 deletions tests/sql/test_nested_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import daft
from daft.sql.sql import SQLCatalog


def test_nested_access():
df = daft.from_pydict(
{
"json": ['{"a": 1, "b": {"c": 2}}', '{"a": 3, "b": {"c": 4}}', '{"a": 5, "b": {"c": 6}}'],
"list": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
"dict": [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}],
}
)

catalog = SQLCatalog({"test": df})

actual = daft.sql(
"""
select
json_query(json, '.b.c') as json_b_c,
list[1] as list_1,
list[0:1] as list_slice,
dict['a'] as dict_a,
struct_get(dict, 'a') as dict_a_2,
cast(list as int[3])[1] as fsl_1,
cast(list as int[3])[0:1] as fsl_slice
from test
""",
catalog,
).collect()

expected = df.select(
daft.col("json").json.query(".b.c").alias("json_b_c"),
daft.col("list").list.get(1).alias("list_1"),
daft.col("list").list.slice(0, 1).alias("list_slice"),
daft.col("dict").struct.get("a").alias("dict_a"),
daft.col("dict").struct.get("a").alias("dict_a_2"),
daft.col("list").cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3)).list.get(1).alias("fsl_1"),
daft.col("list")
.cast(daft.DataType.fixed_size_list(daft.DataType.int32(), 3))
.list.slice(0, 1)
.alias("fsl_slice"),
).collect()

assert actual.to_pydict() == expected.to_pydict()
Loading