Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] agg_concat doesn't work on strings #2847

Merged
merged 8 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions src/daft-core/src/array/ops/concat_agg.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use arrow2::{bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index};
use arrow2::{
array::{Array, Utf8Array},
bitmap::utils::SlicesIterator,
offset::OffsetsBuffer,
types::Index,
};
use common_error::DaftResult;

use super::{as_arrow::AsArrow, DaftConcatAggable};
use crate::array::{
growable::{make_growable, Growable},
ListArray,
use crate::{
array::{
growable::{make_growable, Growable},
DataArray, ListArray,
},
prelude::Utf8Type,
};

#[cfg(feature = "python")]
Expand Down Expand Up @@ -146,6 +154,67 @@ impl DaftConcatAggable for ListArray {
}
}

impl DaftConcatAggable for DataArray<Utf8Type> {
type Output = DaftResult<Self>;

fn concat(&self) -> Self::Output {
let new_validity = match self.validity() {
Some(validity) if validity.unset_bits() == self.len() => {
Some(arrow2::bitmap::Bitmap::from(vec![false]))
}
_ => None,
};

let arrow_array = self.as_arrow();
let new_offsets = OffsetsBuffer::<i64>::try_from(vec![0, *arrow_array.offsets().last()])?;
let output = Utf8Array::new(
arrow_array.data_type().clone(),
new_offsets,
arrow_array.values().clone(),
new_validity,
);

let result_box = Box::new(output);
DataArray::new(self.field().clone().into(), result_box)
}

fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output {
let arrow_array = self.as_arrow();
let concat_per_group = if arrow_array.null_count() > 0 {
Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| {
let to_concat = g
.iter()
.filter_map(|index| {
let idx = *index as usize;
arrow_array.get(idx)
})
.collect::<Vec<&str>>();
if to_concat.is_empty() {
None
} else {
Some(to_concat.concat())
}
})))
} else {
Box::new(Utf8Array::from_trusted_len_values_iter(groups.iter().map(
|g| {
g.iter()
.map(|index| {
let idx = *index as usize;
arrow_array.value(idx)
})
.collect::<String>()
},
)))
};

Ok(DataArray::from((
self.field.name.as_ref(),
concat_per_group,
)))
}
}

#[cfg(test)]
mod test {
use std::iter::repeat;
Expand Down
11 changes: 10 additions & 1 deletion src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,17 @@ impl Series {
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
}
}
DataType::Utf8 => {
let downcasted = self.downcast::<Utf8Array>()?;
match groups {
Some(groups) => {
Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series())
}
None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()),
}
}
_ => Err(DaftError::TypeError(format!(
"concat aggregation is only valid for List or Python types, got {}",
"concat aggregation is only valid for List, Python types, or Utf8, got {}",
self.data_type()
))),
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ impl AggExpr {
let field = expr.to_field(schema)?;
match field.dtype {
DataType::List(..) => Ok(field),
DataType::Utf8 => Ok(field),
#[cfg(feature = "python")]
DataType::Python => Ok(field),
_ => Err(DaftError::TypeError(format!(
Expand Down
50 changes: 50 additions & 0 deletions tests/table/test_table_aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,53 @@ def test_groupby_struct(dtype) -> None:
expected = [[0, 1, 4], [2, 6], [3, 5]]
for lt in expected:
assert lt in res["b"]


def test_agg_concat_on_string() -> None:
df3 = from_pydict({"a": ["the", " quick", " brown", " fox"]})
res = df3.agg(col("a").agg_concat()).to_pydict()
assert res["a"] == ["the quick brown fox"]


def test_agg_concat_on_string_groupby() -> None:
df3 = from_pydict({"a": ["the", " quick", " brown", " fox"], "b": [1, 2, 1, 2]})
res = df3.groupby("b").agg_concat("a").to_pydict()
expected = ["the brown", " quick fox"]
for txt in expected:
assert txt in res["a"]


def test_agg_concat_on_string_null() -> None:
df3 = from_pydict({"a": ["the", " quick", None, " fox"]})
res = df3.agg(col("a").agg_concat()).to_pydict()
expected = ["the quick fox"]
assert res["a"] == expected


def test_agg_concat_on_string_groupby_null() -> None:
df3 = from_pydict({"a": ["the", " quick", None, " fox"], "b": [1, 2, 1, 2]})
res = df3.groupby("b").agg_concat("a").to_pydict()
expected = ["the", " quick fox"]
for txt in expected:
assert txt in res["a"]


def test_agg_concat_on_string_null_list() -> None:
df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column(
"a", col("a").cast(DataType.string())
)
res = df3.agg(col("a").agg_concat()).to_pydict()
print(res)
expected = [None]
assert res["a"] == expected
assert len(res["a"]) == 1


def test_agg_concat_on_string_groupby_null_list() -> None:
df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column(
"a", col("a").cast(DataType.string())
)
res = df3.groupby("b").agg_concat("a").to_pydict()
expected = [None, None]
assert res["a"] == expected
assert len(res["a"]) == len(expected)
Loading