Skip to content

Commit

Permalink
Add support for explode.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Aug 11, 2023
1 parent 6754d67 commit 088b8b9
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 7 deletions.
26 changes: 26 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from daft.daft import FileFormat, FileFormatConfig, PyExpr, PySchema, PyTable
from daft.execution import execution_step, physical_plan
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
from daft.table import Table
Expand Down Expand Up @@ -52,6 +53,31 @@ def project(
)


class ShimExplodeOp(MapPartitionOp):
explode_columns: ExpressionsProjection

def __init__(self, explode_columns: ExpressionsProjection) -> None:
self.explode_columns = explode_columns

def get_output_schema(self) -> Schema:
raise NotImplementedError("Output schema shouldn't be needed at execution time")

def run(self, input_partition: Table) -> Table:
return input_partition.explode(self.explode_columns)


def explode(
input: physical_plan.InProgressPhysicalPlan[PartitionT], explode_exprs: list[PyExpr]
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
explode_expr_projection = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in explode_exprs])
explode_op = ShimExplodeOp(explode_expr_projection)
return physical_plan.pipeline_instruction(
child_plan=input,
pipeable_instruction=execution_step.MapPartition(explode_op),
resource_request=ResourceRequest(), # TODO(Clark): Use real ResourceRequest.
)


def sort(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
sort_by: list[PyExpr],
Expand Down
16 changes: 15 additions & 1 deletion daft/logical/rust_logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,21 @@ def limit(self, num_rows: int) -> RustLogicalPlanBuilder:
return RustLogicalPlanBuilder(builder)

def explode(self, explode_expressions: ExpressionsProjection) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
# TODO(Clark): Move this logic to Rust side after we've ported ExpressionsProjection.
explode_expressions = ExpressionsProjection([expr._explode() for expr in explode_expressions])
input_schema = self.schema()
explode_schema = explode_expressions.resolve_schema(input_schema)
output_fields = []
for f in input_schema:
if f.name in explode_schema.column_names():
output_fields.append(explode_schema[f.name])
else:
output_fields.append(f)

exploded_schema = Schema._from_field_name_and_types([(f.name, f.dtype) for f in output_fields])
explode_pyexprs = [expr._expr for expr in explode_expressions]
builder = self._builder.explode(explode_pyexprs, exploded_schema._schema)
return RustLogicalPlanBuilder(builder)

def count(self) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
Expand Down
19 changes: 19 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,25 @@ impl LogicalPlanBuilder {
Ok(logical_plan_builder)
}

pub fn explode(
&self,
explode_pyexprs: Vec<PyExpr>,
exploded_schema: &PySchema,
) -> PyResult<LogicalPlanBuilder> {
let explode_exprs = explode_pyexprs
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let logical_plan: LogicalPlan = ops::Explode::new(
explode_exprs,
exploded_schema.clone().into(),
self.plan.clone(),
)
.into();
let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into());
Ok(logical_plan_builder)
}

pub fn sort(
&self,
sort_by: Vec<PyExpr>,
Expand Down
10 changes: 10 additions & 0 deletions src/daft-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub enum LogicalPlan {
Project(Project),
Filter(Filter),
Limit(Limit),
Explode(Explode),
Sort(Sort),
Repartition(Repartition),
Coalesce(Coalesce),
Expand All @@ -28,6 +29,9 @@ impl LogicalPlan {
}) => projected_schema.clone(),
Self::Filter(Filter { input, .. }) => input.schema(),
Self::Limit(Limit { input, .. }) => input.schema(),
Self::Explode(Explode {
exploded_schema, ..
}) => exploded_schema.clone(),
Self::Sort(Sort { input, .. }) => input.schema(),
Self::Repartition(Repartition { input, .. }) => input.schema(),
Self::Coalesce(Coalesce { input, .. }) => input.schema(),
Expand All @@ -44,6 +48,7 @@ impl LogicalPlan {
Self::Project(Project { input, .. }) => input.partition_spec(),
Self::Filter(Filter { input, .. }) => input.partition_spec(),
Self::Limit(Limit { input, .. }) => input.partition_spec(),
Self::Explode(Explode { input, .. }) => input.partition_spec(),
Self::Sort(Sort { input, sort_by, .. }) => PartitionSpec::new_internal(
PartitionScheme::Range,
input.partition_spec().num_partitions,
Expand Down Expand Up @@ -82,6 +87,7 @@ impl LogicalPlan {
Self::Project(Project { input, .. }) => vec![input],
Self::Filter(Filter { input, .. }) => vec![input],
Self::Limit(Limit { input, .. }) => vec![input],
Self::Explode(Explode { input, .. }) => vec![input],
Self::Sort(Sort { input, .. }) => vec![input],
Self::Repartition(Repartition { input, .. }) => vec![input],
Self::Coalesce(Coalesce { input, .. }) => vec![input],
Expand All @@ -98,6 +104,9 @@ impl LogicalPlan {
Self::Project(Project { projection, .. }) => vec![format!("Project: {projection:?}")],
Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")],
Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")],
Self::Explode(Explode { explode_exprs, .. }) => {
vec![format!("Explode: {explode_exprs:?}")]
}
Self::Sort(sort) => sort.multiline_display(),
Self::Repartition(repartition) => repartition.multiline_display(),
Self::Coalesce(Coalesce { num_to, .. }) => vec![format!("Coalesce: {num_to}")],
Expand Down Expand Up @@ -129,6 +138,7 @@ impl_from_data_struct_for_logical_plan!(Source);
impl_from_data_struct_for_logical_plan!(Project);
impl_from_data_struct_for_logical_plan!(Filter);
impl_from_data_struct_for_logical_plan!(Limit);
impl_from_data_struct_for_logical_plan!(Explode);
impl_from_data_struct_for_logical_plan!(Sort);
impl_from_data_struct_for_logical_plan!(Repartition);
impl_from_data_struct_for_logical_plan!(Coalesce);
Expand Down
28 changes: 28 additions & 0 deletions src/daft-plan/src/ops/explode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use std::sync::Arc;

use daft_core::schema::SchemaRef;
use daft_dsl::Expr;

use crate::LogicalPlan;

#[derive(Clone, Debug)]
pub struct Explode {
pub explode_exprs: Vec<Expr>,
pub exploded_schema: SchemaRef,
// Upstream node.
pub input: Arc<LogicalPlan>,
}

impl Explode {
pub(crate) fn new(
explode_exprs: Vec<Expr>,
exploded_schema: SchemaRef,
input: Arc<LogicalPlan>,
) -> Self {
Self {
explode_exprs,
exploded_schema,
input,
}
}
}
2 changes: 2 additions & 0 deletions src/daft-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod agg;
mod coalesce;
mod concat;
mod distinct;
mod explode;
mod filter;
mod limit;
mod project;
Expand All @@ -14,6 +15,7 @@ pub use agg::Aggregate;
pub use coalesce::Coalesce;
pub use concat::Concat;
pub use distinct::Distinct;
pub use explode::Explode;
pub use filter::Filter;
pub use limit::Limit;
pub use project::Project;
Expand Down
21 changes: 21 additions & 0 deletions src/daft-plan/src/physical_ops/explode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use std::sync::Arc;

use daft_dsl::Expr;

use crate::physical_plan::PhysicalPlan;

#[derive(Clone, Debug)]
pub struct Explode {
pub explode_exprs: Vec<Expr>,
// Upstream node.
pub input: Arc<PhysicalPlan>,
}

impl Explode {
pub(crate) fn new(explode_exprs: Vec<Expr>, input: Arc<PhysicalPlan>) -> Self {
Self {
explode_exprs,
input,
}
}
}
2 changes: 2 additions & 0 deletions src/daft-plan/src/physical_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod agg;
mod coalesce;
mod concat;
mod csv;
mod explode;
mod fanout;
mod filter;
mod flatten;
Expand All @@ -19,6 +20,7 @@ pub use agg::Aggregate;
pub use coalesce::Coalesce;
pub use concat::Concat;
pub use csv::{TabularScanCsv, TabularWriteCsv};
pub use explode::Explode;
pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom};
pub use filter::Filter;
pub use flatten::Flatten;
Expand Down
16 changes: 16 additions & 0 deletions src/daft-plan/src/physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum PhysicalPlan {
Project(Project),
Filter(Filter),
Limit(Limit),
Explode(Explode),
Sort(Sort),
Split(Split),
Flatten(Flatten),
Expand Down Expand Up @@ -222,6 +223,21 @@ impl PhysicalPlan {
.call1((local_limit_iter, *limit, *num_partitions))?;
Ok(global_limit_iter.into())
}
PhysicalPlan::Explode(Explode {
input,
explode_exprs,
}) => {
let upstream_iter = input.to_partition_tasks(py, psets)?;
let explode_pyexprs: Vec<PyExpr> = explode_exprs
.iter()
.map(|expr| PyExpr::from(expr.clone()))
.collect();
let py_iter = py
.import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))?
.getattr(pyo3::intern!(py, "explode"))?
.call1((upstream_iter, explode_pyexprs))?;
Ok(py_iter.into())
}
PhysicalPlan::Sort(Sort {
input,
sort_by,
Expand Down
17 changes: 14 additions & 3 deletions src/daft-plan/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use daft_dsl::Expr;
use crate::logical_plan::LogicalPlan;
use crate::ops::{
Aggregate as LogicalAggregate, Coalesce as LogicalCoalesce, Concat as LogicalConcat,
Distinct as LogicalDistinct, Filter as LogicalFilter, Limit as LogicalLimit,
Project as LogicalProject, Repartition as LogicalRepartition, Sink as LogicalSink,
Sort as LogicalSort, Source,
Distinct as LogicalDistinct, Explode as LogicalExplode, Filter as LogicalFilter,
Limit as LogicalLimit, Project as LogicalProject, Repartition as LogicalRepartition,
Sink as LogicalSink, Sort as LogicalSort, Source,
};
use crate::physical_ops::*;
use crate::physical_plan::PhysicalPlan;
Expand Down Expand Up @@ -88,6 +88,17 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult<PhysicalPlan> {
Arc::new(input_physical),
)))
}
LogicalPlan::Explode(LogicalExplode {
input,
explode_exprs,
..
}) => {
let input_physical = plan(input)?;
Ok(PhysicalPlan::Explode(Explode::new(
explode_exprs.clone(),
input_physical.into(),
)))
}
LogicalPlan::Sort(LogicalSort {
input,
sort_by,
Expand Down
6 changes: 3 additions & 3 deletions tests/dataframe/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Series.from_arrow(pa.array([[1, 2], [3, 4], None, []], type=pa.large_list(pa.int64()))),
],
)
def test_explode(data):
def test_explode(data, use_new_planner):
df = daft.from_pydict({"nested": data, "sidecar": ["a", "b", "c", "d"]})
df = df.explode(col("nested"))
assert df.to_pydict() == {"nested": [1, 2, 3, 4, None, None], "sidecar": ["a", "a", "b", "b", "c", "d"]}
Expand All @@ -28,7 +28,7 @@ def test_explode(data):
Series.from_arrow(pa.array([[1, 2], [3, 4], None, []], type=pa.large_list(pa.int64()))),
],
)
def test_explode_multiple_cols(data):
def test_explode_multiple_cols(data, use_new_planner):
df = daft.from_pydict({"nested": data, "nested2": data, "sidecar": ["a", "b", "c", "d"]})
df = df.explode(col("nested"), col("nested2"))
assert df.to_pydict() == {
Expand All @@ -38,7 +38,7 @@ def test_explode_multiple_cols(data):
}


def test_explode_bad_col_type():
def test_explode_bad_col_type(use_new_planner):
df = daft.from_pydict({"a": [1, 2, 3]})
with pytest.raises(ValueError, match="Datatype cannot be exploded:"):
df = df.explode(col("a"))

0 comments on commit 088b8b9

Please sign in to comment.