-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(join): joining on different types #3716
Changes from 6 commits
4aa0032
a4d8eb6
767e697
c4e3380
01f53a9
3dc2398
91b4aae
fbbdac7
449742e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1455,8 +1455,8 @@ class PyMicroPartition: | |
right: PyMicroPartition, | ||
left_on: list[PyExpr], | ||
right_on: list[PyExpr], | ||
null_equals_nulls: list[bool] | None, | ||
how: JoinType, | ||
null_equals_nulls: list[bool] | None = None, | ||
) -> PyMicroPartition: ... | ||
def pivot( | ||
self, | ||
|
@@ -1584,6 +1584,11 @@ class AdaptivePhysicalPlanScheduler: | |
num_rows: int, | ||
) -> None: ... | ||
|
||
class JoinColumnRenamingParams: | ||
def __new__( | ||
cls, prefix: str | None, suffix: str | None, merge_matching_join_keys: bool | ||
) -> JoinColumnRenamingParams: ... | ||
|
||
class LogicalPlanBuilder: | ||
"""A logical plan builder, which simplifies constructing logical plans via a fluent interface. | ||
|
||
|
@@ -1642,9 +1647,8 @@ class LogicalPlanBuilder: | |
left_on: list[PyExpr], | ||
right_on: list[PyExpr], | ||
join_type: JoinType, | ||
strategy: JoinStrategy | None = None, | ||
join_prefix: str | None = None, | ||
join_suffix: str | None = None, | ||
join_strategy: JoinStrategy | None = None, | ||
column_renaming_params: JoinColumnRenamingParams | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'd prefer deprecating There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm kind of also on the fence of leaving these as is for the dataframe api. IMO, it's a bit cleaner to do df.join(df2, prefix="df2.", suffix="_joined") instead of from daft.daft import JoinColumnRenamingParams
df.join(df2, column_renaming_params=JoinColumnRenamingParams(prefix="df2.", suffix="_joined")) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The API for the actual dataframe operation has not changed, this is just for the builder. Do you think we should be concerned about breaking the builder API? |
||
) -> LogicalPlanBuilder: ... | ||
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ... | ||
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ... | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
use common_error::DaftResult; | ||
use daft_core::{prelude::*, utils::supertype::try_get_supertype}; | ||
use indexmap::IndexSet; | ||
|
||
use crate::{deduplicate_expr_names, ExprRef}; | ||
|
||
pub fn get_common_join_cols<'a>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Common columns now determined via schema instead of join keys so that join keys can be modified (ex: casting to supertype) without side effects. This does introduce a small API change in the order of the join schema: common columns are now sorted by left side schema instead of by join keys. I think this is fine but we could introduce a project under the left side to reorder if necessary. |
||
left_schema: &'a SchemaRef, | ||
right_schema: &'a SchemaRef, | ||
) -> impl Iterator<Item = &'a String> { | ||
left_schema | ||
.fields | ||
.keys() | ||
.filter(|name| right_schema.has_field(name)) | ||
} | ||
|
||
/// Infer the schema of a join operation | ||
pub fn infer_join_schema( | ||
left_schema: &SchemaRef, | ||
right_schema: &SchemaRef, | ||
join_type: JoinType, | ||
) -> DaftResult<SchemaRef> { | ||
if matches!(join_type, JoinType::Anti | JoinType::Semi) { | ||
Ok(left_schema.clone()) | ||
} else { | ||
let common_cols = get_common_join_cols(left_schema, right_schema).collect::<IndexSet<_>>(); | ||
|
||
// common columns, then unique left fields, then unique right fields | ||
let fields = common_cols | ||
.iter() | ||
.map(|name| { | ||
let left_field = left_schema.get_field(name).unwrap(); | ||
let right_field = right_schema.get_field(name).unwrap(); | ||
|
||
Ok(match join_type { | ||
JoinType::Inner => left_field.clone(), | ||
JoinType::Left => left_field.clone(), | ||
JoinType::Right => right_field.clone(), | ||
JoinType::Outer => { | ||
let supertype = try_get_supertype(&left_field.dtype, &right_field.dtype)?; | ||
|
||
Field::new(*name, supertype) | ||
} | ||
JoinType::Anti | JoinType::Semi => unreachable!(), | ||
}) | ||
}) | ||
.chain( | ||
left_schema | ||
.fields | ||
.iter() | ||
.chain(right_schema.fields.iter()) | ||
.filter_map(|(name, field)| { | ||
if common_cols.contains(name) { | ||
None | ||
} else { | ||
Some(field.clone()) | ||
} | ||
}) | ||
.map(Ok), | ||
) | ||
.collect::<DaftResult<_>>()?; | ||
|
||
Ok(Schema::new(fields)?.into()) | ||
} | ||
} | ||
|
||
/// Casts join keys to the same types and make their names unique. | ||
pub fn normalize_join_keys( | ||
left_on: Vec<ExprRef>, | ||
right_on: Vec<ExprRef>, | ||
left_schema: SchemaRef, | ||
right_schema: SchemaRef, | ||
) -> DaftResult<(Vec<ExprRef>, Vec<ExprRef>)> { | ||
let (left_on, right_on) = left_on | ||
.into_iter() | ||
.zip(right_on) | ||
.map(|(mut l, mut r)| { | ||
let l_dtype = l.to_field(&left_schema)?.dtype; | ||
let r_dtype = r.to_field(&right_schema)?.dtype; | ||
|
||
let supertype = try_get_supertype(&l_dtype, &r_dtype)?; | ||
|
||
if l_dtype != supertype { | ||
l = l.cast(&supertype); | ||
} | ||
|
||
if r_dtype != supertype { | ||
r = r.cast(&supertype); | ||
} | ||
|
||
Ok((l, r)) | ||
}) | ||
.collect::<DaftResult<(Vec<_>, Vec<_>)>>()?; | ||
|
||
let left_on = deduplicate_expr_names(&left_on); | ||
let right_on = deduplicate_expr_names(&right_on); | ||
Comment on lines
+95
to
+96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: can we make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to materialize here anyway to get the DaftResult out of the iterator as well as split the iterator into two vecs. |
||
|
||
Ok((left_on, right_on)) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small drive-by fix in typing to match rust definition