diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs index f10f83dfe3c8..cafa385eac39 100644 --- a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::SchemaRef; use datafusion_common::Result; @@ -42,6 +43,9 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; + + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) + fn clear_shrink(&mut self, batch: &RecordBatch); } pub fn new_group_values(schema: SchemaRef) -> Result> { diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs index d7989fb8c4c5..7a52729d2018 100644 --- a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs +++ b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs @@ -20,6 +20,7 @@ use ahash::RandomState; use arrow::array::BooleanBufferBuilder; use arrow::buffer::NullBuffer; use arrow::datatypes::i256; +use arrow::record_batch::RecordBatch; use arrow_array::cast::AsArray; use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray}; use arrow_schema::DataType; @@ -206,4 +207,12 @@ where }; Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.values.clear(); + self.values.shrink_to(count); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + } } diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs index 4eb660d52590..d711a1619116 100644 --- a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs +++ b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs @@ -17,6 +17,7 @@ use crate::physical_plan::aggregates::group_values::GroupValues; use ahash::RandomState; +use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; use arrow_array::ArrayRef; use arrow_schema::SchemaRef; @@ -181,4 +182,15 @@ impl GroupValues for GroupValuesRows { } }) } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + // FIXME: there is no good way to clear_shrink for self.group_values + self.group_values = self.row_converter.empty_rows(count, 0); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.hashes_buffer.clear(); + self.hashes_buffer.shrink_to(count); + } } diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index bb3f1edfa82d..bbc2b949e2ca 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -1296,6 +1296,7 @@ mod tests { use std::sync::Arc; use std::task::{Context, Poll}; + use datafusion_execution::config::SessionConfig; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1466,7 +1467,22 @@ mod tests { ) } - async fn check_grouping_sets(input: Arc) -> Result<()> { + fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc { + let session_config = SessionConfig::new().with_batch_size(batch_size); + let runtime = Arc::new( + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0)) + .unwrap(), + ); + let task_ctx = TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime); + Arc::new(task_ctx) + } + + async fn check_grouping_sets( + input: Arc, + spill: bool, + ) -> Result<()> { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { @@ -1491,7 +1507,11 @@ mod tests { DataType::Int64, ))]; - let task_ctx = Arc::new(TaskContext::default()); + let task_ctx = if spill { + new_spill_ctx(4, 1000) + } else { + Arc::new(TaskContext::default()) + }; let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -1506,24 +1526,53 @@ mod tests { let result = common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; - let expected = vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", - ]; + let expected = if spill { + vec![ + "+---+-----+-----------------+", + "| a | b | COUNT(1)[count] |", + "+---+-----+-----------------+", + "| | 1.0 | 1 |", + "| | 1.0 | 1 |", + "| | 2.0 | 1 |", + "| | 2.0 | 1 |", + "| | 3.0 | 1 |", + "| | 3.0 | 1 |", + "| | 4.0 | 1 |", + "| | 4.0 | 1 |", + "| 2 | | 1 |", + "| 2 | | 1 |", + "| 2 | 1.0 | 1 |", + "| 2 | 1.0 | 1 |", + "| 3 | | 1 |", + "| 3 | | 2 |", + "| 3 | 2.0 | 2 |", + "| 3 | 3.0 | 1 |", + "| 4 | | 1 |", + "| 4 | | 2 |", + "| 4 | 3.0 | 1 |", + "| 4 | 4.0 | 2 |", + "+---+-----+-----------------+", + ] + } else { + vec![ + "+---+-----+-----------------+", + "| a | b | COUNT(1)[count] |", + "+---+-----+-----------------+", + "| | 1.0 | 2 |", + "| | 2.0 | 2 |", + "| | 3.0 | 2 |", + "| | 4.0 | 2 |", + "| 2 | | 2 |", + "| 2 | 1.0 | 2 |", + "| 3 | | 3 |", + "| 3 | 2.0 | 2 |", + "| 3 | 3.0 | 1 |", + "| 4 | | 3 |", + "| 4 | 3.0 | 1 |", + "| 4 | 4.0 | 2 |", + "+---+-----+-----------------+", + ] + }; assert_batches_sorted_eq!(expected, &result); let groups = partial_aggregate.group_expr().expr().to_vec(); @@ -1537,6 +1586,12 @@ mod tests { let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let task_ctx = if spill { + new_spill_ctx(4, 3160) + } else { + task_ctx + }; + let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, final_grouping_set, @@ -1582,7 +1637,7 @@ mod tests { } /// build the aggregates on the data from some_data() and check the results - async fn check_aggregates(input: Arc) -> Result<()> { + async fn check_aggregates(input: Arc, spill: bool) -> Result<()> { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { @@ -1597,7 +1652,11 @@ mod tests { DataType::Float64, ))]; - let task_ctx = Arc::new(TaskContext::default()); + let task_ctx = if spill { + new_spill_ctx(2, 2144) + } else { + Arc::new(TaskContext::default()) + }; let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, @@ -1612,15 +1671,29 @@ mod tests { let result = common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; - let expected = [ - "+---+---------------+-------------+", - "| a | AVG(b)[count] | AVG(b)[sum] |", - "+---+---------------+-------------+", - "| 2 | 2 | 2.0 |", - "| 3 | 3 | 7.0 |", - "| 4 | 3 | 11.0 |", - "+---+---------------+-------------+", - ]; + let expected = if spill { + vec![ + "+---+---------------+-------------+", + "| a | AVG(b)[count] | AVG(b)[sum] |", + "+---+---------------+-------------+", + "| 2 | 1 | 1.0 |", + "| 2 | 1 | 1.0 |", + "| 3 | 1 | 2.0 |", + "| 3 | 2 | 5.0 |", + "| 4 | 3 | 11.0 |", + "+---+---------------+-------------+", + ] + } else { + vec![ + "+---+---------------+-------------+", + "| a | AVG(b)[count] | AVG(b)[sum] |", + "+---+---------------+-------------+", + "| 2 | 2 | 2.0 |", + "| 3 | 3 | 7.0 |", + "| 4 | 3 | 11.0 |", + "+---+---------------+-------------+", + ] + }; assert_batches_sorted_eq!(expected, &result); let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); @@ -1663,7 +1736,13 @@ mod tests { let metrics = merged_aggregate.metrics().unwrap(); let output_rows = metrics.output_rows().unwrap(); - assert_eq!(3, output_rows); + if spill { + // When spilling, the output rows metrics become partial output size + final output size + // This is because final aggregation starts while partial aggregation is still emitting + assert_eq!(8, output_rows); + } else { + assert_eq!(3, output_rows); + } Ok(()) } @@ -1784,7 +1863,7 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: false }); - check_aggregates(input).await + check_aggregates(input, false).await } #[tokio::test] @@ -1792,7 +1871,7 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: false }); - check_grouping_sets(input).await + check_grouping_sets(input, false).await } #[tokio::test] @@ -1800,7 +1879,7 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: true }); - check_aggregates(input).await + check_aggregates(input, false).await } #[tokio::test] @@ -1808,7 +1887,39 @@ mod tests { let input: Arc = Arc::new(TestYieldingExec { yield_first: true }); - check_grouping_sets(input).await + check_grouping_sets(input, false).await + } + + #[tokio::test] + async fn aggregate_source_not_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: false }); + + check_aggregates(input, true).await + } + + #[tokio::test] + async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: false }); + + check_grouping_sets(input, true).await + } + + #[tokio::test] + async fn aggregate_source_with_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + + check_aggregates(input, true).await + } + + #[tokio::test] + async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + + check_grouping_sets(input, true).await } #[tokio::test] @@ -1976,7 +2087,10 @@ mod tests { async fn run_first_last_multi_partitions() -> Result<()> { for use_coalesce_batches in [false, true] { for is_first_acc in [false, true] { - first_last_multi_partitions(use_coalesce_batches, is_first_acc).await? + for spill in [false, true] { + first_last_multi_partitions(use_coalesce_batches, is_first_acc, spill) + .await? + } } } Ok(()) @@ -2002,8 +2116,13 @@ mod tests { async fn first_last_multi_partitions( use_coalesce_batches: bool, is_first_acc: bool, + spill: bool, ) -> Result<()> { - let task_ctx = Arc::new(TaskContext::default()); + let task_ctx = if spill { + new_spill_ctx(2, 2812) + } else { + Arc::new(TaskContext::default()) + }; let (schema, data) = some_data_v2(); let partition1 = data[0].clone(); diff --git a/datafusion/core/src/physical_plan/aggregates/order/partial.rs b/datafusion/core/src/physical_plan/aggregates/order/partial.rs index 019e61ef2688..0feac3a5ed52 100644 --- a/datafusion/core/src/physical_plan/aggregates/order/partial.rs +++ b/datafusion/core/src/physical_plan/aggregates/order/partial.rs @@ -241,7 +241,7 @@ impl GroupOrderingPartial { Ok(()) } - /// Return the size of memor allocated by this structure + /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { std::mem::size_of::() + self.order_indices.allocated_size() diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index d034bd669e55..eef25c1dc214 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -18,7 +18,7 @@ //! Hash aggregation use datafusion_physical_expr::{ - AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, + AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, PhysicalSortExpr, }; use log::debug; use std::sync::Arc; @@ -29,19 +29,28 @@ use futures::ready; use futures::stream::{Stream, StreamExt}; use crate::physical_plan::aggregates::group_values::{new_group_values, GroupValues}; +use crate::physical_plan::aggregates::order::GroupOrderingFull; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, PhysicalGroupBy, }; +use crate::physical_plan::common::IPCWriter; use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; +use crate::physical_plan::sorts::sort::{read_spill_as_stream, sort_batch}; +use crate::physical_plan::sorts::streaming_merge; +use crate::physical_plan::stream::RecordBatchStreamAdapter; use crate::physical_plan::{aggregates, PhysicalExpr}; use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::Result; +use arrow_schema::SortOptions; +use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; #[derive(Debug, Clone)] /// This object tracks the aggregation phase (input/output) @@ -56,6 +65,28 @@ pub(crate) enum ExecutionState { use super::order::GroupOrdering; use super::AggregateExec; +/// This encapsulates the spilling state +struct SpillState { + /// If data has previously been spilled, the locations of the + /// spill files (in Arrow IPC format) + spills: Vec, + + /// Sorting expression for spilling batches + spill_expr: Vec, + + /// Schema for spilling batches + spill_schema: SchemaRef, + + /// true when streaming merge is in progress + is_stream_merging: bool, + + /// aggregate_arguments for merging spilled data + merging_aggregate_arguments: Vec>>, + + /// GROUP BY expressions for merging spilled data + merging_group_by: PhysicalGroupBy, +} + /// HashTable based Grouping Aggregator /// /// # Design Goals @@ -120,6 +151,57 @@ use super::AggregateExec; /// hash table). /// /// [`group_values`]: Self::group_values +/// +/// # Spilling +/// +/// The sizes of group values and accumulators can become large. Before that causes out of memory, +/// this hash aggregator outputs partial states early for partial aggregation or spills to local +/// disk using Arrow IPC format for final aggregation. For every input [`RecordBatch`], the memory +/// manager checks whether the new input size meets the memory configuration. If not, outputting or +/// spilling happens. For outputting, the final aggregation takes care of re-grouping. For spilling, +/// later stream-merge sort on reading back the spilled data does re-grouping. Note the rows cannot +/// be grouped once spilled onto disk, the read back data needs to be re-grouped again. In addition, +/// re-grouping may cause out of memory again. Thus, re-grouping has to be a sort based aggregation. +/// +/// ```text +/// Partial Aggregation [batch_size = 2] (max memory = 3 rows) +/// +/// INPUTS PARTIALLY AGGREGATED (UPDATE BATCH) OUTPUTS +/// ┌─────────┐ ┌─────────────────┐ ┌─────────────────┐ +/// │ a │ b │ │ a │ AVG(b) │ │ a │ AVG(b) │ +/// │---│-----│ │ │[count]│[sum]│ │ │[count]│[sum]│ +/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│ │---│-------│-----│ +/// │ 2 │ 2.0 │ │ 2 │ 1 │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1 │ 2.0 │ +/// └─────────┘ │ 3 │ 2 │ 7.0 │ │ │ 3 │ 2 │ 7.0 │ +/// ┌─────────┐ ─▶ │ 4 │ 1 │ 8.0 │ │ └─────────────────┘ +/// │ 3 │ 4.0 │ └─────────────────┘ └▶ ┌─────────────────┐ +/// │ 4 │ 8.0 │ ┌─────────────────┐ │ 4 │ 1 │ 8.0 │ +/// └─────────┘ │ a │ AVG(b) │ ┌▶ │ 1 │ 1 │ 1.0 │ +/// ┌─────────┐ │---│-------│-----│ │ └─────────────────┘ +/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1 │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐ +/// │ 3 │ 2.0 │ │ 3 │ 1 │ 2.0 │ │ 3 │ 1 │ 2.0 │ +/// └─────────┘ └─────────────────┘ └─────────────────┘ +/// +/// +/// Final Aggregation [batch_size = 2] (max memory = 3 rows) +/// +/// PARTIALLY INPUTS FINAL AGGREGATION (MERGE BATCH) RE-GROUPED (SORTED) +/// ┌─────────────────┐ [keep using the partial schema] [Real final aggregation +/// │ a │ AVG(b) │ ┌─────────────────┐ output] +/// │ │[count]│[sum]│ │ a │ AVG(b) │ ┌────────────┐ +/// │---│-------│-----│ ─▶ │ │[count]│[sum]│ │ a │ AVG(b) │ +/// │ 3 │ 3 │ 3.0 │ │---│-------│-----│ ─▶ spill ─┐ │---│--------│ +/// │ 2 │ 2 │ 1.0 │ │ 2 │ 2 │ 1.0 │ │ │ 1 │ 4.0 │ +/// └─────────────────┘ │ 3 │ 4 │ 8.0 │ ▼ │ 2 │ 1.0 │ +/// ┌─────────────────┐ ─▶ │ 4 │ 1 │ 7.0 │ Streaming ─▶ └────────────┘ +/// │ 3 │ 1 │ 5.0 │ └─────────────────┘ merge sort ─▶ ┌────────────┐ +/// │ 4 │ 1 │ 7.0 │ ┌─────────────────┐ ▲ │ a │ AVG(b) │ +/// └─────────────────┘ │ a │ AVG(b) │ │ │---│--------│ +/// ┌─────────────────┐ │---│-------│-----│ ─▶ memory ─┘ │ 3 │ 2.0 │ +/// │ 1 │ 2 │ 8.0 │ ─▶ │ 1 │ 2 │ 8.0 │ │ 4 │ 7.0 │ +/// │ 2 │ 2 │ 3.0 │ │ 2 │ 2 │ 3.0 │ └────────────┘ +/// └─────────────────┘ └─────────────────┘ +/// ``` pub(crate) struct GroupedHashAggregateStream { schema: SchemaRef, input: SendableRecordBatchStream, @@ -178,6 +260,12 @@ pub(crate) struct GroupedHashAggregateStream { /// Have we seen the end of the input input_done: bool, + + /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument + runtime: Arc, + + /// The spill state object + spill_state: SpillState, } impl GroupedHashAggregateStream { @@ -207,6 +295,12 @@ impl GroupedHashAggregateStream { &agg.mode, agg_group_by.expr.len(), )?; + // arguments for aggregating spilled data is the same as the one for final aggregation + let merging_aggregate_arguments = aggregates::aggregate_expressions( + &agg.aggr_expr, + &AggregateMode::Final, + agg_group_by.expr.len(), + )?; let filter_expressions = match agg.mode { AggregateMode::Partial @@ -224,6 +318,14 @@ impl GroupedHashAggregateStream { .collect::>()?; let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let spill_expr = group_schema + .fields + .into_iter() + .map(|field| PhysicalSortExpr { + expr: col(field.name(), &group_schema).unwrap(), + options: SortOptions::default(), + }) + .collect(); let name = format!("GroupedHashAggregateStream[{partition}]"); let reservation = MemoryConsumer::new(name).register(context.memory_pool()); @@ -243,6 +345,15 @@ impl GroupedHashAggregateStream { let exec_state = ExecutionState::ReadingInput; + let spill_state = SpillState { + spills: vec![], + spill_expr, + spill_schema: agg_schema.clone(), + is_stream_merging: false, + merging_aggregate_arguments, + merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + }; + Ok(GroupedHashAggregateStream { schema: agg_schema, input, @@ -259,6 +370,8 @@ impl GroupedHashAggregateStream { batch_size, group_ordering, input_done: false, + runtime: context.runtime_env(), + spill_state, }) } } @@ -310,6 +423,9 @@ impl Stream for GroupedHashAggregateStream { // new batch to aggregate Some(Ok(batch)) => { let timer = elapsed_compute.timer(); + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -318,9 +434,12 @@ impl Stream for GroupedHashAggregateStream { assert!(!self.input_done); if let Some(to_emit) = self.group_ordering.emit_to() { - let batch = extract_ok!(self.emit(to_emit)); + let batch = extract_ok!(self.emit(to_emit, false)); self.exec_state = ExecutionState::ProducingOutput(batch); } + + extract_ok!(self.emit_early_if_necessary()); + timer.done(); } Some(Err(e)) => { @@ -332,8 +451,14 @@ impl Stream for GroupedHashAggregateStream { self.input_done = true; self.group_ordering.input_done(); let timer = elapsed_compute.timer(); - let batch = extract_ok!(self.emit(EmitTo::All)); - self.exec_state = ExecutionState::ProducingOutput(batch); + if self.spill_state.spills.is_empty() { + let batch = extract_ok!(self.emit(EmitTo::All, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + } else { + // If spill files exist, stream-merge them. + extract_ok!(self.update_merged_stream()); + self.exec_state = ExecutionState::ReadingInput; + } timer.done(); } } @@ -360,7 +485,13 @@ impl Stream for GroupedHashAggregateStream { ))); } - ExecutionState::Done => return Poll::Ready(None), + ExecutionState::Done => { + // release the memory reservation since sending back output batch itself needs + // some memory reservation, so make some room for it. + self.clear_all(); + let _ = self.update_memory_reservation(); + return Poll::Ready(None); + } } } } @@ -376,13 +507,26 @@ impl GroupedHashAggregateStream { /// Perform group-by aggregation for the given [`RecordBatch`]. fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { // Evaluate the grouping expressions - let group_by_values = evaluate_group_by(&self.group_by, &batch)?; + let group_by_values = if self.spill_state.is_stream_merging { + evaluate_group_by(&self.spill_state.merging_group_by, &batch)? + } else { + evaluate_group_by(&self.group_by, &batch)? + }; // Evaluate the aggregation expressions. - let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; + let input_values = if self.spill_state.is_stream_merging { + evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)? + } else { + evaluate_many(&self.aggregate_arguments, &batch)? + }; // Evaluate the filter expressions, if any, against the inputs - let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; + let filter_values = if self.spill_state.is_stream_merging { + let filter_expressions = vec![None; self.accumulators.len()]; + evaluate_optional(&filter_expressions, &batch)? + } else { + evaluate_optional(&self.filter_expressions, &batch)? + }; for group_values in &group_by_values { // calculate the group indices for each input row @@ -416,7 +560,9 @@ impl GroupedHashAggregateStream { match self.mode { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => { + | AggregateMode::SinglePartitioned + if !self.spill_state.is_stream_merging => + { acc.update_batch( values, group_indices, @@ -424,7 +570,7 @@ impl GroupedHashAggregateStream { total_num_groups, )?; } - AggregateMode::FinalPartitioned | AggregateMode::Final => { + _ => { // if aggregation is over intermediate states, // use merge acc.merge_batch( @@ -438,7 +584,16 @@ impl GroupedHashAggregateStream { } } - self.update_memory_reservation() + match self.update_memory_reservation() { + // Here we can ignore `insufficient_capacity_err` because we will spill later, + // but at least one batch should fit in the memory + Err(DataFusionError::ResourcesExhausted(_)) + if self.group_values.len() >= self.batch_size => + { + Ok(()) + } + other => other, + } } fn update_memory_reservation(&mut self) -> Result<()> { @@ -452,9 +607,14 @@ impl GroupedHashAggregateStream { /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to - fn emit(&mut self, emit_to: EmitTo) -> Result { + fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { + let schema = if spilling { + self.spill_state.spill_schema.clone() + } else { + self.schema() + }; if self.group_values.is_empty() { - return Ok(RecordBatch::new_empty(self.schema())); + return Ok(RecordBatch::new_empty(schema)); } let mut output = self.group_values.emit(emit_to)?; @@ -466,6 +626,11 @@ impl GroupedHashAggregateStream { for acc in self.accumulators.iter_mut() { match self.mode { AggregateMode::Partial => output.extend(acc.state(emit_to)?), + _ if spilling => { + // If spilling, output partial state because the spilled data will be + // merged and re-evaluated later. + output.extend(acc.state(emit_to)?) + } AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::Single @@ -473,8 +638,110 @@ impl GroupedHashAggregateStream { } } - self.update_memory_reservation()?; - let batch = RecordBatch::try_new(self.schema(), output)?; + // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is + // over the target memory size after emission, we can emit again rather than returning Err. + let _ = self.update_memory_reservation(); + let batch = RecordBatch::try_new(schema, output)?; Ok(batch) } + + /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly + /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the + /// memory. Currently only [`GroupOrdering::None`] is supported for spilling. + fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> { + // TODO: support group_ordering for spilling + if self.group_values.len() > 0 + && batch.num_rows() > 0 + && matches!(self.group_ordering, GroupOrdering::None) + && !matches!(self.mode, AggregateMode::Partial) + && !self.spill_state.is_stream_merging + && self.update_memory_reservation().is_err() + { + // Use input batch (Partial mode) schema for spilling because + // the spilled data will be merged and re-evaluated later. + self.spill_state.spill_schema = batch.schema(); + self.spill()?; + self.clear_shrink(batch); + } + Ok(()) + } + + /// Emit all rows, sort them, and store them on disk. + fn spill(&mut self) -> Result<()> { + let emit = self.emit(EmitTo::All, true)?; + let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; + let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; + let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?; + // TODO: slice large `sorted` and write to multiple files in parallel + writer.write(&sorted)?; + writer.finish()?; + self.spill_state.spills.push(spillfile); + Ok(()) + } + + /// Clear memory and shirk capacities to the size of the batch. + fn clear_shrink(&mut self, batch: &RecordBatch) { + self.group_values.clear_shrink(batch); + self.current_group_indices.clear(); + self.current_group_indices.shrink_to(batch.num_rows()); + } + + /// Clear memory and shirk capacities to zero. + fn clear_all(&mut self) { + let s = self.schema(); + self.clear_shrink(&RecordBatch::new_empty(s)); + } + + /// Emit if the used memory exceeds the target for partial aggregation. + /// Currently only [`GroupOrdering::None`] is supported for early emitting. + /// TODO: support group_ordering for early emitting + fn emit_early_if_necessary(&mut self) -> Result<()> { + if self.group_values.len() >= self.batch_size + && matches!(self.group_ordering, GroupOrdering::None) + && matches!(self.mode, AggregateMode::Partial) + && self.update_memory_reservation().is_err() + { + let n = self.group_values.len() / self.batch_size * self.batch_size; + let batch = self.emit(EmitTo::First(n), false)?; + self.exec_state = ExecutionState::ProducingOutput(batch); + } + Ok(()) + } + + /// At this point, all the inputs are read and there are some spills. + /// Emit the remaining rows and create a batch. + /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully + /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. + fn update_merged_stream(&mut self) -> Result<()> { + let batch = self.emit(EmitTo::All, true)?; + // clear up memory for streaming_merge + self.clear_all(); + self.update_memory_reservation()?; + let mut streams: Vec = vec![]; + let expr = self.spill_state.spill_expr.clone(); + let schema = batch.schema(); + streams.push(Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(futures::future::lazy(move |_| { + sort_batch(&batch, &expr, None) + })), + ))); + for spill in self.spill_state.spills.drain(..) { + let stream = read_spill_as_stream(spill, schema.clone())?; + streams.push(stream); + } + self.spill_state.is_stream_merging = true; + self.input = streaming_merge( + streams, + schema, + &self.spill_state.spill_expr, + self.baseline_metrics.clone(), + self.batch_size, + None, + self.reservation.new_empty(), + )?; + self.input_done = false; + self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 17b94d51c587..92fb45142ed0 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -574,7 +574,7 @@ impl Debug for ExternalSorter { } } -fn sort_batch( +pub(crate) fn sort_batch( batch: &RecordBatch, expressions: &[PhysicalSortExpr], fetch: Option, @@ -608,7 +608,7 @@ async fn spill_sorted_batches( } } -fn read_spill_as_stream( +pub(crate) fn read_spill_as_stream( path: RefCountedTempFile, schema: SchemaRef, ) -> Result { diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 7e8930ce2a32..02bb466d44bd 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -165,6 +165,8 @@ struct FirstValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Whether merge_batch() is called before + is_merge_called: bool, } impl FirstValueAccumulator { @@ -183,6 +185,7 @@ impl FirstValueAccumulator { is_set: false, orderings, ordering_req, + is_merge_called: false, }) } @@ -198,7 +201,9 @@ impl Accumulator for FirstValueAccumulator { fn state(&self) -> Result> { let mut result = vec![self.first.clone()]; result.extend(self.orderings.iter().cloned()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + if !self.is_merge_called { + result.push(ScalarValue::Boolean(Some(self.is_set))); + } Ok(result) } @@ -213,6 +218,7 @@ impl Accumulator for FirstValueAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.is_merge_called = true; // FIRST_VALUE(first1, first2, first3, ...) // last index contains is_set flag. let is_set_idx = states.len() - 1; @@ -384,6 +390,8 @@ struct LastValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // Whether merge_batch() is called before + is_merge_called: bool, } impl LastValueAccumulator { @@ -402,6 +410,7 @@ impl LastValueAccumulator { is_set: false, orderings, ordering_req, + is_merge_called: false, }) } @@ -417,7 +426,9 @@ impl Accumulator for LastValueAccumulator { fn state(&self) -> Result> { let mut result = vec![self.last.clone()]; result.extend(self.orderings.clone()); - result.push(ScalarValue::Boolean(Some(self.is_set))); + if !self.is_merge_called { + result.push(ScalarValue::Boolean(Some(self.is_set))); + } Ok(result) } @@ -431,6 +442,7 @@ impl Accumulator for LastValueAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.is_merge_called = true; // LAST_VALUE(last1, last2, last3, ...) // last index contains is_set flag. let is_set_idx = states.len() - 1;