Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support aggr push down in distributed query #1232

Merged
merged 13 commits into from
Sep 30, 2023
Merged
154 changes: 129 additions & 25 deletions df_engine_extensions/src/dist_sql_query/physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ use datafusion::{
execution::TaskContext,
physical_expr::PhysicalSortExpr,
physical_plan::{
aggregates::{AggregateExec, AggregateMode},
coalesce_batches::CoalesceBatchesExec,
coalesce_partitions::CoalescePartitionsExec,
displayable,
filter::FilterExec,
projection::ProjectionExec,
repartition::RepartitionExec,
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream as DfSendableRecordBatchStream, Statistics,
},
Expand Down Expand Up @@ -111,33 +118,114 @@ impl DisplayAs for UnresolvedPartitionedScan {
/// related nodes to execute.
#[derive(Debug)]
pub struct ResolvedPartitionedScan {
pub remote_executor: Arc<dyn RemotePhysicalPlanExecutor>,
pub remote_exec_plans: Vec<(TableIdentifier, Arc<dyn ExecutionPlan>)>,
pub remote_exec_ctx: Arc<RemoteExecContext>,
pub pushing_down: bool,
}

impl ResolvedPartitionedScan {
pub fn extend_remote_exec_plans(
pub fn try_to_push_down_more(
&self,
extended_node: Arc<dyn ExecutionPlan>,
) -> DfResult<Arc<ResolvedPartitionedScan>> {
cur_node: Arc<dyn ExecutionPlan>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
// Can not push more...
if !self.pushing_down {
return cur_node.with_new_children(vec![self.push_down_finished()]);
}

// Push down more, and when occur the terminated push down able node, we need to
// set `can_push_down_more` false.
let push_down_able_opt = PushDownAble::try_new(cur_node.clone());
let (node, can_push_down_more) = match push_down_able_opt {
Some(PushDownAble::Continue(node)) => (node, true),
Some(PushDownAble::Terminated(node)) => (node, false),
None => {
let partitioned_scan = self.push_down_finished();
return cur_node.with_new_children(vec![partitioned_scan]);
}
};

let new_plans = self
.remote_exec_plans
.remote_exec_ctx
.plans
.iter()
.map(|(table, plan)| {
extended_node
.clone()
node.clone()
.with_new_children(vec![plan.clone()])
.map(|extended_plan| (table.clone(), extended_plan))
})
.collect::<DfResult<Vec<_>>>()?;

let remote_exec_ctx = Arc::new(RemoteExecContext {
executor: self.remote_exec_ctx.executor.clone(),
plans: new_plans,
});
let plan = ResolvedPartitionedScan {
remote_executor: self.remote_executor.clone(),
remote_exec_plans: new_plans,
remote_exec_ctx,
pushing_down: can_push_down_more,
};

Ok(Arc::new(plan))
}

pub fn new(
remote_executor: Arc<dyn RemotePhysicalPlanExecutor>,
remote_exec_plans: Vec<(TableIdentifier, Arc<dyn ExecutionPlan>)>,
) -> Self {
let remote_exec_ctx = Arc::new(RemoteExecContext {
executor: remote_executor,
plans: remote_exec_plans,
});

Self {
remote_exec_ctx,
pushing_down: true,
}
}

pub fn push_down_finished(&self) -> Arc<dyn ExecutionPlan> {
Arc::new(Self {
remote_exec_ctx: self.remote_exec_ctx.clone(),
pushing_down: false,
})
}
}

#[derive(Debug)]
pub struct RemoteExecContext {
executor: Arc<dyn RemotePhysicalPlanExecutor>,
plans: Vec<(TableIdentifier, Arc<dyn ExecutionPlan>)>,
}

pub enum PushDownAble {
Continue(Arc<dyn ExecutionPlan>),
Terminated(Arc<dyn ExecutionPlan>),
}

impl PushDownAble {
pub fn try_new(plan: Arc<dyn ExecutionPlan>) -> Option<Self> {
if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
if *aggr.mode() == AggregateMode::Partial {
Some(Self::Terminated(plan))
} else {
None
}
} else if plan.as_any().downcast_ref::<FilterExec>().is_some()
|| plan.as_any().downcast_ref::<ProjectionExec>().is_some()
|| plan.as_any().downcast_ref::<RepartitionExec>().is_some()
|| plan
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
.is_some()
|| plan
.as_any()
.downcast_ref::<CoalesceBatchesExec>()
.is_some()
{
Some(Self::Continue(plan))
} else {
None
}
}
}

impl ExecutionPlan for ResolvedPartitionedScan {
Expand All @@ -146,31 +234,36 @@ impl ExecutionPlan for ResolvedPartitionedScan {
}

fn schema(&self) -> ArrowSchemaRef {
self.remote_exec_plans
self.remote_exec_ctx
.plans
.first()
.expect("remote_exec_plans should not be empty")
.1
.schema()
}

fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(self.remote_exec_plans.len())
Partitioning::UnknownPartitioning(self.remote_exec_ctx.plans.len())
}

fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![]
self.remote_exec_ctx
.plans
.iter()
.map(|(_, plan)| plan.clone())
.collect()
}

fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
Err(DataFusionError::Internal(
"UnresolvedPartitionedScan should not have children".to_string(),
"UnresolvedPartitionedScan can't be built directly from new children".to_string(),
))
}

Expand All @@ -179,11 +272,19 @@ impl ExecutionPlan for ResolvedPartitionedScan {
partition: usize,
context: Arc<TaskContext>,
) -> DfResult<DfSendableRecordBatchStream> {
let (sub_table, plan) = &self.remote_exec_plans[partition];
if self.pushing_down {
return Err(DataFusionError::Internal(format!(
"partitioned scan can't be executed before pushing down finished, plan:{}",
displayable(self).indent(true)
)));
}

let (sub_table, plan) = &self.remote_exec_ctx.plans[partition];

// Send plan for remote execution.
let stream_future =
self.remote_executor
self.remote_exec_ctx
.executor
.execute(sub_table.clone(), &context, plan.clone())?;
let record_stream = PartitionedScanStream::new(stream_future, plan.schema());

Expand Down Expand Up @@ -280,15 +381,18 @@ pub(crate) enum StreamState {
Polling(DfSendableRecordBatchStream),
}

// TODO: make display for the plan more pretty.
impl DisplayAs for ResolvedPartitionedScan {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ResolvedPartitionedScan: remote_exec_plans:{:?}, partition_count={}",
self.remote_exec_plans,
self.output_partitioning().partition_count(),
)
fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
write!(
f,
"ResolvedPartitionedScan: pushing_down:{}, partition_count:{}",
self.pushing_down,
self.remote_exec_ctx.plans.len()
)
}
}
}
}

Expand Down Expand Up @@ -352,7 +456,7 @@ impl DisplayAs for UnresolvedSubTableScan {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"UnresolvedSubTableScan: table={:?}, read_request:{:?}, partition_count={}",
"UnresolvedSubTableScan: table:{:?}, request:{:?}, partition_count:{}",
self.table,
self.read_request,
self.output_partitioning().partition_count(),
Expand Down
96 changes: 90 additions & 6 deletions df_engine_extensions/src/dist_sql_query/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ impl Resolver {
pub fn resolve_partitioned_scan(
&self,
plan: Arc<dyn ExecutionPlan>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
let resolved_plan = self.resolve_partitioned_scan_internal(plan)?;

if let Some(plan) = resolved_plan
.as_any()
.downcast_ref::<ResolvedPartitionedScan>()
{
Ok(plan.push_down_finished())
} else {
Ok(resolved_plan)
}
}

pub fn resolve_partitioned_scan_internal(
&self,
plan: Arc<dyn ExecutionPlan>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
// Leave node, let's resolve it and return.
if let Some(unresolved) = plan.as_any().downcast_ref::<UnresolvedPartitionedScan>() {
Expand All @@ -73,10 +89,10 @@ impl Resolver {
})
.collect::<Vec<_>>();

return Ok(Arc::new(ResolvedPartitionedScan {
remote_executor: self.remote_executor.clone(),
remote_exec_plans: remote_plans,
}));
return Ok(Arc::new(ResolvedPartitionedScan::new(
self.remote_executor.clone(),
remote_plans,
)));
}

let children = plan.children().clone();
Expand All @@ -88,12 +104,47 @@ impl Resolver {
// Resolve children if exist.
let mut new_children = Vec::with_capacity(children.len());
for child in children {
let child = self.resolve_partitioned_scan(child)?;
let child = self.resolve_partitioned_scan_internal(child)?;

new_children.push(child);
}

plan.with_new_children(new_children)
Self::maybe_push_down_to_remote_plans(new_children, plan)
}

fn maybe_push_down_to_remote_plans(
mut new_children: Vec<Arc<dyn ExecutionPlan>>,
current_node: Arc<dyn ExecutionPlan>,
) -> DfResult<Arc<dyn ExecutionPlan>> {
// No children, just return.
if new_children.is_empty() {
return Ok(current_node);
}

// This node Has multiple children, it can't be pushed down to remote.
// But it's possible that `ResolvedPartitionedScan`s exist among its children,
// we need to extract these children and mark them push down finished.
if new_children.len() > 1 {
new_children.iter_mut().for_each(|child| {
if let Some(plan) = child.as_any().downcast_ref::<ResolvedPartitionedScan>() {
*child = plan.push_down_finished();
}
});
return current_node.with_new_children(new_children);
}

// Has ensured that this node has just child and it is just
// `ResolvedPartitionedScan`, try to push down it to remote plans in
// `ResolvedPartitionedScan`.
let child = new_children.first().unwrap();
let partitioned_scan =
if let Some(plan) = child.as_any().downcast_ref::<ResolvedPartitionedScan>() {
plan
} else {
return current_node.with_new_children(new_children);
};

partitioned_scan.try_to_push_down_more(current_node.clone())
}

#[async_recursion]
Expand Down Expand Up @@ -203,4 +254,37 @@ mod test {

assert_eq!(original_plan_display, new_plan_display);
}

#[test]
fn test_aggr_push_down() {
let ctx = TestContext::new();
let plan = ctx.build_aggr_push_down_plan();
let resolver = ctx.resolver();
let new_plan = displayable(resolver.resolve_partitioned_scan(plan).unwrap().as_ref())
.indent(true)
.to_string();
insta::assert_snapshot!(new_plan);
}

#[test]
fn test_compounded_aggr_push_down() {
let ctx = TestContext::new();
let plan = ctx.build_compounded_aggr_push_down_plan();
let resolver = ctx.resolver();
let new_plan = displayable(resolver.resolve_partitioned_scan(plan).unwrap().as_ref())
.indent(true)
.to_string();
insta::assert_snapshot!(new_plan);
}

#[test]
fn test_node_with_multiple_partitioned_scan_children() {
let ctx = TestContext::new();
let plan = ctx.build_union_plan();
let resolver = ctx.resolver();
let new_plan = displayable(resolver.resolve_partitioned_scan(plan).unwrap().as_ref())
.indent(true)
.to_string();
insta::assert_snapshot!(new_plan);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
---
source: df_engine_extensions/src/dist_sql_query/resolver.rs
assertion_line: 269
expression: new_plan
---
AggregateExec: mode=Final, gby=[tag1@1 as tag1, tag2@2 as tag2], aggr=[COUNT(value), COUNT(field2)]
CoalescePartitionsExec
ResolvedPartitionedScan: pushing_down:false, partition_count:3
AggregateExec: mode=Partial, gby=[tag1@1 as tag1, tag2@2 as tag2], aggr=[COUNT(value), COUNT(field2)]
UnresolvedSubTableScan: table:TableIdentifier { catalog: "test_catalog", schema: "test_schema", table: "__test_1" }, request:ReadRequest { request_id: RequestId(42), opts: ReadOptions { batch_size: 10000, read_parallelism: 8, deadline: None }, projected: "[time,tag1,tag2,value,field2]", predicate: "[time < TimestampMillisecond(1691974518000, None) AND tag1 = Utf8(\"test_tag\")]" }, partition_count:8
AggregateExec: mode=Partial, gby=[tag1@1 as tag1, tag2@2 as tag2], aggr=[COUNT(value), COUNT(field2)]
UnresolvedSubTableScan: table:TableIdentifier { catalog: "test_catalog", schema: "test_schema", table: "__test_2" }, request:ReadRequest { request_id: RequestId(42), opts: ReadOptions { batch_size: 10000, read_parallelism: 8, deadline: None }, projected: "[time,tag1,tag2,value,field2]", predicate: "[time < TimestampMillisecond(1691974518000, None) AND tag1 = Utf8(\"test_tag\")]" }, partition_count:8
AggregateExec: mode=Partial, gby=[tag1@1 as tag1, tag2@2 as tag2], aggr=[COUNT(value), COUNT(field2)]
UnresolvedSubTableScan: table:TableIdentifier { catalog: "test_catalog", schema: "test_schema", table: "__test_3" }, request:ReadRequest { request_id: RequestId(42), opts: ReadOptions { batch_size: 10000, read_parallelism: 8, deadline: None }, projected: "[time,tag1,tag2,value,field2]", predicate: "[time < TimestampMillisecond(1691974518000, None) AND tag1 = Utf8(\"test_tag\")]" }, partition_count:8

Loading