Skip to content

Commit

Permalink
Handle table reuse in semi and anti join (#6059)
Browse files Browse the repository at this point in the history
cargo fmt

cargo clippy --fix

cleanup
  • Loading branch information
nseekhao authored Apr 23, 2023
1 parent 72d7db2 commit abbfdce
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 6 deletions.
20 changes: 17 additions & 3 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,25 @@ pub async fn from_substrait_rel(
from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?,
);
let join_type = from_substrait_jointype(join.r#type)?;
let schema =
build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;
// The join condition expression needs full input schema and not the output schema from join since we lose columns from
// certain join types such as semi and anti joins
// - if left and right schemas are different, we combine (join) the schema to include all fields
// - if left and right schemas are the same, we handle the duplicate fields by using `build_join_schema()`, which discard the unused schema
// TODO: Handle duplicate fields error for other join types (non-semi/anti). The current approach does not work due to Substrait's inability
// to encode aliases
let join_schema = match left.schema().join(right.schema()) {
Ok(schema) => Ok(schema),
Err(DataFusionError::SchemaError(
datafusion::common::SchemaError::DuplicateQualifiedField {
qualifier: _,
name: _,
},
)) => build_join_schema(left.schema(), right.schema(), &join_type),
Err(e) => Err(e),
};
let on = from_substrait_rex(
join.expression.as_ref().unwrap(),
&schema,
&join_schema?,
extensions,
)
.await?;
Expand Down
15 changes: 12 additions & 3 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,16 @@ pub fn to_substrait_rel(
// join schema from left and right to maintain all nececesary columns from inputs
// note that we cannot simple use join.schema here since we discard some input columns
// when performing semi and anti joins
let join_schema = join.left.schema().join(join.right.schema());
let join_schema = match join.left.schema().join(join.right.schema()) {
Ok(schema) => Ok(schema),
Err(DataFusionError::SchemaError(
datafusion::common::SchemaError::DuplicateQualifiedField {
qualifier: _,
name: _,
},
)) => Ok(join.schema.as_ref().clone()),
Err(e) => Err(e),
};
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
Expand Down Expand Up @@ -1329,11 +1338,11 @@ mod test {
}

fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
println!("Checking round trip of {:?}", scalar);
println!("Checking round trip of {scalar:?}");

let substrait = to_substrait_literal(&scalar)?;
let Expression { rex_type: Some(RexType::Literal(substrait_literal)) } = substrait else {
panic!("Expected Literal expression, got {:?}", substrait);
panic!("Expected Literal expression, got {substrait:?}");
};

let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
Expand Down
13 changes: 13 additions & 0 deletions datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,19 @@ mod tests {
.await
}

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);",
"Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n LeftSemi Join: data.a = data.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data projection=[a]",
)
.await
}

#[tokio::test]
async fn simple_window_function() -> Result<()> {
roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b) OVER (PARTITION BY a) FROM data;").await
Expand Down

0 comments on commit abbfdce

Please sign in to comment.