From 200bd5276b1b259e71d96c0157f3848931254dd4 Mon Sep 17 00:00:00 2001 From: frankzfli Date: Sun, 15 Dec 2024 01:48:55 +0800 Subject: [PATCH 1/3] feat: add download_url/upload_url to sql --- src/daft-functions/src/uri/download.rs | 10 +- src/daft-functions/src/uri/mod.rs | 4 +- src/daft-functions/src/uri/upload.rs | 12 +- src/daft-sql/src/functions.rs | 3 +- src/daft-sql/src/modules/mod.rs | 2 + src/daft-sql/src/modules/url.rs | 192 +++++++++++++++++++++++++ tests/io/test_url_download_local.py | 17 ++- tests/io/test_url_upload_local.py | 93 ++++++------ 8 files changed, 273 insertions(+), 60 deletions(-) create mode 100644 src/daft-sql/src/modules/url.rs diff --git a/src/daft-functions/src/uri/download.rs b/src/daft-functions/src/uri/download.rs index 24d3f89d33..59dd8ec649 100644 --- a/src/daft-functions/src/uri/download.rs +++ b/src/daft-functions/src/uri/download.rs @@ -12,11 +12,11 @@ use snafu::prelude::*; use crate::InvalidArgumentSnafu; #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)] -pub(super) struct DownloadFunction { - pub(super) max_connections: usize, - pub(super) raise_error_on_failure: bool, - pub(super) multi_thread: bool, - pub(super) config: Arc, +pub struct DownloadFunction { + pub max_connections: usize, + pub raise_error_on_failure: bool, + pub multi_thread: bool, + pub config: Arc, } #[typetag::serde] diff --git a/src/daft-functions/src/uri/mod.rs b/src/daft-functions/src/uri/mod.rs index 67418fa1df..19c52afeab 100644 --- a/src/daft-functions/src/uri/mod.rs +++ b/src/daft-functions/src/uri/mod.rs @@ -1,5 +1,5 @@ -mod download; -mod upload; +pub mod download; +pub mod upload; use common_io_config::IOConfig; use daft_dsl::{functions::ScalarFunction, ExprRef}; diff --git a/src/daft-functions/src/uri/upload.rs b/src/daft-functions/src/uri/upload.rs index 5b01858b94..7d08a889cb 100644 --- a/src/daft-functions/src/uri/upload.rs +++ b/src/daft-functions/src/uri/upload.rs @@ -9,12 +9,12 @@ use futures::{StreamExt, TryStreamExt}; use serde::Serialize; #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)] -pub(super) struct UploadFunction { - pub(super) max_connections: usize, - pub(super) raise_error_on_failure: bool, - pub(super) multi_thread: bool, - pub(super) is_single_folder: bool, - pub(super) config: Arc, +pub struct UploadFunction { + pub max_connections: usize, + pub raise_error_on_failure: bool, + pub multi_thread: bool, + pub is_single_folder: bool, + pub config: Arc, } #[typetag::serde] diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index d75f090072..4e969394a6 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -14,7 +14,7 @@ use crate::{ coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat, SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric, SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs, - SQLModuleTemporal, SQLModuleUtf8, + SQLModuleTemporal, SQLModuleURL, SQLModuleUtf8, }, planner::SQLPlanner, unsupported_sql_err, @@ -37,6 +37,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions.add_fn("coalesce", SQLCoalesce {}); functions diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 30195dc52f..c3f654adff 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -15,6 +15,7 @@ pub mod python; pub mod sketch; pub mod structs; pub mod temporal; +pub mod url; pub mod utf8; pub use aggs::SQLModuleAggs; @@ -30,6 +31,7 @@ pub use python::SQLModulePython; pub use sketch::SQLModuleSketch; pub use structs::SQLModuleStructs; pub use temporal::SQLModuleTemporal; +pub use url::SQLModuleURL; pub use utf8::SQLModuleUtf8; /// A [SQLModule] is a collection of SQL functions that can be registered with a [SQLFunctions] instance. diff --git a/src/daft-sql/src/modules/url.rs b/src/daft-sql/src/modules/url.rs new file mode 100644 index 0000000000..916da5db1c --- /dev/null +++ b/src/daft-sql/src/modules/url.rs @@ -0,0 +1,192 @@ +use std::sync::Arc; + +use daft_dsl::{Expr, ExprRef, LiteralValue}; +use daft_functions::uri::{download, download::DownloadFunction, upload, upload::UploadFunction}; +use sqlparser::ast::FunctionArg; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + modules::config::expr_to_iocfg, + planner::SQLPlanner, + unsupported_sql_err, +}; + +pub struct SQLModuleURL; + +impl SQLModule for SQLModuleURL { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("url_download", UrlDownload); + parent.add_fn("url_upload", UrlUpload); + } +} + +impl TryFrom for DownloadFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); + let raise_error_on_failure = args + .get_named("on_error") + .map(|arg| match arg.as_ref() { + Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { + "raise" => Ok(true), + "null" => Ok(false), + _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + }, + _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + }) + .transpose()? + .unwrap_or(true); + + // TODO: choice multi_thread based on the current engine (such as ray) + let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false); + + let config = Arc::new( + args.get_named("io_config") + .map(expr_to_iocfg) + .transpose()? + .unwrap_or_default(), + ); + + Ok(Self { + max_connections, + raise_error_on_failure, + multi_thread, + config, + }) + } +} + +struct UrlDownload; + +impl SQLFunction for UrlDownload { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: DownloadFunction = planner.plan_function_args( + args, + &["max_connections", "on_error", "multi_thread", "io_config"], + 0, + )?; + + Ok(download( + input, + args.max_connections, + args.raise_error_on_failure, + args.multi_thread, + Arc::try_unwrap(args.config).unwrap_or_default().into(), + )) + } + _ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"), + } + } + + fn docstrings(&self, _: &str) -> String { + "download data from the given url".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[ + "input", + "max_connections", + "on_error", + "multi_thread", + "io_config", + ] + } +} + +impl TryFrom for UploadFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); + + let raise_error_on_failure = args + .get_named("on_error") + .map(|arg| match arg.as_ref() { + Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { + "raise" => Ok(true), + "null" => Ok(false), + _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + }, + _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + }) + .transpose()? + .unwrap_or(true); + + // TODO: choice multi_thread based on the current engine (such as ray) + let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false); + + // by default use row_specifc_urls + let is_single_folder = false; + + let config = Arc::new( + args.get_named("io_config") + .map(expr_to_iocfg) + .transpose()? + .unwrap_or_default(), + ); + + Ok(Self { + max_connections, + raise_error_on_failure, + multi_thread, + is_single_folder, + config, + }) + } +} + +struct UrlUpload; + +impl SQLFunction for UrlUpload { + fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { + match inputs { + [input, location, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let location = planner.plan_function_arg(location)?; + let mut args: UploadFunction = planner.plan_function_args( + args, + &["max_connections", "on_error", "multi_thread", "io_config"], + 0, + )?; + if location.as_literal().is_some() { + args.is_single_folder = true; + } + Ok(upload( + input, + location, + args.max_connections, + args.raise_error_on_failure, + args.multi_thread, + args.is_single_folder, + Arc::try_unwrap(args.config).unwrap_or_default().into(), + )) + } + _ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"), + } + } + + fn docstrings(&self, _: &str) -> String { + "upload data to the given path".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[ + "input", + "location", + "max_connections", + "on_error", + "multi_thread", + "io_config", + ] + } +} diff --git a/tests/io/test_url_download_local.py b/tests/io/test_url_download_local.py index 57dcbf6d9e..875fa9439c 100644 --- a/tests/io/test_url_download_local.py +++ b/tests/io/test_url_download_local.py @@ -30,18 +30,25 @@ def local_image_data_fixture(tmpdir, image_data) -> YieldFixture[list[str]]: def test_url_download_local(local_image_data_fixture, image_data): data = {"urls": local_image_data_fixture} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download()) - assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]} + + def check_results(df): + assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]} + + check_results(df.with_column("data", df["urls"].url.download())) + check_results(daft.sql("SELECT urls, url_download(urls) AS data FROM df")) @pytest.mark.integration() def test_url_download_local_missing(local_image_data_fixture): data = {"urls": local_image_data_fixture + ["/missing/path/x.jpeg"]} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download(on_error="raise")) - with pytest.raises(FileNotFoundError): - df.collect() + def check_results(df): + with pytest.raises(FileNotFoundError): + df.collect() + + check_results(df.with_column("data", df["urls"].url.download(on_error="raise"))) + check_results(daft.sql("SELECT urls, url_download(urls, on_error:='raise') AS data FROM df")) @pytest.mark.integration() diff --git a/tests/io/test_url_upload_local.py b/tests/io/test_url_upload_local.py index 3a23d0d8a5..6d8af376a4 100644 --- a/tests/io/test_url_upload_local.py +++ b/tests/io/test_url_upload_local.py @@ -9,17 +9,21 @@ def test_upload_local(tmpdir): bytes_data = [b"a", b"b", b"c"] data = {"data": bytes_data} df = daft.from_pydict(data) - df = df.with_column("files", df["data"].url.upload(str(tmpdir + "/nested"))) - df.collect() - results = df.to_pydict() - assert results["data"] == bytes_data - assert len(results["files"]) == len(bytes_data) - for path, expected in zip(results["files"], bytes_data): - assert path.startswith("file://") - path = path[len("file://") :] - with open(path, "rb") as f: - assert f.read() == expected + def check_upload_results(df): + df.collect() + results = df.to_pydict() + assert results["data"] == bytes_data + assert len(results["files"]) == len(bytes_data) + for path, expected in zip(results["files"], bytes_data): + assert path.startswith("file://") + path = path[len("file://") :] + with open(path, "rb") as f: + assert f.read() == expected + + # check df and sql + check_upload_results(df.with_column("files", df["data"].url.upload(str(tmpdir + "/nested")))) + check_upload_results(daft.sql(f"SELECT data, url_upload(data, '{tmpdir!s}') AS files FROM df")) def test_upload_local_single_file_url(tmpdir): @@ -27,23 +31,27 @@ def test_upload_local_single_file_url(tmpdir): paths = [f"{tmpdir}/0"] data = {"data": bytes_data, "paths": paths} df = daft.from_pydict(data) - # Even though there is only one row, since we pass in the upload URL via an expression, we - # should treat the given path as a per-row path and write directly to that path, instead of - # treating the path as a directory and writing to `{path}/uuid`. - df = df.with_column("files", df["data"].url.upload(df["paths"])) - df.collect() - - results = df.to_pydict() - assert results["data"] == bytes_data - assert len(results["files"]) == len(bytes_data) - for path, expected in zip(results["files"], bytes_data): - assert path.startswith("file://") - path = path[len("file://") :] - with open(path, "rb") as f: - assert f.read() == expected - # Check that data was uploaded to the correct paths. - for path, expected in zip(results["files"], paths): - assert path == "file://" + expected + + def check_upload_results(df): + # Even though there is only one row, since we pass in the upload URL via an expression, we + # should treat the given path as a per-row path and write directly to that path, instead of + # treating the path as a directory and writing to `{path}/uuid`. + df.collect() + + results = df.to_pydict() + assert results["data"] == bytes_data + assert len(results["files"]) == len(bytes_data) + for path, expected in zip(results["files"], bytes_data): + assert path.startswith("file://") + path = path[len("file://") :] + with open(path, "rb") as f: + assert f.read() == expected + # Check that data was uploaded to the correct paths. + for path, expected in zip(results["files"], paths): + assert path == "file://" + expected + + check_upload_results(df.with_column("files", df["data"].url.upload(df["paths"]))) + check_upload_results(daft.sql("SELECT data, url_upload(data, paths) AS files FROM df")) def test_upload_local_row_specifc_urls(tmpdir): @@ -51,20 +59,23 @@ def test_upload_local_row_specifc_urls(tmpdir): paths = [f"{tmpdir}/0", f"{tmpdir}/1", f"{tmpdir}/2"] data = {"data": bytes_data, "paths": paths} df = daft.from_pydict(data) - df = df.with_column("files", df["data"].url.upload(df["paths"])) - df.collect() - - results = df.to_pydict() - assert results["data"] == bytes_data - assert len(results["files"]) == len(bytes_data) - for path, expected in zip(results["files"], bytes_data): - assert path.startswith("file://") - path = path[len("file://") :] - with open(path, "rb") as f: - assert f.read() == expected - # Check that data was uploaded to the correct paths. - for path, expected in zip(results["files"], paths): - assert path == "file://" + expected + + def check_upload_results(df): + df.collect() + results = df.to_pydict() + assert results["data"] == bytes_data + assert len(results["files"]) == len(bytes_data) + for path, expected in zip(results["files"], bytes_data): + assert path.startswith("file://") + path = path[len("file://") :] + with open(path, "rb") as f: + assert f.read() == expected + # Check that data was uploaded to the correct paths. + for path, expected in zip(results["files"], paths): + assert path == "file://" + expected + + check_upload_results(df.with_column("files", df["data"].url.upload(df["paths"]))) + check_upload_results(daft.sql("SELECT data, url_upload(data, paths) AS files FROM df")) def test_upload_local_no_write_permissions(tmpdir): From fbe0a6fbf2ca3eef67bd0ba15e1821a1a95b5502 Mon Sep 17 00:00:00 2001 From: frankzfli Date: Sat, 28 Dec 2024 20:25:20 +0800 Subject: [PATCH 2/3] fix comment --- src/daft-sql/src/modules/url.rs | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/daft-sql/src/modules/url.rs b/src/daft-sql/src/modules/url.rs index 916da5db1c..7540f99056 100644 --- a/src/daft-sql/src/modules/url.rs +++ b/src/daft-sql/src/modules/url.rs @@ -15,6 +15,8 @@ use crate::{ pub struct SQLModuleURL; +const DEFAULT_MAX_CONNECTIONS: usize = 32; + impl SQLModule for SQLModuleURL { fn register(parent: &mut SQLFunctions) { parent.add_fn("url_download", UrlDownload); @@ -26,16 +28,22 @@ impl TryFrom for DownloadFunction { type Error = PlannerError; fn try_from(args: SQLFunctionArguments) -> Result { - let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); + let max_connections = args + .try_get_named("max_connections")? + .unwrap_or(DEFAULT_MAX_CONNECTIONS); let raise_error_on_failure = args .get_named("on_error") .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { "raise" => Ok(true), "null" => Ok(false), - _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + other => unsupported_sql_err!( + "Expected on_error to be 'raise' or 'null'; instead got '{other:?}'" + ), }, - _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + other => unsupported_sql_err!( + "Expected on_error to be 'raise' or 'null'; instead got '{other:?}'" + ), }) .transpose()? .unwrap_or(true); @@ -81,7 +89,7 @@ impl SQLFunction for UrlDownload { args.max_connections, args.raise_error_on_failure, args.multi_thread, - Arc::try_unwrap(args.config).unwrap_or_default().into(), + Arc::try_unwrap(args.config).ok(), // upload requires Option )) } _ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"), @@ -107,7 +115,9 @@ impl TryFrom for UploadFunction { type Error = PlannerError; fn try_from(args: SQLFunctionArguments) -> Result { - let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); + let max_connections = args + .try_get_named("max_connections")? + .unwrap_or(DEFAULT_MAX_CONNECTIONS); let raise_error_on_failure = args .get_named("on_error") @@ -115,9 +125,13 @@ impl TryFrom for UploadFunction { Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { "raise" => Ok(true), "null" => Ok(false), - _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + other => unsupported_sql_err!( + "Expected on_error to be 'raise' or 'null'; instead got '{other:?}'" + ), }, - _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + other => unsupported_sql_err!( + "Expected on_error to be 'raise' or 'null'; instead got '{other:?}'" + ), }) .transpose()? .unwrap_or(true); @@ -168,7 +182,7 @@ impl SQLFunction for UrlUpload { args.raise_error_on_failure, args.multi_thread, args.is_single_folder, - Arc::try_unwrap(args.config).unwrap_or_default().into(), + Arc::try_unwrap(args.config).ok(), // upload requires Option )) } _ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"), From e72676f55547a73e4b683360bd9445b9316dec46 Mon Sep 17 00:00:00 2001 From: frankzfli Date: Sun, 29 Dec 2024 09:45:36 +0800 Subject: [PATCH 3/3] fix comment 2 --- src/daft-sql/src/modules/url.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/daft-sql/src/modules/url.rs b/src/daft-sql/src/modules/url.rs index 7540f99056..9f674d776f 100644 --- a/src/daft-sql/src/modules/url.rs +++ b/src/daft-sql/src/modules/url.rs @@ -89,7 +89,10 @@ impl SQLFunction for UrlDownload { args.max_connections, args.raise_error_on_failure, args.multi_thread, - Arc::try_unwrap(args.config).ok(), // upload requires Option + Some(match Arc::try_unwrap(args.config) { + Ok(elem) => elem, + Err(elem) => (*elem).clone(), + }), // download requires Option )) } _ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"), @@ -182,7 +185,10 @@ impl SQLFunction for UrlUpload { args.raise_error_on_failure, args.multi_thread, args.is_single_folder, - Arc::try_unwrap(args.config).ok(), // upload requires Option + Some(match Arc::try_unwrap(args.config) { + Ok(elem) => elem, + Err(elem) => (*elem).clone(), + }), // upload requires Option )) } _ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"),