From 521c547610267a6e247e58af2a5287a7784c6483 Mon Sep 17 00:00:00 2001 From: vicky1999 Date: Mon, 16 Sep 2024 22:43:55 +0530 Subject: [PATCH 1/6] WIP: on strings --- src/daft-core/src/array/ops/concat_agg.rs | 46 +++++++++++++++++++++-- src/daft-core/src/series/ops/agg.rs | 11 +++++- src/daft-dsl/src/expr.rs | 1 + tests/table/test_table_aggs.py | 30 +++++++++++++++ 4 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index ca81098a9b..59c4ad405f 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -1,6 +1,9 @@ -use crate::array::{ - growable::{make_growable, Growable}, - ListArray, +use crate::{ + array::{ + growable::{make_growable, Growable}, + DataArray, ListArray, + }, + prelude::Utf8Type, }; use arrow2::{bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index}; use common_error::DaftResult; @@ -146,6 +149,43 @@ impl DaftConcatAggable for ListArray { } } +impl DaftConcatAggable for DataArray { + type Output = DaftResult; + + fn concat(&self) -> Self::Output { + let mut concat_result = String::new(); + + for idx in 0..self.len() { + if self.get(idx).is_some() { + let x = self.get(idx).unwrap().to_string(); + concat_result.push_str(&x); + } + } + let result_box = Box::new(arrow2::array::Utf8Array::::from_slice([concat_result])); + + DataArray::new(self.field().clone().into(), result_box) + } + + fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { + let mut result_data = vec![]; + for group in groups { + let mut group_res = String::new(); + for idx in group { + let ind: usize = idx.to_usize(); + if self.get(ind).is_some() { + let x = self.get(ind).unwrap().to_string(); + group_res.push_str(&x); + } + } + result_data.push(group_res); + } + + let result_box = Box::new(arrow2::array::Utf8Array::::from_slice(result_data)); + + DataArray::new(self.field().clone().into(), result_box) + } +} + #[cfg(test)] mod test { use std::iter::repeat; diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 5042b57ce9..84f5263a08 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -244,8 +244,17 @@ impl Series { None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), } } + DataType::Utf8 => { + let downcasted = self.downcast::()?; + match groups { + Some(groups) => { + Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) + } + None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), + } + } _ => Err(DaftError::TypeError(format!( - "concat aggregation is only valid for List or Python types, got {}", + "concat aggregation is only valid for List, Python types, or Utf8, got {}", self.data_type() ))), } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 0afbe81c2e..b908588b43 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -392,6 +392,7 @@ impl AggExpr { let field = expr.to_field(schema)?; match field.dtype { DataType::List(..) => Ok(field), + DataType::Utf8 => Ok(field), #[cfg(feature = "python")] DataType::Python => Ok(field), _ => Err(DaftError::TypeError(format!( diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index fa7a26b3e4..2cbb4bb029 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -874,3 +874,33 @@ def test_groupby_struct(dtype) -> None: expected = [[0, 1, 4], [2, 6], [3, 5]] for lt in expected: assert lt in res["b"] + + +def test_agg_concat_on_string() -> None: + df3 = from_pydict({"a": ["the", " quick", " brown", " fox"]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + assert res["a"] == ["the quick brown fox"] + + +def test_agg_concat_on_string_groupby() -> None: + df3 = from_pydict({"a": ["the", " quick", " brown", " fox"], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = ["the brown", " quick fox"] + for txt in expected: + assert txt in res["a"] + + +def test_agg_concat_on_string_null() -> None: + df3 = from_pydict({"a": ["the", " quick", None, " fox"]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + expected = ["the quick fox"] + for txt in expected: + assert txt in res["a"] + + +def test_agg_concat_on_string_groupby_null() -> None: + df3 = from_pydict({"a": ["the", " quick", None, " fox"], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = ["the", " quick fox"] + for txt in expected: + assert txt in res["a"] From a0dbc548ef7651eb7fa7707b9a9ec2381438f886 Mon Sep 17 00:00:00 2001 From: vicky1999 Date: Thu, 19 Sep 2024 19:24:39 +0530 Subject: [PATCH 2/6] [Fix]: string concat optimization --- src/daft-core/src/array/ops/concat_agg.rs | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index d69dfa92bd..90e6eb94b7 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -1,4 +1,6 @@ -use arrow2::{bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index}; +use arrow2::{ + array::Utf8Array, bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index, +}; use common_error::DaftResult; use super::{as_arrow::AsArrow, DaftConcatAggable}; @@ -153,36 +155,34 @@ impl DaftConcatAggable for DataArray { type Output = DaftResult; fn concat(&self) -> Self::Output { - let mut concat_result = String::new(); - - for idx in 0..self.len() { - if self.get(idx).is_some() { - let x = self.get(idx).unwrap().to_string(); - concat_result.push_str(&x); - } - } - let result_box = Box::new(arrow2::array::Utf8Array::::from_slice([concat_result])); + let arrow_array = self.as_arrow(); + let new_offsets = OffsetsBuffer::::try_from(vec![0, *arrow_array.offsets().last()])?; + let output = Utf8Array::new( + arrow_array.data_type().clone(), + new_offsets, + arrow_array.values().clone(), + None, + ); + let result_box = Box::new(output); DataArray::new(self.field().clone().into(), result_box) } fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { - let mut result_data = vec![]; - for group in groups { - let mut group_res = String::new(); - for idx in group { - let ind: usize = idx.to_usize(); - if self.get(ind).is_some() { - let x = self.get(ind).unwrap().to_string(); - group_res.push_str(&x); - } - } - result_data.push(group_res); - } - - let result_box = Box::new(arrow2::array::Utf8Array::::from_slice(result_data)); - - DataArray::new(self.field().clone().into(), result_box) + let arrow_array = self.as_arrow(); + let concat_per_group = Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| { + let mut group_res = vec![]; + g.iter().for_each(|index| { + let idx = *index as usize; + group_res.push(arrow_array.value(idx)); + }); + Some(group_res.concat()) + }))); + + Ok(DataArray::from(( + self.field.name.as_ref(), + concat_per_group, + ))) } } From 145d8ca61f253712d8a4d7191bf3ec657f814019 Mon Sep 17 00:00:00 2001 From: vicky1999 Date: Wed, 25 Sep 2024 02:12:41 +0530 Subject: [PATCH 3/6] [Fix]: Null array testcase handled --- src/daft-core/src/array/ops/concat_agg.rs | 35 +++++++++++++++++++++-- src/daft-core/src/series/ops/agg.rs | 9 ++++++ src/daft-dsl/src/expr.rs | 1 + tests/table/test_table_aggs.py | 17 +++++++++++ 4 files changed, 59 insertions(+), 3 deletions(-) diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 90e6eb94b7..40f4b3b957 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -1,5 +1,8 @@ use arrow2::{ - array::Utf8Array, bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index, + array::{Array, NullArray, Utf8Array}, + bitmap::utils::SlicesIterator, + offset::OffsetsBuffer, + types::Index, }; use common_error::DaftResult; @@ -9,7 +12,7 @@ use crate::{ growable::{make_growable, Growable}, DataArray, ListArray, }, - prelude::Utf8Type, + prelude::{NullType, Utf8Type}, }; #[cfg(feature = "python")] @@ -155,13 +158,20 @@ impl DaftConcatAggable for DataArray { type Output = DaftResult; fn concat(&self) -> Self::Output { + let new_validity = match self.validity() { + Some(validity) if validity.unset_bits() == self.len() => { + Some(arrow2::bitmap::Bitmap::from(vec![false])) + } + _ => None, + }; + let arrow_array = self.as_arrow(); let new_offsets = OffsetsBuffer::::try_from(vec![0, *arrow_array.offsets().last()])?; let output = Utf8Array::new( arrow_array.data_type().clone(), new_offsets, arrow_array.values().clone(), - None, + new_validity, ); let result_box = Box::new(output); @@ -186,6 +196,25 @@ impl DaftConcatAggable for DataArray { } } +impl DaftConcatAggable for DataArray { + type Output = DaftResult; + + fn concat(&self) -> Self::Output { + let arrow_array = self.as_arrow(); + let result_box = Box::new(NullArray::new_null(arrow_array.data_type().clone(), 1)); + DataArray::new(self.field().clone().into(), result_box) + } + + fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { + let arrow_array = self.as_arrow(); + let result_box = Box::new(NullArray::new_null( + arrow_array.data_type().clone(), + groups.len(), + )); + DataArray::new(self.field().clone().into(), result_box) + } +} + #[cfg(test)] mod test { use std::iter::repeat; diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 353c6ca25d..802117a89c 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -253,6 +253,15 @@ impl Series { None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), } } + DataType::Null => { + let downcasted = self.downcast::()?; + match groups { + Some(groups) => { + Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) + } + None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), + } + } _ => Err(DaftError::TypeError(format!( "concat aggregation is only valid for List, Python types, or Utf8, got {}", self.data_type() diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index f8c5deb247..14558b7f52 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -391,6 +391,7 @@ impl AggExpr { match field.dtype { DataType::List(..) => Ok(field), DataType::Utf8 => Ok(field), + DataType::Null => Ok(field), #[cfg(feature = "python")] DataType::Python => Ok(field), _ => Err(DaftError::TypeError(format!( diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index 2cbb4bb029..969c2b6941 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -904,3 +904,20 @@ def test_agg_concat_on_string_groupby_null() -> None: expected = ["the", " quick fox"] for txt in expected: assert txt in res["a"] + + +def test_agg_concat_on_string_null_list() -> None: + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}) + res = df3.agg(col("a").agg_concat()).to_pydict() + print(res) + expected = [None] + assert res["a"] == expected + assert len(res["a"]) == 1 + + +def test_agg_concat_on_string_groupby_null_list() -> None: + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}) + res = df3.groupby("b").agg_concat("a").to_pydict() + expected = [None, None] + assert res["a"] == expected + assert len(res["a"]) == len(expected) From 96e835740cc465f8fc3d06730799565cf0a8e0d5 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 24 Sep 2024 18:39:52 -0700 Subject: [PATCH 4/6] Apply suggestions from code review --- src/daft-core/src/array/ops/concat_agg.rs | 53 ++++++++++++----------- src/daft-core/src/series/ops/agg.rs | 9 ---- src/daft-dsl/src/expr.rs | 1 - tests/table/test_table_aggs.py | 7 ++- 4 files changed, 30 insertions(+), 40 deletions(-) diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index 40f4b3b957..e57e8ff76d 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -180,14 +180,33 @@ impl DaftConcatAggable for DataArray { fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { let arrow_array = self.as_arrow(); - let concat_per_group = Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| { - let mut group_res = vec![]; - g.iter().for_each(|index| { - let idx = *index as usize; - group_res.push(arrow_array.value(idx)); - }); - Some(group_res.concat()) - }))); + let concat_per_group = if arrow_array.null_count() > 0 { + Box::new(Utf8Array::from_trusted_len_iter(groups.iter().map(|g| { + let to_concat = g + .iter() + .filter_map(|index| { + let idx = *index as usize; + arrow_array.get(idx) + }) + .collect::>(); + if to_concat.is_empty() { + None + } else { + Some(to_concat.concat()) + } + }))) + } else { + Box::new(Utf8Array::from_trusted_len_values_iter(groups.iter().map( + |g| { + g.iter() + .map(|index| { + let idx = *index as usize; + arrow_array.value(idx) + }) + .collect::() + }, + ))) + }; Ok(DataArray::from(( self.field.name.as_ref(), @@ -196,24 +215,6 @@ impl DaftConcatAggable for DataArray { } } -impl DaftConcatAggable for DataArray { - type Output = DaftResult; - - fn concat(&self) -> Self::Output { - let arrow_array = self.as_arrow(); - let result_box = Box::new(NullArray::new_null(arrow_array.data_type().clone(), 1)); - DataArray::new(self.field().clone().into(), result_box) - } - - fn grouped_concat(&self, groups: &super::GroupIndices) -> Self::Output { - let arrow_array = self.as_arrow(); - let result_box = Box::new(NullArray::new_null( - arrow_array.data_type().clone(), - groups.len(), - )); - DataArray::new(self.field().clone().into(), result_box) - } -} #[cfg(test)] mod test { diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 802117a89c..353c6ca25d 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -253,15 +253,6 @@ impl Series { None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), } } - DataType::Null => { - let downcasted = self.downcast::()?; - match groups { - Some(groups) => { - Ok(DaftConcatAggable::grouped_concat(downcasted, groups)?.into_series()) - } - None => Ok(DaftConcatAggable::concat(downcasted)?.into_series()), - } - } _ => Err(DaftError::TypeError(format!( "concat aggregation is only valid for List, Python types, or Utf8, got {}", self.data_type() diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 14558b7f52..f8c5deb247 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -391,7 +391,6 @@ impl AggExpr { match field.dtype { DataType::List(..) => Ok(field), DataType::Utf8 => Ok(field), - DataType::Null => Ok(field), #[cfg(feature = "python")] DataType::Python => Ok(field), _ => Err(DaftError::TypeError(format!( diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index 969c2b6941..5a50adca9a 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -894,8 +894,7 @@ def test_agg_concat_on_string_null() -> None: df3 = from_pydict({"a": ["the", " quick", None, " fox"]}) res = df3.agg(col("a").agg_concat()).to_pydict() expected = ["the quick fox"] - for txt in expected: - assert txt in res["a"] + assert res["a"] == expected def test_agg_concat_on_string_groupby_null() -> None: @@ -907,7 +906,7 @@ def test_agg_concat_on_string_groupby_null() -> None: def test_agg_concat_on_string_null_list() -> None: - df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}) + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column("a", col("a").cast(DataType.string())) res = df3.agg(col("a").agg_concat()).to_pydict() print(res) expected = [None] @@ -916,7 +915,7 @@ def test_agg_concat_on_string_null_list() -> None: def test_agg_concat_on_string_groupby_null_list() -> None: - df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}) + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column("a", col("a").cast(DataType.string())) res = df3.groupby("b").agg_concat("a").to_pydict() expected = [None, None] assert res["a"] == expected From 48bfa25dd3500eebc60d57c9d94045f59a9ca3e3 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 24 Sep 2024 20:31:45 -0700 Subject: [PATCH 5/6] style --- src/daft-core/src/array/ops/concat_agg.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/daft-core/src/array/ops/concat_agg.rs b/src/daft-core/src/array/ops/concat_agg.rs index e57e8ff76d..09ebb0876e 100644 --- a/src/daft-core/src/array/ops/concat_agg.rs +++ b/src/daft-core/src/array/ops/concat_agg.rs @@ -1,5 +1,5 @@ use arrow2::{ - array::{Array, NullArray, Utf8Array}, + array::{Array, Utf8Array}, bitmap::utils::SlicesIterator, offset::OffsetsBuffer, types::Index, @@ -12,7 +12,7 @@ use crate::{ growable::{make_growable, Growable}, DataArray, ListArray, }, - prelude::{NullType, Utf8Type}, + prelude::Utf8Type, }; #[cfg(feature = "python")] @@ -215,7 +215,6 @@ impl DaftConcatAggable for DataArray { } } - #[cfg(test)] mod test { use std::iter::repeat; From f6455af80a9587cd5a3b154c5027dd54dfbef7ee Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 24 Sep 2024 20:42:44 -0700 Subject: [PATCH 6/6] style --- tests/table/test_table_aggs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index 5a50adca9a..01749a1cdb 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -906,7 +906,9 @@ def test_agg_concat_on_string_groupby_null() -> None: def test_agg_concat_on_string_null_list() -> None: - df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column("a", col("a").cast(DataType.string())) + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column( + "a", col("a").cast(DataType.string()) + ) res = df3.agg(col("a").agg_concat()).to_pydict() print(res) expected = [None] @@ -915,7 +917,9 @@ def test_agg_concat_on_string_null_list() -> None: def test_agg_concat_on_string_groupby_null_list() -> None: - df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column("a", col("a").cast(DataType.string())) + df3 = from_pydict({"a": [None, None, None, None], "b": [1, 2, 1, 2]}).with_column( + "a", col("a").cast(DataType.string()) + ) res = df3.groupby("b").agg_concat("a").to_pydict() expected = [None, None] assert res["a"] == expected