Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine causal window frames to produce early results. #8842

Merged
merged 15 commits into from
Jan 15, 2024
2 changes: 1 addition & 1 deletion datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async fn main() -> Result<()> {
vec![col("speed")], // smooth_it(speed)
vec![col("car")], // PARTITION BY car
vec![col("time").sort(true, true)], // ORDER BY time ASC
WindowFrame::new(false),
WindowFrame::new(None),
);
let df = ctx.table("cars").await?.window(vec![window_expr])?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ mod tests {
vec![col("aggregate_test_100.c1")],
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub fn bounded_window_exec(
&[col(col_name, &schema).unwrap()],
&[],
&sort_exprs,
Arc::new(WindowFrame::new(true)),
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
)
.unwrap()],
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ async fn test_count_wildcard_on_window() -> Result<()> {
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
WindowFrame {
units: WindowFrameUnits::Range,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
},
WindowFrame::try_new(
WindowFrameUnits::Range,
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
)?,
))])?
.explain(false, false)?
.collect()
Expand Down
14 changes: 4 additions & 10 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,8 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame {
} else {
WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
let mut window_frame =
WindowFrame::try_new(units, start_bound, end_bound).unwrap();
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
Expand Down Expand Up @@ -375,11 +372,8 @@ fn get_random_window_frame(rng: &mut StdRng, is_linear: bool) -> WindowFrame {
end_bound.val as u64,
)))
};
let mut window_frame = WindowFrame {
units,
start_bound,
end_bound,
};
let mut window_frame =
WindowFrame::try_new(units, start_bound, end_bound).unwrap();
// with 10% use unbounded preceding in tests
if rng.gen_range(0..10) == 0 {
window_frame.start_bound =
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ where
/// vec![col("speed")], // smooth_it(speed)
/// vec![col("car")], // PARTITION BY car
/// vec![col("time").sort(true, true)], // ORDER BY time ASC
/// WindowFrame::new(false),
/// WindowFrame::new(None),
/// );
/// ```
pub trait WindowUDFImpl {
Expand Down
20 changes: 10 additions & 10 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,28 +1252,28 @@ mod tests {
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
Expand All @@ -1295,28 +1295,28 @@ mod tests {
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
vec![col("name")],
vec![],
vec![],
WindowFrame::new(false),
WindowFrame::new(None),
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
vec![col("name")],
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
vec![col("age")],
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
Expand Down Expand Up @@ -1350,7 +1350,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
Expand All @@ -1361,7 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(true),
WindowFrame::new(Some(false)),
)),
];
let expected = vec![
Expand Down
101 changes: 88 additions & 13 deletions datafusion/expr/src/window_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,24 @@ use sqlparser::ast;
use sqlparser::parser::ParserError::ParserError;
use std::convert::{From, TryFrom};
use std::fmt;
use std::fmt::Formatter;
use std::hash::Hash;

/// The frame-spec determines which output rows are read by an aggregate window function.
///
/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the
/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to
/// CURRENT ROW.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct WindowFrame {
/// A frame type - either ROWS, RANGE or GROUPS
pub units: WindowFrameUnits,
/// A starting frame boundary
pub start_bound: WindowFrameBound,
/// An ending frame boundary
pub end_bound: WindowFrameBound,
/// Flag indicates whether window frame is causal.
is_causal: bool,
}

impl fmt::Display for WindowFrame {
Expand All @@ -58,6 +61,17 @@ impl fmt::Display for WindowFrame {
}
}

impl fmt::Debug for WindowFrame {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"WindowFrame {{ units: {:?}, start_bound: {:?}, end_bound: {:?} }}",
self.units, self.start_bound, self.end_bound
)?;
Ok(())
}
}

impl TryFrom<ast::WindowFrame> for WindowFrame {
type Error = DataFusionError;

Expand All @@ -81,35 +95,47 @@ impl TryFrom<ast::WindowFrame> for WindowFrame {
)?
}
};
let units = value.units.into();
let is_causal = is_frame_causal(&units, &end_bound)?;
Ok(Self {
units: value.units.into(),
units,
start_bound,
end_bound,
is_causal,
})
}
}

impl WindowFrame {
/// Creates a new, default window frame (with the meaning of default depending on whether the
/// frame contains an `ORDER BY` clause.
pub fn new(has_order_by: bool) -> Self {
if has_order_by {
// This window frame covers the table (or partition if `PARTITION BY` is used)
// from beginning to the `CURRENT ROW` (with same rank). It is used when the `OVER`
// clause contains an `ORDER BY` clause but no frame.
/// Creates a new, default window frame (with the meaning of default
/// depending on whether the frame contains an `ORDER BY` clause and this
/// ordering is strict (i.e. no ties).
pub fn new(order_by: Option<bool>) -> Self {
if let Some(strict) = order_by {
// This window frame covers the table (or partition if `PARTITION BY`
// is used) from beginning to the `CURRENT ROW` (with same rank). It
// is used when the `OVER` clause contains an `ORDER BY` clause but
// no frame.
WindowFrame {
units: WindowFrameUnits::Range,
units: if strict {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
},
start_bound: WindowFrameBound::Preceding(ScalarValue::Null),
end_bound: WindowFrameBound::CurrentRow,
// When mode is Rows, it is causal when mode is Range it is not
is_causal: strict,
}
} else {
// This window frame covers the whole table (or partition if `PARTITION BY` is used).
// It is used when the `OVER` clause does not contain an `ORDER BY` clause and there is
// no frame.
// This window frame covers the whole table (or partition if `PARTITION BY`
// is used). It is used when the `OVER` clause does not contain an
// `ORDER BY` clause and there is no frame.
WindowFrame {
units: WindowFrameUnits::Rows,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(None)),
is_causal: false,
}
}
}
Expand All @@ -136,12 +162,61 @@ impl WindowFrame {
}
WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
};
// Units and end bound types do not change, cannot produce error.
let is_causal = is_frame_causal(&self.units, &end_bound).unwrap();
WindowFrame {
units: self.units,
start_bound,
end_bound,
is_causal,
}
}

/// Get whether window frame is causal
pub fn is_causal(&self) -> bool {
self.is_causal
}

/// Initializes window frame from units (window bound type), start bound and end bound
pub fn try_new(
units: WindowFrameUnits,
start_bound: WindowFrameBound,
end_bound: WindowFrameBound,
) -> Result<Self> {
let is_causal = is_frame_causal(&units, &end_bound)?;
Ok(WindowFrame {
units,
start_bound,
end_bound,
is_causal,
})
}
}

/// Calculate whether window frame is causal or not.
fn is_frame_causal(
frame_units: &WindowFrameUnits,
end_bound: &WindowFrameBound,
) -> Result<bool> {
Ok(match frame_units {
WindowFrameUnits::Rows => matches!(
end_bound,
WindowFrameBound::Preceding(_) | WindowFrameBound::CurrentRow
),
WindowFrameUnits::Range | WindowFrameUnits::Groups => match end_bound {
WindowFrameBound::Preceding(val) => {
// val can be either numeric type or Utf8 type (which is initial type after parsing)
// In subsequent stages, Utf8 type converted to the appropriate types.
if let ScalarValue::Utf8(Some(val)) = val {
val != "0"
} else {
let zero = ScalarValue::new_zero(&val.data_type())?;
val.gt(&zero)
}
}
_ => false,
},
})
}

/// Regularizes ORDER BY clause for window definition for implicit corner cases.
Expand Down
30 changes: 15 additions & 15 deletions datafusion/expr/src/window_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,11 @@ mod tests {

#[test]
fn test_window_frame_group_boundaries() -> Result<()> {
let window_frame = Arc::new(WindowFrame {
units: WindowFrameUnits::Groups,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
});
let window_frame = Arc::new(WindowFrame::try_new(
WindowFrameUnits::Groups,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
)?);
let expected_results = vec![
(Range { start: 0, end: 2 }, 0),
(Range { start: 0, end: 4 }, 1),
Expand All @@ -703,11 +703,11 @@ mod tests {

#[test]
fn test_window_frame_group_boundaries_both_following() -> Result<()> {
let window_frame = Arc::new(WindowFrame {
units: WindowFrameUnits::Groups,
start_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
});
let window_frame = Arc::new(WindowFrame::try_new(
WindowFrameUnits::Groups,
WindowFrameBound::Following(ScalarValue::UInt64(Some(1))),
WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
)?);
let expected_results = vec![
(Range::<usize> { start: 1, end: 4 }, 0),
(Range::<usize> { start: 2, end: 5 }, 1),
Expand All @@ -724,11 +724,11 @@ mod tests {

#[test]
fn test_window_frame_group_boundaries_both_preceding() -> Result<()> {
let window_frame = Arc::new(WindowFrame {
units: WindowFrameUnits::Groups,
start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
end_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
});
let window_frame = Arc::new(WindowFrame::try_new(
WindowFrameUnits::Groups,
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))),
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))),
)?);
let expected_results = vec![
(Range::<usize> { start: 0, end: 0 }, 0),
(Range::<usize> { start: 0, end: 1 }, 1),
Expand Down
Loading