Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stream,agg): enable distinct agg support in backend #8100

Merged
merged 3 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions src/frontend/planner_test/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,10 @@
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, distinct_b_num, sum_c], pk_columns: [a], pk_conflict: "no check" }
└─StreamProject { exprs: [t.a, count(t.b), sum(sum(t.c))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b), sum(sum(t.c))] }
└─StreamProject { exprs: [t.a, count(distinct t.b), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(distinct t.b), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b], aggs: [count, sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: distinct agg and non-disintct agg with intersected argument
sql: |
create table t(a int, b int, c int);
Expand All @@ -805,14 +802,10 @@
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, distinct_b_num, distinct_c_sum, sum_c], pk_columns: [a], pk_conflict: "no check" }
└─StreamProject { exprs: [t.a, count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─StreamProject { exprs: [t.a, count(distinct t.b), count(distinct t.c), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(distinct t.b), count(distinct t.c), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, t.c, flag, sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b, t.c, flag], aggs: [count, sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b, t.c, flag) }
└─StreamExpand { column_subsets: [[t.a, t.c], [t.a, t.b]] }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: distinct agg with filter
sql: |
create table t(a int, b int, c int);
Expand All @@ -830,13 +823,10 @@
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, count, sum], pk_columns: [a], pk_conflict: "no check" }
└─StreamProject { exprs: [t.a, count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─StreamProject { exprs: [t.a, count(distinct t.b) filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(distinct t.b) filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, count filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b], aggs: [count, count filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: non-distinct agg with filter
sql: |
create table t(a int, b int, c int);
Expand Down
71 changes: 23 additions & 48 deletions src/frontend/planner_test/tests/testdata/nexmark.yaml

Large diffs are not rendered by default.

77 changes: 26 additions & 51 deletions src/frontend/planner_test/tests/testdata/nexmark_source.yaml

Large diffs are not rendered by default.

48 changes: 21 additions & 27 deletions src/frontend/planner_test/tests/testdata/tpch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2691,40 +2691,35 @@
└─BatchScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey], distribution: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
stream_plan: |
StreamMaterialize { columns: [p_brand, p_type, p_size, supplier_cnt], pk_columns: [p_brand, p_type, p_size], order_descs: [supplier_cnt, p_brand, p_type, p_size], pk_conflict: "no check" }
└─StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(partsupp.ps_suppkey)] }
└─StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(partsupp.ps_suppkey)] }
└─StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(distinct partsupp.ps_suppkey)] }
└─StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(distinct partsupp.ps_suppkey)] }
└─StreamExchange { dist: HashShard(part.p_brand, part.p_type, part.p_size) }
└─StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey] }
└─StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey], aggs: [count] }
└─StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
├─StreamExchange { dist: HashShard(partsupp.ps_suppkey) }
| └─StreamHashJoin { type: Inner, predicate: partsupp.ps_partkey = part.p_partkey, output: [partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size, partsupp.ps_partkey, part.p_partkey] }
| ├─StreamExchange { dist: HashShard(partsupp.ps_partkey) }
| | └─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey], pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
| └─StreamExchange { dist: HashShard(part.p_partkey) }
| └─StreamFilter { predicate: (part.p_brand <> 'Brand#45':Varchar) AND (Not((part.p_type >= 'SMALL PLATED':Varchar)) OR Not((part.p_type < 'SMALL PLATEE':Varchar))) AND In(part.p_size, 19:Int32, 17:Int32, 16:Int32, 23:Int32, 10:Int32, 4:Int32, 38:Int32, 11:Int32) }
| └─StreamTableScan { table: part, columns: [part.p_partkey, part.p_brand, part.p_type, part.p_size], pk: [part.p_partkey], dist: UpstreamHashShard(part.p_partkey) }
└─StreamExchange { dist: HashShard(supplier.s_suppkey) }
└─StreamProject { exprs: [supplier.s_suppkey] }
└─StreamFilter { predicate: Like(supplier.s_comment, '%Customer%Complaints%':Varchar) }
└─StreamTableScan { table: supplier, columns: [supplier.s_suppkey, supplier.s_comment], pk: [supplier.s_suppkey], dist: UpstreamHashShard(supplier.s_suppkey) }
└─StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
├─StreamExchange { dist: HashShard(partsupp.ps_suppkey) }
| └─StreamHashJoin { type: Inner, predicate: partsupp.ps_partkey = part.p_partkey, output: [partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size, partsupp.ps_partkey, part.p_partkey] }
| ├─StreamExchange { dist: HashShard(partsupp.ps_partkey) }
| | └─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey], pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
| └─StreamExchange { dist: HashShard(part.p_partkey) }
| └─StreamFilter { predicate: (part.p_brand <> 'Brand#45':Varchar) AND (Not((part.p_type >= 'SMALL PLATED':Varchar)) OR Not((part.p_type < 'SMALL PLATEE':Varchar))) AND In(part.p_size, 19:Int32, 17:Int32, 16:Int32, 23:Int32, 10:Int32, 4:Int32, 38:Int32, 11:Int32) }
| └─StreamTableScan { table: part, columns: [part.p_partkey, part.p_brand, part.p_type, part.p_size], pk: [part.p_partkey], dist: UpstreamHashShard(part.p_partkey) }
└─StreamExchange { dist: HashShard(supplier.s_suppkey) }
└─StreamProject { exprs: [supplier.s_suppkey] }
└─StreamFilter { predicate: Like(supplier.s_comment, '%Customer%Complaints%':Varchar) }
└─StreamTableScan { table: supplier, columns: [supplier.s_suppkey, supplier.s_comment], pk: [supplier.s_suppkey], dist: UpstreamHashShard(supplier.s_suppkey) }
stream_dist_plan: |
Fragment 0
StreamMaterialize { columns: [p_brand, p_type, p_size, supplier_cnt], pk_columns: [p_brand, p_type, p_size], order_descs: [supplier_cnt, p_brand, p_type, p_size], pk_conflict: "no check" }
materialized table: 4294967294
StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(partsupp.ps_suppkey)] }
StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(partsupp.ps_suppkey)] }
StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(distinct partsupp.ps_suppkey)] }
StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(distinct partsupp.ps_suppkey)] }
result table: 0, state tables: []
StreamExchange Hash([0, 1, 2]) from 1

Fragment 1
StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey] }
StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey], aggs: [count] }
result table: 1, state tables: []
StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
left table: 2, right table 4, left degree table: 3, right degree table: 5,
StreamExchange Hash([0]) from 2
StreamExchange Hash([0]) from 5
StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
left table: 2, right table 4, left degree table: 3, right degree table: 5,
StreamExchange Hash([0]) from 2
StreamExchange Hash([0]) from 5

Fragment 2
StreamHashJoin { type: Inner, predicate: partsupp.ps_partkey = part.p_partkey, output: [partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size, partsupp.ps_partkey, part.p_partkey] }
Expand All @@ -2750,8 +2745,7 @@
Upstream
BatchPlanNode

Table 0 { columns: [part_p_brand, part_p_type, part_p_size, count, count(partsupp_ps_suppkey)], primary key: [$0 ASC, $1 ASC, $2 ASC], value indices: [3, 4], distribution key: [0, 1, 2] }
Table 1 { columns: [part_p_brand, part_p_type, part_p_size, partsupp_ps_suppkey, count], primary key: [$0 ASC, $1 ASC, $2 ASC, $3 ASC], value indices: [4], distribution key: [3] }
Table 0 { columns: [part_p_brand, part_p_type, part_p_size, count, count(distinct partsupp_ps_suppkey)], primary key: [$0 ASC, $1 ASC, $2 ASC], value indices: [3, 4], distribution key: [0, 1, 2] }
Table 2 { columns: [partsupp_ps_suppkey, part_p_brand, part_p_type, part_p_size, partsupp_ps_partkey, part_p_partkey], primary key: [$0 ASC, $4 ASC, $5 ASC], value indices: [0, 1, 2, 3, 4, 5], distribution key: [0] }
Table 3 { columns: [partsupp_ps_suppkey, partsupp_ps_partkey, part_p_partkey, _degree], primary key: [$0 ASC, $1 ASC, $2 ASC], value indices: [3], distribution key: [0] }
Table 4 { columns: [supplier_s_suppkey], primary key: [$0 ASC], value indices: [0], distribution key: [0] }
Expand Down
5 changes: 4 additions & 1 deletion src/frontend/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,10 @@ impl PlanRoot {
plan = self.optimize_by_rules(
plan,
"Convert Distinct Aggregation".to_string(),
vec![UnionToDistinctRule::create(), DistinctAggRule::create()],
vec![
UnionToDistinctRule::create(),
DistinctAggRule::create(for_stream),
],
ApplyOrder::TopDown,
);

Expand Down
19 changes: 14 additions & 5 deletions src/frontend/src/optimizer/rule/distinct_agg_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@ use crate::optimizer::PlanRef;
use crate::utils::{ColIndexMapping, Condition};

/// Transform distinct aggregates to `LogicalAgg` -> `LogicalAgg` -> `Expand` -> `Input`.
pub struct DistinctAggRule {}
pub struct DistinctAggRule {
for_stream: bool,
}

impl Rule for DistinctAggRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let agg: &LogicalAgg = plan.as_logical_agg()?;
let (mut agg_calls, mut agg_group_keys, input) = agg.clone().decompose();
let original_group_keys_len = agg_group_keys.len();

if self.for_stream && !agg_group_keys.is_empty() {
// Due to performance issue, we don't do 2-phase agg for stream distinct agg with group
// by. See https://github.com/risingwavelabs/risingwave/issues/7271 for more.
return None;
}

let original_group_keys_len = agg_group_keys.len();
let (node, flag_values, has_expand) =
Self::build_expand(input, &mut agg_group_keys, &mut agg_calls)?;
let mid_agg = Self::build_middle_agg(node, agg_group_keys, agg_calls.clone(), has_expand);
Expand All @@ -50,8 +59,8 @@ impl Rule for DistinctAggRule {
}

impl DistinctAggRule {
pub fn create() -> BoxedRule {
Box::new(DistinctAggRule {})
pub fn create(for_stream: bool) -> BoxedRule {
Box::new(DistinctAggRule { for_stream })
}

/// Construct `Expand` for distinct aggregates.
Expand Down Expand Up @@ -110,7 +119,7 @@ impl DistinctAggRule {

let n_different_distinct = distinct_aggs
.iter()
.unique_by(|agg_call| agg_call.input_indices())
.unique_by(|agg_call| agg_call.input_indices()[0])
.count();
assert_ne!(n_different_distinct, 0); // since `distinct_aggs` is not empty here
if n_different_distinct == 1 {
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/agg_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct AggExecutorArgs<S: StateStore> {
pub storages: Vec<AggStateStorage<S>>,
pub result_table: StateTable<S>,
pub distinct_dedup_tables: HashMap<usize, StateTable<S>>,
pub watermark_epoch: AtomicU64Ref,

// extra
pub extra: Option<AggExecutorArgsExtra>,
Expand All @@ -53,5 +54,4 @@ pub struct AggExecutorArgsExtra {
// things only used by hash agg currently
pub metrics: Arc<StreamingMetrics>,
pub chunk_size: usize,
pub watermark_epoch: AtomicU64Ref,
}
22 changes: 2 additions & 20 deletions src/stream/src/executor/aggregation/agg_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::fmt::Debug;

use itertools::Itertools;
Expand All @@ -26,7 +25,7 @@ use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_storage::StateStore;

use super::agg_state::{AggState, AggStateStorage};
use super::{AggCall, DistinctDeduplicater};
use super::AggCall;
use crate::common::table::state_table::StateTable;
use crate::executor::error::StreamExecutorResult;
use crate::executor::PkIndices;
Expand All @@ -39,9 +38,6 @@ pub struct AggGroup<S: StateStore> {
/// Current managed states for all [`AggCall`]s.
states: Vec<AggState<S>>,

/// Distinct deduplicater to deduplicate input rows for each distinct agg call.
distinct_dedup: DistinctDeduplicater<S>,

/// Previous outputs of managed states. Initializing with `None`.
prev_outputs: Option<OwnedRow>,
}
Expand Down Expand Up @@ -102,7 +98,6 @@ impl<S: StateStore> AggGroup<S> {
Ok(Self {
group_key,
states,
distinct_dedup: DistinctDeduplicater::new(agg_calls),
prev_outputs,
})
}
Expand All @@ -127,24 +122,13 @@ impl<S: StateStore> AggGroup<S> {

/// Apply input chunk to all managed agg states.
/// `visibilities` contains the row visibility of the input chunk for each agg call.
pub async fn apply_chunk(
pub fn apply_chunk(
&mut self,
storages: &mut [AggStateStorage<S>],
ops: &[Op],
columns: &[Column],
visibilities: Vec<Option<Bitmap>>,
distinct_dedup_tables: &mut HashMap<usize, StateTable<S>>,
) -> StreamExecutorResult<()> {
let visibilities = self
.distinct_dedup
.dedup_chunk(
ops,
columns,
visibilities,
distinct_dedup_tables,
self.group_key.as_ref(),
)
.await?;
let columns = columns.iter().map(|col| col.array_ref()).collect_vec();
for ((state, storage), visibility) in self
.states
Expand All @@ -163,7 +147,6 @@ impl<S: StateStore> AggGroup<S> {
pub async fn flush_state_if_needed(
&self,
storages: &mut [AggStateStorage<S>],
distinct_dedup_tables: &mut HashMap<usize, StateTable<S>>,
) -> StreamExecutorResult<()> {
futures::future::try_join_all(self.states.iter().zip_eq_fast(storages).filter_map(
|(state, storage)| match state {
Expand All @@ -175,7 +158,6 @@ impl<S: StateStore> AggGroup<S> {
},
))
.await?;
self.distinct_dedup.flush(distinct_dedup_tables)?;
Ok(())
}

Expand Down
Loading