Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(join): joining on different types #3716

Merged
merged 9 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1455,8 +1455,8 @@ class PyMicroPartition:
right: PyMicroPartition,
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
how: JoinType,
null_equals_nulls: list[bool] | None = None,
Comment on lines -1458 to +1459
Copy link
Member Author

@kevinzwang kevinzwang Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small drive-by fix in typing to match rust definition

) -> PyMicroPartition: ...
def pivot(
self,
Expand Down Expand Up @@ -1584,6 +1584,11 @@ class AdaptivePhysicalPlanScheduler:
num_rows: int,
) -> None: ...

class JoinColumnRenamingParams:
def __new__(
cls, prefix: str | None, suffix: str | None, merge_matching_join_keys: bool
) -> JoinColumnRenamingParams: ...

class LogicalPlanBuilder:
"""A logical plan builder, which simplifies constructing logical plans via a fluent interface.

Expand Down Expand Up @@ -1642,9 +1647,8 @@ class LogicalPlanBuilder:
left_on: list[PyExpr],
right_on: list[PyExpr],
join_type: JoinType,
strategy: JoinStrategy | None = None,
join_prefix: str | None = None,
join_suffix: str | None = None,
join_strategy: JoinStrategy | None = None,
column_renaming_params: JoinColumnRenamingParams | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer deprecating join_prefix and join_suffix instead of just flat out removing them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm kind of also on the fence of leaving these as is for the dataframe api. IMO, it's a bit cleaner to do

df.join(df2, prefix="df2.", suffix="_joined")

instead of

from daft.daft import JoinColumnRenamingParams

df.join(df2, column_renaming_params=JoinColumnRenamingParams(prefix="df2.", suffix="_joined"))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API for the actual dataframe operation has not changed, this is just for the builder. Do you think we should be concerned about breaking the builder API?

) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
Expand Down
5 changes: 2 additions & 3 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, check_column_name_validity
from daft.daft import FileFormat, IOConfig, JoinColumnRenamingParams, JoinStrategy, JoinType, check_column_name_validity
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -1897,8 +1897,7 @@ def join(
right_on=right_exprs,
how=join_type,
strategy=join_strategy,
join_prefix=prefix,
join_suffix=suffix,
column_renaming_params=JoinColumnRenamingParams(prefix, suffix, True),
)
return DataFrame(builder)

Expand Down
7 changes: 3 additions & 4 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CountMode,
FileFormat,
IOConfig,
JoinColumnRenamingParams,
JoinStrategy,
JoinType,
PyDaftExecutionConfig,
Expand Down Expand Up @@ -257,17 +258,15 @@ def join( # type: ignore[override]
right_on: list[Expression],
how: JoinType = JoinType.Inner,
strategy: JoinStrategy | None = None,
join_suffix: str | None = None,
join_prefix: str | None = None,
column_renaming_params: JoinColumnRenamingParams | None = None,
) -> LogicalPlanBuilder:
builder = self._builder.join(
right._builder,
[expr._expr for expr in left_on],
[expr._expr for expr in right_on],
how,
strategy,
join_suffix,
join_prefix,
column_renaming_params,
)
return LogicalPlanBuilder(builder)

Expand Down
42 changes: 32 additions & 10 deletions src/arrow2/src/array/dyn_ord.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
use std::cmp::Ordering;

use num_traits::Float;
use ord::total_cmp;

use std::cmp::Ordering;

use crate::datatypes::*;
use crate::error::Error;
use crate::offset::Offset;
use crate::{array::*, types::NativeType};
use crate::{array::*, datatypes::*, error::Error, offset::Offset, types::NativeType};

/// Compare the values at two arbitrary indices in two arbitrary arrays.
pub type DynArrayComparator =
Box<dyn Fn(&dyn Array, &dyn Array, usize, usize) -> Ordering + Send + Sync>;

#[inline]
unsafe fn is_valid<A: Array>(arr: &A, i: usize) -> bool {
unsafe fn is_valid(arr: &dyn Array, i: usize) -> bool {
// avoid dyn function hop by using generic
arr.validity()
.as_ref()
Expand All @@ -22,9 +19,9 @@ unsafe fn is_valid<A: Array>(arr: &A, i: usize) -> bool {
}

#[inline]
fn compare_with_nulls<A: Array, F: FnOnce() -> Ordering>(
left: &A,
right: &A,
fn compare_with_nulls<F: FnOnce() -> Ordering>(
left: &dyn Array,
right: &dyn Array,
i: usize,
j: usize,
nulls_equal: bool,
Expand Down Expand Up @@ -122,6 +119,30 @@ fn compare_dyn_boolean(nulls_equal: bool) -> DynArrayComparator {
})
}

fn compare_dyn_null(nulls_equal: bool) -> DynArrayComparator {
Box::new(move |left, right, i, j| {
assert!(i < left.len());
assert!(j < right.len());
// need the extra datatype check in match because the validity of a null array
// is quizzically always true and not false
match (
unsafe { is_valid(left, i) } && *left.data_type() != DataType::Null,
unsafe { is_valid(right, j) } && *right.data_type() != DataType::Null,
) {
(true, true) => unreachable!(),
(false, true) => Ordering::Greater,
(true, false) => Ordering::Less,
(false, false) => {
if nulls_equal {
Ordering::Equal
} else {
Ordering::Less
}
}
}
})
}

pub fn build_dyn_array_compare(
left: &DataType,
right: &DataType,
Expand Down Expand Up @@ -187,6 +208,7 @@ pub fn build_dyn_array_compare(
// }
// }
// }
(Null, _) | (_, Null) => compare_dyn_null(nulls_equal),
(lhs, _) => {
return Err(Error::InvalidArgumentError(format!(
"The data type type {lhs:?} has no natural order"
Expand Down
14 changes: 10 additions & 4 deletions src/arrow2/src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

use std::cmp::Ordering;

use crate::datatypes::*;
use crate::error::Error;
use crate::offset::Offset;
use crate::{array::*, types::NativeType};
use crate::{array::*, datatypes::*, error::Error, offset::Offset, types::NativeType};

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
Expand Down Expand Up @@ -157,6 +154,14 @@ macro_rules! dyn_dict {
}};
}

fn compare_null(_left: &dyn Array, _right: &dyn Array) -> DynComparator {
Box::new(move |_i: usize, _j: usize| {
// nulls do not have a canonical ordering, but it is trivially implemented so that
// null arrays can be used in things that depend on `build_compare`
Ordering::Less
})
}

/// returns a comparison function that compares values at two different slots
/// between two [`Array`].
/// # Example
Expand Down Expand Up @@ -243,6 +248,7 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
}
}
}
(Null, _) | (_, Null) => compare_null(left, right),
(lhs, _) => {
return Err(Error::InvalidArgumentError(format!(
"The data type type {lhs:?} has no natural order"
Expand Down
36 changes: 32 additions & 4 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

use std::{
any::Any,
collections::HashSet,
hash::{DefaultHasher, Hash, Hasher},
io::{self, Write},
str::FromStr,
Expand All @@ -21,7 +22,6 @@
utils::supertype::try_get_supertype,
};
use derive_more::Display;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use super::functions::FunctionExpr;
Expand Down Expand Up @@ -1320,9 +1320,9 @@
// Check if one set of columns is a reordering of the other
pub fn is_partition_compatible(a: &[ExprRef], b: &[ExprRef]) -> bool {
// sort a and b by name
let a: Vec<&str> = a.iter().map(|a| a.name()).sorted().collect();
let b: Vec<&str> = b.iter().map(|a| a.name()).sorted().collect();
a == b
let a_set: HashSet<&ExprRef> = HashSet::from_iter(a);
let b_set: HashSet<&ExprRef> = HashSet::from_iter(b);
a_set == b_set
}

pub fn has_agg(expr: &ExprRef) -> bool {
Expand Down Expand Up @@ -1443,3 +1443,31 @@
.collect::<DaftResult<_>>()?;
Ok(Arc::new(Schema::new(fields)?))
}

/// Adds aliases as appropriate to ensure that all expressions have unique names.
pub fn deduplicate_expr_names(exprs: &[ExprRef]) -> Vec<ExprRef> {
let mut names_so_far = HashSet::new();

exprs
.iter()
.map(|e| {
let curr_name = e.name();

let mut i = 0;
let mut new_name = curr_name.to_string();

while names_so_far.contains(&new_name) {
i += 1;
new_name = format!("{}_{}", curr_name, i);
}

Check warning on line 1462 in src/daft-dsl/src/expr/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/expr/mod.rs#L1460-L1462

Added lines #L1460 - L1462 were not covered by tests

names_so_far.insert(new_name.clone());

if i == 0 {
e.clone()
} else {
e.alias(new_name)

Check warning on line 1469 in src/daft-dsl/src/expr/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/expr/mod.rs#L1469

Added line #L1469 was not covered by tests
}
})
.collect()
}
99 changes: 99 additions & 0 deletions src/daft-dsl/src/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use common_error::DaftResult;
use daft_core::{prelude::*, utils::supertype::try_get_supertype};
use indexmap::IndexSet;

use crate::{deduplicate_expr_names, ExprRef};

pub fn get_common_join_cols<'a>(
Copy link
Member Author

@kevinzwang kevinzwang Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Common columns now determined via schema instead of join keys so that join keys can be modified (ex: casting to supertype) without side effects. This does introduce a small API change in the order of the join schema: common columns are now sorted by left side schema instead of by join keys. I think this is fine but we could introduce a project under the left side to reorder if necessary.

left_schema: &'a SchemaRef,
right_schema: &'a SchemaRef,
) -> impl Iterator<Item = &'a String> {
left_schema
.fields
.keys()
.filter(|name| right_schema.has_field(name))
}

/// Infer the schema of a join operation
pub fn infer_join_schema(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
join_type: JoinType,
) -> DaftResult<SchemaRef> {
if matches!(join_type, JoinType::Anti | JoinType::Semi) {
Ok(left_schema.clone())
} else {
let common_cols = get_common_join_cols(left_schema, right_schema).collect::<IndexSet<_>>();

// common columns, then unique left fields, then unique right fields
let fields = common_cols
.iter()
.map(|name| {
let left_field = left_schema.get_field(name).unwrap();
let right_field = right_schema.get_field(name).unwrap();

Ok(match join_type {
JoinType::Inner => left_field.clone(),
JoinType::Left => left_field.clone(),
JoinType::Right => right_field.clone(),
JoinType::Outer => {
let supertype = try_get_supertype(&left_field.dtype, &right_field.dtype)?;

Field::new(*name, supertype)
}
JoinType::Anti | JoinType::Semi => unreachable!(),

Check warning on line 44 in src/daft-dsl/src/join.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/join.rs#L44

Added line #L44 was not covered by tests
})
})
.chain(
left_schema
.fields
.iter()
.chain(right_schema.fields.iter())
.filter_map(|(name, field)| {
if common_cols.contains(name) {
None
} else {
Some(field.clone())
}
})
.map(Ok),
)
.collect::<DaftResult<_>>()?;

Ok(Schema::new(fields)?.into())
}
}

/// Casts join keys to the same types and make their names unique.
pub fn normalize_join_keys(
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
left_schema: SchemaRef,
right_schema: SchemaRef,
) -> DaftResult<(Vec<ExprRef>, Vec<ExprRef>)> {
let (left_on, right_on) = left_on
.into_iter()
.zip(right_on)
.map(|(mut l, mut r)| {
let l_dtype = l.to_field(&left_schema)?.dtype;
let r_dtype = r.to_field(&right_schema)?.dtype;

let supertype = try_get_supertype(&l_dtype, &r_dtype)?;

if l_dtype != supertype {
l = l.cast(&supertype);
}

if r_dtype != supertype {
r = r.cast(&supertype);
}

Ok((l, r))
})
.collect::<DaftResult<(Vec<_>, Vec<_>)>>()?;

let left_on = deduplicate_expr_names(&left_on);
let right_on = deduplicate_expr_names(&right_on);
Comment on lines +95 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we make deduplicate_expr_names take in an iter so we don't have to materialize twice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to materialize here anyway to get the DaftResult out of the iterator as well as split the iterator into two vecs.


Ok((left_on, right_on))
}
Loading
Loading