Skip to content

Commit

Permalink
ENH/TST: Add BaseMethodsTests tests for ArrowExtensionArray (#47552)
Browse files Browse the repository at this point in the history
* ENH/TST: Add BaseMethodsTests tests for ArrowExtensionArray

* Passing test now

* add xfails for arraymanager

* Fix typo

* Trigger CI

* Add xfails for min version and datamanger

* Adjust more tests
  • Loading branch information
mroeschke authored Jul 5, 2022
1 parent 4d17588 commit 700ef33
Show file tree
Hide file tree
Showing 4 changed files with 317 additions and 10 deletions.
10 changes: 6 additions & 4 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,12 @@ def unique(self):

if not isinstance(values, np.ndarray):
result: ArrayLike = values.unique()
if self.dtype.kind in ["m", "M"] and isinstance(self, ABCSeries):
# GH#31182 Series._values returns EA, unpack for backward-compat
if getattr(self.dtype, "tz", None) is None:
result = np.asarray(result)
if (
isinstance(self.dtype, np.dtype) and self.dtype.kind in ["m", "M"]
) and isinstance(self, ABCSeries):
# GH#31182 Series._values returns EA
# unpack numpy datetime for backward-compat
result = np.asarray(result)
else:
result = unique1d(values)

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from pandas.core.dtypes.common import is_bool_dtype
from pandas.core.dtypes.missing import na_value_for_dtype

import pandas as pd
import pandas._testing as tm
Expand Down Expand Up @@ -49,8 +50,7 @@ def test_value_counts_with_normalize(self, data):
else:
expected = pd.Series(0.0, index=result.index)
expected[result > 0] = 1 / len(values)

if isinstance(data.dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
if na_value_for_dtype(data.dtype) is pd.NA:
# TODO(GH#44692): avoid special-casing
expected = expected.astype("Float64")

Expand Down
309 changes: 308 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ def data_for_grouping(dtype):
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)


@pytest.fixture
def data_for_sorting(data_for_grouping):
"""
Length-3 array with a known sort order.
This should be three items [B, C, A] with
A < B < C
"""
return type(data_for_grouping)._from_sequence(
[data_for_grouping[0], data_for_grouping[7], data_for_grouping[4]]
)


@pytest.fixture
def data_missing_for_sorting(data_for_grouping):
"""
Length-3 array with a known sort order.
This should be three items [B, NA, A] with
A < B and NA missing.
"""
return type(data_for_grouping)._from_sequence(
[data_for_grouping[0], data_for_grouping[2], data_for_grouping[4]]
)


@pytest.fixture
def na_value():
"""The scalar missing value for this type. Default 'None'"""
Expand Down Expand Up @@ -654,7 +680,7 @@ def test_setitem_loc_scalar_single(self, data, using_array_manager, request):
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
Expand Down Expand Up @@ -988,6 +1014,287 @@ def test_EA_types(self, engine, data, request):
super().test_EA_types(engine, data)


class TestBaseMethods(base.BaseMethodsTests):
@pytest.mark.parametrize("dropna", [True, False])
def test_value_counts(self, all_data, dropna, request):
pa_dtype = all_data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"value_count has no kernel for {pa_dtype}",
)
)
super().test_value_counts(all_data, dropna)

def test_value_counts_with_normalize(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"value_count has no pyarrow kernel for {pa_dtype}",
)
)
super().test_value_counts_with_normalize(data)

def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize("ascending", [True, False])
def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype) and not ascending and not pa_version_under2p0:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=(
f"unique has no pyarrow kernel "
f"for {pa_dtype} when ascending={ascending}"
),
)
)
super().test_sort_values(data_for_sorting, ascending, sort_by_key)

@pytest.mark.parametrize("ascending", [True, False])
def test_sort_values_frame(self, data_for_sorting, ascending, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=(
f"dictionary_encode has no pyarrow kernel "
f"for {pa_dtype} when ascending={ascending}"
),
)
)
super().test_sort_values_frame(data_for_sorting, ascending)

@pytest.mark.parametrize("box", [pd.Series, lambda x: x])
@pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique])
def test_unique(self, data, box, method, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype) and not pa_version_under2p0:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"unique has no pyarrow kernel for {pa_dtype}.",
)
)
super().test_unique(data, box, method)

@pytest.mark.parametrize("na_sentinel", [-1, -2])
def test_factorize(self, data_for_grouping, na_sentinel, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
)
)
elif pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_factorize(data_for_grouping, na_sentinel)

@pytest.mark.parametrize("na_sentinel", [-1, -2])
def test_factorize_equivalence(self, data_for_grouping, na_sentinel, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
)
)
super().test_factorize_equivalence(data_for_grouping, na_sentinel)

def test_factorize_empty(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
)
)
super().test_factorize_empty(data)

def test_fillna_copy_frame(self, data_missing, request, using_array_manager):
pa_dtype = data_missing.dtype.pyarrow_dtype
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
super().test_fillna_copy_frame(data_missing)

def test_fillna_copy_series(self, data_missing, request, using_array_manager):
pa_dtype = data_missing.dtype.pyarrow_dtype
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
super().test_fillna_copy_series(data_missing)

def test_shift_fill_value(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_shift_fill_value(data)

@pytest.mark.parametrize("repeats", [0, 1, 2, [1, 2, 3]])
def test_repeat(self, data, repeats, as_series, use_numpy, request):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC") and repeats != 0:
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"Not supported by pyarrow < 2.0 with "
f"timestamp type {tz} when repeats={repeats}"
)
)
)
super().test_repeat(data, repeats, as_series, use_numpy)

def test_insert(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_insert(data)

def test_combine_first(self, data, request, using_array_manager):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
elif pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_combine_first(data)

@pytest.mark.parametrize("frame", [True, False])
@pytest.mark.parametrize(
"periods, indices",
[(-2, [2, 3, 4, -1, -1]), (0, [0, 1, 2, 3, 4]), (2, [-1, -1, 0, 1, 2])],
)
def test_container_shift(
self, data, frame, periods, indices, request, using_array_manager
):
pa_dtype = data.dtype.pyarrow_dtype
if (
using_array_manager
and pa.types.is_duration(pa_dtype)
and periods in (-2, 2)
):
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"Checking ndim when using arraymanager with "
f"{pa_dtype} and periods={periods}"
)
)
)
super().test_container_shift(data, frame, periods, indices)

@pytest.mark.xfail(
reason="result dtype pyarrow[bool] better than expected dtype object"
)
def test_combine_le(self, data_repeated):
super().test_combine_le(data_repeated)

def test_combine_add(self, data_repeated, request):
pa_dtype = next(data_repeated(1)).dtype.pyarrow_dtype
if pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=f"{pa_dtype} cannot be added to {pa_dtype}",
)
)
super().test_combine_add(data_repeated)

def test_searchsorted(self, data_for_sorting, as_series, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_searchsorted(data_for_sorting, as_series)

def test_where_series(self, data, na_value, as_frame, request, using_array_manager):
pa_dtype = data.dtype.pyarrow_dtype
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
elif pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"Unsupported cast from double to {pa_dtype}",
)
)
super().test_where_series(data, na_value, as_frame)


def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):


class TestMethods(base.BaseMethodsTests):
@pytest.mark.xfail(reason="returns nullable: GH 44692")
def test_value_counts_with_normalize(self, data):
super().test_value_counts_with_normalize(data)
pass


class TestCasting(base.BaseCastingTests):
Expand Down

0 comments on commit 700ef33

Please sign in to comment.