Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH/TST: Add BaseMethodsTests tests for ArrowExtensionArray #47552

Merged
merged 10 commits into from
Jul 5, 2022
10 changes: 6 additions & 4 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,12 @@ def unique(self):

if not isinstance(values, np.ndarray):
result: ArrayLike = values.unique()
if self.dtype.kind in ["m", "M"] and isinstance(self, ABCSeries):
# GH#31182 Series._values returns EA, unpack for backward-compat
if getattr(self.dtype, "tz", None) is None:
result = np.asarray(result)
if (
isinstance(self.dtype, np.dtype) and self.dtype.kind in ["m", "M"]
) and isinstance(self, ABCSeries):
# GH#31182 Series._values returns EA
# unpack numpy datetime for backward-compat
result = np.asarray(result)
else:
result = unique1d(values)

Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from pandas.core.dtypes.common import is_bool_dtype
from pandas.core.dtypes.missing import na_value_for_dtype

import pandas as pd
import pandas._testing as tm
Expand Down Expand Up @@ -49,8 +50,7 @@ def test_value_counts_with_normalize(self, data):
else:
expected = pd.Series(0.0, index=result.index)
expected[result > 0] = 1 / len(values)

if isinstance(data.dtype, pd.core.dtypes.dtypes.BaseMaskedDtype):
if na_value_for_dtype(data.dtype) is pd.NA:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not for this PR, but we probably need to find a way to let authors define na_value_for_dtype

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably similar to #27825

# TODO(GH#44692): avoid special-casing
expected = expected.astype("Float64")

Expand Down
309 changes: 308 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ def data_for_grouping(dtype):
return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)


@pytest.fixture
def data_for_sorting(data_for_grouping):
"""
Length-3 array with a known sort order.

This should be three items [B, C, A] with
A < B < C
"""
return type(data_for_grouping)._from_sequence(
[data_for_grouping[0], data_for_grouping[7], data_for_grouping[4]]
)


@pytest.fixture
def data_missing_for_sorting(data_for_grouping):
"""
Length-3 array with a known sort order.

This should be three items [B, NA, A] with
A < B and NA missing.
"""
return type(data_for_grouping)._from_sequence(
[data_for_grouping[0], data_for_grouping[2], data_for_grouping[4]]
)


@pytest.fixture
def na_value():
"""The scalar missing value for this type. Default 'None'"""
Expand Down Expand Up @@ -654,7 +680,7 @@ def test_setitem_loc_scalar_single(self, data, using_array_manager, request):
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=(f"Not supported by pyarrow < 2.0 with timestamp type {tz}")
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
elif using_array_manager and pa.types.is_duration(data.dtype.pyarrow_dtype):
Expand Down Expand Up @@ -988,6 +1014,287 @@ def test_EA_types(self, engine, data, request):
super().test_EA_types(engine, data)


class TestBaseMethods(base.BaseMethodsTests):
@pytest.mark.parametrize("dropna", [True, False])
def test_value_counts(self, all_data, dropna, request):
pa_dtype = all_data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"value_count has no kernel for {pa_dtype}",
)
)
super().test_value_counts(all_data, dropna)

def test_value_counts_with_normalize(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_date(pa_dtype) or (
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
):
request.node.add_marker(
pytest.mark.xfail(
raises=AttributeError,
reason="GH 34986",
)
)
elif pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"value_count has no pyarrow kernel for {pa_dtype}",
)
)
super().test_value_counts_with_normalize(data)

def test_argmin_argmax(
self, data_for_sorting, data_missing_for_sorting, na_value, request
):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)

@pytest.mark.parametrize("ascending", [True, False])
def test_sort_values(self, data_for_sorting, ascending, sort_by_key, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype) and not ascending and not pa_version_under2p0:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=(
f"unique has no pyarrow kernel "
f"for {pa_dtype} when ascending={ascending}"
),
)
)
super().test_sort_values(data_for_sorting, ascending, sort_by_key)

@pytest.mark.parametrize("ascending", [True, False])
def test_sort_values_frame(self, data_for_sorting, ascending, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=(
f"dictionary_encode has no pyarrow kernel "
f"for {pa_dtype} when ascending={ascending}"
),
)
)
super().test_sort_values_frame(data_for_sorting, ascending)

@pytest.mark.parametrize("box", [pd.Series, lambda x: x])
@pytest.mark.parametrize("method", [lambda x: x.unique(), pd.unique])
def test_unique(self, data, box, method, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype) and not pa_version_under2p0:
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"unique has no pyarrow kernel for {pa_dtype}.",
)
)
super().test_unique(data, box, method)

@pytest.mark.parametrize("na_sentinel", [-1, -2])
def test_factorize(self, data_for_grouping, na_sentinel, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
)
)
elif pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_factorize(data_for_grouping, na_sentinel)

@pytest.mark.parametrize("na_sentinel", [-1, -2])
def test_factorize_equivalence(self, data_for_grouping, na_sentinel, request):
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
)
)
super().test_factorize_equivalence(data_for_grouping, na_sentinel)

def test_factorize_empty(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
if pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"dictionary_encode has no pyarrow kernel for {pa_dtype}",
)
)
super().test_factorize_empty(data)

def test_fillna_copy_frame(self, data_missing, request, using_array_manager):
pa_dtype = data_missing.dtype.pyarrow_dtype
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
super().test_fillna_copy_frame(data_missing)

def test_fillna_copy_series(self, data_missing, request, using_array_manager):
pa_dtype = data_missing.dtype.pyarrow_dtype
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
super().test_fillna_copy_series(data_missing)

def test_shift_fill_value(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_shift_fill_value(data)

@pytest.mark.parametrize("repeats", [0, 1, 2, [1, 2, 3]])
def test_repeat(self, data, repeats, as_series, use_numpy, request):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC") and repeats != 0:
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"Not supported by pyarrow < 2.0 with "
f"timestamp type {tz} when repeats={repeats}"
)
)
)
super().test_repeat(data, repeats, as_series, use_numpy)

def test_insert(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_insert(data)

def test_combine_first(self, data, request, using_array_manager):
pa_dtype = data.dtype.pyarrow_dtype
tz = getattr(pa_dtype, "tz", None)
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
elif pa_version_under2p0 and tz not in (None, "UTC"):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Not supported by pyarrow < 2.0 with timestamp type {tz}"
)
)
super().test_combine_first(data)

@pytest.mark.parametrize("frame", [True, False])
@pytest.mark.parametrize(
"periods, indices",
[(-2, [2, 3, 4, -1, -1]), (0, [0, 1, 2, 3, 4]), (2, [-1, -1, 0, 1, 2])],
)
def test_container_shift(
self, data, frame, periods, indices, request, using_array_manager
):
pa_dtype = data.dtype.pyarrow_dtype
if (
using_array_manager
and pa.types.is_duration(pa_dtype)
and periods in (-2, 2)
):
request.node.add_marker(
pytest.mark.xfail(
reason=(
f"Checking ndim when using arraymanager with "
f"{pa_dtype} and periods={periods}"
)
)
)
super().test_container_shift(data, frame, periods, indices)

@pytest.mark.xfail(
reason="result dtype pyarrow[bool] better than expected dtype object"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium-term should just override the test with the correct/better behavior

)
def test_combine_le(self, data_repeated):
super().test_combine_le(data_repeated)

def test_combine_add(self, data_repeated, request):
pa_dtype = next(data_repeated(1)).dtype.pyarrow_dtype
if pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=TypeError,
reason=f"{pa_dtype} cannot be added to {pa_dtype}",
)
)
super().test_combine_add(data_repeated)

def test_searchsorted(self, data_for_sorting, as_series, request):
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
if pa.types.is_boolean(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"{pa_dtype} only has 2 unique possible values",
)
)
super().test_searchsorted(data_for_sorting, as_series)

def test_where_series(self, data, na_value, as_frame, request, using_array_manager):
pa_dtype = data.dtype.pyarrow_dtype
if using_array_manager and pa.types.is_duration(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
reason=f"Checking ndim when using arraymanager with {pa_dtype}"
)
)
elif pa.types.is_temporal(pa_dtype):
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowNotImplementedError,
reason=f"Unsupported cast from double to {pa_dtype}",
)
)
super().test_where_series(data, na_value, as_frame)


def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):


class TestMethods(base.BaseMethodsTests):
@pytest.mark.xfail(reason="returns nullable: GH 44692")
def test_value_counts_with_normalize(self, data):
super().test_value_counts_with_normalize(data)
pass


class TestCasting(base.BaseCastingTests):
Expand Down