From 088b8b909a8f1ca65fdc89391ef6192017d5ac6c Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Thu, 10 Aug 2023 17:57:02 -0700 Subject: [PATCH] Add support for explode. --- daft/execution/rust_physical_plan_shim.py | 26 +++++++++++++++++++++ daft/logical/rust_logical_plan.py | 16 ++++++++++++- src/daft-plan/src/builder.rs | 19 +++++++++++++++ src/daft-plan/src/logical_plan.rs | 10 ++++++++ src/daft-plan/src/ops/explode.rs | 28 +++++++++++++++++++++++ src/daft-plan/src/ops/mod.rs | 2 ++ src/daft-plan/src/physical_ops/explode.rs | 21 +++++++++++++++++ src/daft-plan/src/physical_ops/mod.rs | 2 ++ src/daft-plan/src/physical_plan.rs | 16 +++++++++++++ src/daft-plan/src/planner.rs | 17 +++++++++++--- tests/dataframe/test_explode.py | 6 ++--- 11 files changed, 156 insertions(+), 7 deletions(-) create mode 100644 src/daft-plan/src/ops/explode.rs create mode 100644 src/daft-plan/src/physical_ops/explode.rs diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 44714e707d..bc5fa03cf6 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -6,6 +6,7 @@ from daft.daft import FileFormat, FileFormatConfig, PyExpr, PySchema, PyTable from daft.execution import execution_step, physical_plan from daft.expressions import Expression, ExpressionsProjection +from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema from daft.resource_request import ResourceRequest from daft.table import Table @@ -52,6 +53,31 @@ def project( ) +class ShimExplodeOp(MapPartitionOp): + explode_columns: ExpressionsProjection + + def __init__(self, explode_columns: ExpressionsProjection) -> None: + self.explode_columns = explode_columns + + def get_output_schema(self) -> Schema: + raise NotImplementedError("Output schema shouldn't be needed at execution time") + + def run(self, input_partition: Table) -> Table: + return input_partition.explode(self.explode_columns) + + +def explode( + input: physical_plan.InProgressPhysicalPlan[PartitionT], explode_exprs: list[PyExpr] +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + explode_expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in explode_exprs]) + explode_op = ShimExplodeOp(explode_expr_projection) + return physical_plan.pipeline_instruction( + child_plan=input, + pipeable_instruction=execution_step.MapPartition(explode_op), + resource_request=ResourceRequest(), # TODO(Clark): Use real ResourceRequest. + ) + + def sort( input: physical_plan.InProgressPhysicalPlan[PartitionT], sort_by: list[PyExpr], diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index e4b0d49ea7..9f723949df 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -120,7 +120,21 @@ def limit(self, num_rows: int) -> RustLogicalPlanBuilder: return RustLogicalPlanBuilder(builder) def explode(self, explode_expressions: ExpressionsProjection) -> RustLogicalPlanBuilder: - raise NotImplementedError("not implemented") + # TODO(Clark): Move this logic to Rust side after we've ported ExpressionsProjection. + explode_expressions = ExpressionsProjection([expr._explode() for expr in explode_expressions]) + input_schema = self.schema() + explode_schema = explode_expressions.resolve_schema(input_schema) + output_fields = [] + for f in input_schema: + if f.name in explode_schema.column_names(): + output_fields.append(explode_schema[f.name]) + else: + output_fields.append(f) + + exploded_schema = Schema._from_field_name_and_types([(f.name, f.dtype) for f in output_fields]) + explode_pyexprs = [expr._expr for expr in explode_expressions] + builder = self._builder.explode(explode_pyexprs, exploded_schema._schema) + return RustLogicalPlanBuilder(builder) def count(self) -> RustLogicalPlanBuilder: raise NotImplementedError("not implemented") diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 26200d076e..ea4cece56b 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -111,6 +111,25 @@ impl LogicalPlanBuilder { Ok(logical_plan_builder) } + pub fn explode( + &self, + explode_pyexprs: Vec, + exploded_schema: &PySchema, + ) -> PyResult { + let explode_exprs = explode_pyexprs + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let logical_plan: LogicalPlan = ops::Explode::new( + explode_exprs, + exploded_schema.clone().into(), + self.plan.clone(), + ) + .into(); + let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into()); + Ok(logical_plan_builder) + } + pub fn sort( &self, sort_by: Vec, diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 302279019b..6cceaf312e 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -10,6 +10,7 @@ pub enum LogicalPlan { Project(Project), Filter(Filter), Limit(Limit), + Explode(Explode), Sort(Sort), Repartition(Repartition), Coalesce(Coalesce), @@ -28,6 +29,9 @@ impl LogicalPlan { }) => projected_schema.clone(), Self::Filter(Filter { input, .. }) => input.schema(), Self::Limit(Limit { input, .. }) => input.schema(), + Self::Explode(Explode { + exploded_schema, .. + }) => exploded_schema.clone(), Self::Sort(Sort { input, .. }) => input.schema(), Self::Repartition(Repartition { input, .. }) => input.schema(), Self::Coalesce(Coalesce { input, .. }) => input.schema(), @@ -44,6 +48,7 @@ impl LogicalPlan { Self::Project(Project { input, .. }) => input.partition_spec(), Self::Filter(Filter { input, .. }) => input.partition_spec(), Self::Limit(Limit { input, .. }) => input.partition_spec(), + Self::Explode(Explode { input, .. }) => input.partition_spec(), Self::Sort(Sort { input, sort_by, .. }) => PartitionSpec::new_internal( PartitionScheme::Range, input.partition_spec().num_partitions, @@ -82,6 +87,7 @@ impl LogicalPlan { Self::Project(Project { input, .. }) => vec![input], Self::Filter(Filter { input, .. }) => vec![input], Self::Limit(Limit { input, .. }) => vec![input], + Self::Explode(Explode { input, .. }) => vec![input], Self::Sort(Sort { input, .. }) => vec![input], Self::Repartition(Repartition { input, .. }) => vec![input], Self::Coalesce(Coalesce { input, .. }) => vec![input], @@ -98,6 +104,9 @@ impl LogicalPlan { Self::Project(Project { projection, .. }) => vec![format!("Project: {projection:?}")], Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], + Self::Explode(Explode { explode_exprs, .. }) => { + vec![format!("Explode: {explode_exprs:?}")] + } Self::Sort(sort) => sort.multiline_display(), Self::Repartition(repartition) => repartition.multiline_display(), Self::Coalesce(Coalesce { num_to, .. }) => vec![format!("Coalesce: {num_to}")], @@ -129,6 +138,7 @@ impl_from_data_struct_for_logical_plan!(Source); impl_from_data_struct_for_logical_plan!(Project); impl_from_data_struct_for_logical_plan!(Filter); impl_from_data_struct_for_logical_plan!(Limit); +impl_from_data_struct_for_logical_plan!(Explode); impl_from_data_struct_for_logical_plan!(Sort); impl_from_data_struct_for_logical_plan!(Repartition); impl_from_data_struct_for_logical_plan!(Coalesce); diff --git a/src/daft-plan/src/ops/explode.rs b/src/daft-plan/src/ops/explode.rs new file mode 100644 index 0000000000..9f98c3c487 --- /dev/null +++ b/src/daft-plan/src/ops/explode.rs @@ -0,0 +1,28 @@ +use std::sync::Arc; + +use daft_core::schema::SchemaRef; +use daft_dsl::Expr; + +use crate::LogicalPlan; + +#[derive(Clone, Debug)] +pub struct Explode { + pub explode_exprs: Vec, + pub exploded_schema: SchemaRef, + // Upstream node. + pub input: Arc, +} + +impl Explode { + pub(crate) fn new( + explode_exprs: Vec, + exploded_schema: SchemaRef, + input: Arc, + ) -> Self { + Self { + explode_exprs, + exploded_schema, + input, + } + } +} diff --git a/src/daft-plan/src/ops/mod.rs b/src/daft-plan/src/ops/mod.rs index 066001b348..c140e9bcf8 100644 --- a/src/daft-plan/src/ops/mod.rs +++ b/src/daft-plan/src/ops/mod.rs @@ -2,6 +2,7 @@ mod agg; mod coalesce; mod concat; mod distinct; +mod explode; mod filter; mod limit; mod project; @@ -14,6 +15,7 @@ pub use agg::Aggregate; pub use coalesce::Coalesce; pub use concat::Concat; pub use distinct::Distinct; +pub use explode::Explode; pub use filter::Filter; pub use limit::Limit; pub use project::Project; diff --git a/src/daft-plan/src/physical_ops/explode.rs b/src/daft-plan/src/physical_ops/explode.rs new file mode 100644 index 0000000000..5e75bc5559 --- /dev/null +++ b/src/daft-plan/src/physical_ops/explode.rs @@ -0,0 +1,21 @@ +use std::sync::Arc; + +use daft_dsl::Expr; + +use crate::physical_plan::PhysicalPlan; + +#[derive(Clone, Debug)] +pub struct Explode { + pub explode_exprs: Vec, + // Upstream node. + pub input: Arc, +} + +impl Explode { + pub(crate) fn new(explode_exprs: Vec, input: Arc) -> Self { + Self { + explode_exprs, + input, + } + } +} diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index dc74b1c865..cbce384e61 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -2,6 +2,7 @@ mod agg; mod coalesce; mod concat; mod csv; +mod explode; mod fanout; mod filter; mod flatten; @@ -19,6 +20,7 @@ pub use agg::Aggregate; pub use coalesce::Coalesce; pub use concat::Concat; pub use csv::{TabularScanCsv, TabularWriteCsv}; +pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; pub use filter::Filter; pub use flatten::Flatten; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 50dfa5bf77..61bb8981a2 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -28,6 +28,7 @@ pub enum PhysicalPlan { Project(Project), Filter(Filter), Limit(Limit), + Explode(Explode), Sort(Sort), Split(Split), Flatten(Flatten), @@ -222,6 +223,21 @@ impl PhysicalPlan { .call1((local_limit_iter, *limit, *num_partitions))?; Ok(global_limit_iter.into()) } + PhysicalPlan::Explode(Explode { + input, + explode_exprs, + }) => { + let upstream_iter = input.to_partition_tasks(py, psets)?; + let explode_pyexprs: Vec = explode_exprs + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "explode"))? + .call1((upstream_iter, explode_pyexprs))?; + Ok(py_iter.into()) + } PhysicalPlan::Sort(Sort { input, sort_by, diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 9ed9069152..14421e6096 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -6,9 +6,9 @@ use daft_dsl::Expr; use crate::logical_plan::LogicalPlan; use crate::ops::{ Aggregate as LogicalAggregate, Coalesce as LogicalCoalesce, Concat as LogicalConcat, - Distinct as LogicalDistinct, Filter as LogicalFilter, Limit as LogicalLimit, - Project as LogicalProject, Repartition as LogicalRepartition, Sink as LogicalSink, - Sort as LogicalSort, Source, + Distinct as LogicalDistinct, Explode as LogicalExplode, Filter as LogicalFilter, + Limit as LogicalLimit, Project as LogicalProject, Repartition as LogicalRepartition, + Sink as LogicalSink, Sort as LogicalSort, Source, }; use crate::physical_ops::*; use crate::physical_plan::PhysicalPlan; @@ -88,6 +88,17 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { Arc::new(input_physical), ))) } + LogicalPlan::Explode(LogicalExplode { + input, + explode_exprs, + .. + }) => { + let input_physical = plan(input)?; + Ok(PhysicalPlan::Explode(Explode::new( + explode_exprs.clone(), + input_physical.into(), + ))) + } LogicalPlan::Sort(LogicalSort { input, sort_by, diff --git a/tests/dataframe/test_explode.py b/tests/dataframe/test_explode.py index 22190bc2b7..83b58838ae 100644 --- a/tests/dataframe/test_explode.py +++ b/tests/dataframe/test_explode.py @@ -15,7 +15,7 @@ Series.from_arrow(pa.array([[1, 2], [3, 4], None, []], type=pa.large_list(pa.int64()))), ], ) -def test_explode(data): +def test_explode(data, use_new_planner): df = daft.from_pydict({"nested": data, "sidecar": ["a", "b", "c", "d"]}) df = df.explode(col("nested")) assert df.to_pydict() == {"nested": [1, 2, 3, 4, None, None], "sidecar": ["a", "a", "b", "b", "c", "d"]} @@ -28,7 +28,7 @@ def test_explode(data): Series.from_arrow(pa.array([[1, 2], [3, 4], None, []], type=pa.large_list(pa.int64()))), ], ) -def test_explode_multiple_cols(data): +def test_explode_multiple_cols(data, use_new_planner): df = daft.from_pydict({"nested": data, "nested2": data, "sidecar": ["a", "b", "c", "d"]}) df = df.explode(col("nested"), col("nested2")) assert df.to_pydict() == { @@ -38,7 +38,7 @@ def test_explode_multiple_cols(data): } -def test_explode_bad_col_type(): +def test_explode_bad_col_type(use_new_planner): df = daft.from_pydict({"a": [1, 2, 3]}) with pytest.raises(ValueError, match="Datatype cannot be exploded:"): df = df.explode(col("a"))