Skip to content

ENH: allow storing ExtensionArrays in Index #43930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 110 commits into from
Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
df9c228
ENH/WIP/POC: EA-backed Index
jbrockmendel Oct 8, 2021
3952027
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 8, 2021
95e0129
BUG: NumericIndex.insert
jbrockmendel Oct 8, 2021
c52d459
Merge branch 'bug-insert' into enh-nullable-index
jbrockmendel Oct 8, 2021
cf0c171
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 9, 2021
0a3b7d7
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 9, 2021
d53377d
fix a few more tests; ignoring linting for now
jbrockmendel Oct 10, 2021
69fb0bd
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 10, 2021
1952cd7
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 11, 2021
1ed588a
fix test
jbrockmendel Oct 11, 2021
34d5dde
down to 38 tests failing
jbrockmendel Oct 11, 2021
42be4e6
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 11, 2021
91b3716
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 12, 2021
544d9fe
down to 15 tests failing
jbrockmendel Oct 13, 2021
e14d6f1
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 13, 2021
22a0939
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 13, 2021
900978c
fix value_counts
jbrockmendel Oct 16, 2021
f9c8791
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 16, 2021
c0ae18c
fix map test
jbrockmendel Oct 18, 2021
4ab7f0d
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 20, 2021
a9ef37e
fix some tests
jbrockmendel Oct 21, 2021
a2a8de9
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 21, 2021
41acf3f
ENH: ExtensionArray.insert
jbrockmendel Oct 22, 2021
37d36ad
Fix usage
jbrockmendel Oct 22, 2021
bafb23f
Fix TimedeltaIndex.insert test
jbrockmendel Oct 22, 2021
bf76950
Merge branch 'enh-ea-insert' into enh-nullable-index
jbrockmendel Oct 22, 2021
a3a349d
pass a few more tests
jbrockmendel Oct 23, 2021
ebbb7a4
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 23, 2021
4229dbf
tests
jbrockmendel Oct 23, 2021
17dedee
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 23, 2021
45fdffb
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 24, 2021
2fc4798
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 24, 2021
2e1843a
REF: share ExtensionIndex.insert-> Index.insert
jbrockmendel Oct 24, 2021
1516f66
Merge branch 'bug-ei-inserts' into enh-nullable-index
jbrockmendel Oct 24, 2021
c92bea5
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 25, 2021
6bede7a
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 25, 2021
2bb1dea
handle a few more tests
jbrockmendel Oct 26, 2021
e692899
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 26, 2021
36b6629
update test
jbrockmendel Oct 27, 2021
1ab2df1
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 27, 2021
94376ba
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 28, 2021
1881599
Fix remaining tests
jbrockmendel Oct 29, 2021
bb90bfd
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 29, 2021
1f97325
no-pyarrow-compat
jbrockmendel Oct 29, 2021
076cada
mypy fixups
jbrockmendel Oct 29, 2021
1e8a31f
remove assertion
jbrockmendel Oct 30, 2021
1d076fe
Merge branch 'master' into enh-nullable-index
jbrockmendel Oct 31, 2021
2d5fa6d
restor astype
jbrockmendel Oct 31, 2021
adf3ddb
older numpy compat
jbrockmendel Nov 1, 2021
95be963
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 1, 2021
fd6880e
xfail
jbrockmendel Nov 1, 2021
3d9b9af
mypy fixup
jbrockmendel Nov 1, 2021
31d547c
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 2, 2021
2d75377
lint fixups
jbrockmendel Nov 2, 2021
37b9370
avoid warnings
jbrockmendel Nov 3, 2021
0e56218
avoid FutureWarnings
jbrockmendel Nov 5, 2021
e8987cd
catch RuntimeWarning
jbrockmendel Nov 5, 2021
fef88a7
remove unreachable
jbrockmendel Nov 5, 2021
d33e306
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 11, 2021
41c040d
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 16, 2021
e307963
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 19, 2021
63f26ba
revert no-longer-necessary
jbrockmendel Nov 19, 2021
11d3564
Share ExtensionEngine/NullableEngine methods
jbrockmendel Nov 20, 2021
ca747fc
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 20, 2021
cdb08ca
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 21, 2021
09d8bf1
lint fixup
jbrockmendel Nov 21, 2021
7d783b1
revert no-longer-necessary
jbrockmendel Nov 21, 2021
2553db4
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 22, 2021
7a44a79
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 23, 2021
de64249
remove unnecessary from test_setops
jbrockmendel Nov 23, 2021
1bb2901
suggested edits
jbrockmendel Nov 23, 2021
5a76d5d
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 24, 2021
a779c75
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 25, 2021
23bb325
actually run the new base extension tests for all EAs
jorisvandenbossche Nov 26, 2021
a3fba50
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 30, 2021
4abd60e
update tests
jbrockmendel Nov 30, 2021
a9f7ea9
Merge branch 'master' into enh-nullable-index
jbrockmendel Nov 30, 2021
7f9741e
older np compat
jbrockmendel Nov 30, 2021
c22d6e9
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 1, 2021
40e861b
32bit compat
jbrockmendel Dec 1, 2021
6e79350
simplify, docstring
jbrockmendel Dec 1, 2021
ae6bb0c
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 1, 2021
90366e9
32bit compat
jbrockmendel Dec 1, 2021
c8b1d7d
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 1, 2021
70debb2
Address comments
jbrockmendel Dec 1, 2021
b6a3a47
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 3, 2021
c74a7a7
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 5, 2021
0339e69
simplify
jbrockmendel Dec 5, 2021
e737850
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 5, 2021
96b25aa
dont catch np.float16 too early
jbrockmendel Dec 5, 2021
788eda1
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 7, 2021
812da93
de-xfail
jbrockmendel Dec 7, 2021
e394394
remove edits made extraneous by other PRs
jbrockmendel Dec 7, 2021
7718bb1
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 13, 2021
3e1ec00
suggested edits
jbrockmendel Dec 13, 2021
3596bcf
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 15, 2021
267b1b3
Remove NullableEngine, ExtensionEngine
jbrockmendel Dec 15, 2021
70cad91
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 21, 2021
c8072c5
revert
jbrockmendel Dec 21, 2021
7e0ac18
remove no-longer-necessary
jbrockmendel Dec 21, 2021
80453b4
whatsnew
jbrockmendel Dec 21, 2021
7231a9e
deprecation for SparseArray
jbrockmendel Dec 22, 2021
f78aa0f
share _na_value method
jbrockmendel Dec 22, 2021
453d6ae
mypy fixup, npdev catch warnings
jbrockmendel Dec 22, 2021
8750248
mypy fixup
jbrockmendel Dec 22, 2021
0b01bf9
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 23, 2021
d2e0266
compat for older numpy
jbrockmendel Dec 23, 2021
8daa2dc
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 23, 2021
cf95f32
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 25, 2021
7fba6a2
Merge branch 'master' into enh-nullable-index
jbrockmendel Dec 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
505 changes: 504 additions & 1 deletion pandas/_libs/index.pyx

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pandas/_libs/lib.pxd
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from numpy cimport ndarray

cdef bint c_is_list_like(object, bint) except -1

cpdef ndarray eq_NA_compat(ndarray[object] arr, object key)
20 changes: 20 additions & 0 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3028,3 +3028,23 @@ def is_bool_list(obj: list) -> bool:

# Note: we return True for empty list
return True


cpdef ndarray eq_NA_compat(ndarray[object] arr, object key):
cdef:
ndarray[uint8_t, cast=True] result = np.empty(len(arr), dtype=bool)
Py_ssize_t i
object item

if key is C_NA:
for i in range(len(arr)):
item = arr[i]
result[i] = item is C_NA
else:
for i in range(len(arr)):
item = arr[i]
if item is C_NA:
result[i] = False
else:
result[i] = item == key # FIXME: compat for other NAs
return result
16 changes: 12 additions & 4 deletions pandas/_libs/testing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from pandas._libs.util cimport (
is_real_number_object,
)

from pandas._libs.missing cimport is_matching_na
from pandas.core.dtypes.common import is_dtype_equal
from pandas.core.dtypes.missing import (
array_equivalent,
Expand Down Expand Up @@ -174,11 +175,18 @@ cpdef assert_almost_equal(a, b,
# classes can't be the same, to raise error
assert_class_equal(a, b, obj=obj)

if isna(a) and isna(b):
# TODO: Should require same-dtype NA?
# nan / None comparison
return True
if isna(a):
if isna(b):
# TODO: Should require same-dtype NA?
# nan / None comparison
return True

assert False, f"expected {a} but got {b}"

elif isna(b):
assert False, f"expected {a} but got {b}"

# TODO: test for tm.assert_whatever with pd.NA that would raise here
if a == b:
# object comparison
return True
Expand Down
10 changes: 9 additions & 1 deletion pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,16 @@ def _get_ilevel_values(index, level):
# skip exact index checking when `check_categorical` is False
if check_exact and check_categorical:
if not left.equals(right):
mismatch = left._values != right._values

if not isinstance(mismatch, np.ndarray):
# i.e. its a MaskedArray
mismatch = mismatch.to_numpy(dtype=int, na_value=0)
mismask = left._values._mask ^ right._values._mask
mismatch[mismask] = 1

diff = (
np.sum((left._values != right._values).astype(int)) * 100.0 / len(left)
np.sum(mismatch.astype(int)) * 100.0 / len(left)
)
msg = f"{obj} values are different ({np.round(diff, 5)} %)"
raise_assert_detail(obj, msg, left, right)
Expand Down
8 changes: 8 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,14 @@ def _create_mi_with_dt64tz_level():
"mi-with-dt64tz-level": _create_mi_with_dt64tz_level(),
"multi": _create_multiindex(),
"repeats": Index([0, 0, 1, 1, 2, 2]),
"nullable_int": Index(np.arange(100), dtype="Int64"),
"nullable_float": Index(np.arange(100), dtype="Float32"),
"nullable_bool": Index(np.arange(100).astype(bool), dtype="boolean"),
#"nullable_int-na": Index(np.arange(100), dtype="Int64").insert(1, pd.NA),
#"nullable_float-na": Index(np.arange(100), dtype="Float32").insert(1, pd.NA),
#"nullable_bool-na": Index(np.arange(100).astype(bool), dtype="boolean").insert(
# 1, pd.NA
#),
}


Expand Down
19 changes: 19 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,25 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
# ------------------------------------------------------------------------
# Non-Optimized Default Methods

def putmask(self, mask: np.ndarray, value) -> None:
"""
Analogue to np.putmask(self, mask, value)

Parameters
----------
mask : np.ndarray[bool]
value : scalar or listlike

Raises
------
TypeError
If value cannot be inserted into self.
"""
if not is_list_like(value):
self[mask] = value
else:
self[mask] = value[mask]

def tolist(self) -> list:
"""
Return a list of the values.
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/arrays/floating.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def coerce_to_array(
if dtype is None and hasattr(values, "dtype"):
if is_float_dtype(values.dtype):
dtype = values.dtype
if dtype == "float16":
raise TypeError("FloatingArray does not support float16 dtype")

if dtype is not None:
if isinstance(dtype, str) and dtype.startswith("Float"):
Expand Down Expand Up @@ -254,7 +256,8 @@ def dtype(self) -> FloatingDtype:
return FLOAT_STR_TO_DTYPE[str(self._data.dtype)]

def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
if not (isinstance(values, np.ndarray) and values.dtype.kind == "f"):
if not (isinstance(values, np.ndarray) and values.dtype.kind == "f" and values.dtype.itemsize > 2):
# We do not support float16
raise TypeError(
"values should be floating numpy array. Use "
"the 'pd.array' function instead"
Expand Down Expand Up @@ -422,6 +425,8 @@ def _maybe_mask_result(self, result, mask, other, op_name: str):

return type(self)(result, mask, copy=False)

def isna(self):
return self._mask | np.isnan(self._data)

_dtype_docstring = """
An ExtensionDtype for {dtype} data.
Expand Down
16 changes: 11 additions & 5 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ def reconstruct(x):
m = mask.copy()
return IntegerArray(x, m)
elif is_float_dtype(x.dtype):
if x.dtype.itemsize <= 2:
# we don't support float16
x = x.astype(np.float32)
m = mask.copy()
return FloatingArray(x, m)
else:
Expand Down Expand Up @@ -564,7 +567,7 @@ def value_counts(self, dropna: bool = True) -> Series:
# TODO(extension)
# if we have allow Index to hold an ExtensionArray
# this is easier
index = value_counts.index._values.astype(object)
index = value_counts.index # ._values.astype(object)

# if we want nans, count the mask
if dropna:
Expand All @@ -574,10 +577,13 @@ def value_counts(self, dropna: bool = True) -> Series:
counts[:-1] = value_counts
counts[-1] = self._mask.sum()

index = Index(
np.concatenate([index, np.array([self.dtype.na_value], dtype=object)]),
dtype=object,
)
index = index.insert(-1, self.dtype.na_value)
# index = Index(
# np.concatenate([index, np.array([self.dtype.na_value], dtype=object)]),
# dtype=object,
# )

index = index.astype(self.dtype)

mask = np.zeros(len(counts), dtype="bool")
counts = IntegerArray(counts, mask)
Expand Down
48 changes: 40 additions & 8 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
algos as libalgos,
index as libindex,
lib,
missing as libmissing,
)
import pandas._libs.join as libjoin
from pandas._libs.lib import (
Expand Down Expand Up @@ -135,6 +136,7 @@
tz_to_dtype,
validate_tz_from_dtype,
)
from pandas.core.arrays.masked import BaseMaskedArray
from pandas.core.arrays.sparse import SparseDtype
from pandas.core.base import (
IndexOpsMixin,
Expand Down Expand Up @@ -356,7 +358,7 @@ def _outer_indexer(

_typ: str = "index"
_data: ExtensionArray | np.ndarray
_data_cls: type[np.ndarray] | type[ExtensionArray] = np.ndarray
_data_cls: type[np.ndarray] | type[ExtensionArray] = (np.ndarray, ExtensionArray)
_id: object | None = None
_name: Hashable = None
# MultiIndex.levels previously allowed setting the index name. We
Expand Down Expand Up @@ -411,8 +413,9 @@ def __new__(
validate_tz_from_dtype(dtype, tz)
dtype = tz_to_dtype(tz)

if isinstance(data, PandasArray):
# ensure users don't accidentally put a PandasArray in an index.
if type(data) is PandasArray:
# ensure users don't accidentally put a PandasArray in an index,
# but don't unpack StringArray
data = data.to_numpy()
if isinstance(dtype, PandasDtype):
dtype = dtype.numpy_dtype
Expand All @@ -434,7 +437,6 @@ def __new__(

ea_cls = dtype.construct_array_type()
data = ea_cls._from_sequence(data, dtype=dtype, copy=copy)
data = np.asarray(data, dtype=object)
disallow_kwargs(kwargs)
return Index._simple_new(data, name=name)

Expand All @@ -446,8 +448,8 @@ def __new__(
return result.astype(dtype, copy=False)
return result

data = np.array(data, dtype=object, copy=copy)
disallow_kwargs(kwargs)
data = extract_array(data, extract_numpy=True)
return Index._simple_new(data, name=name)

# index-like
Expand Down Expand Up @@ -657,6 +659,7 @@ def _with_infer(cls, *args, **kwargs):
Constructor that uses the 1.0.x behavior inferring numeric dtypes
for ndarray[object] inputs.
"""

with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*the Index constructor", FutureWarning)
result = cls(*args, **kwargs)
Expand Down Expand Up @@ -812,6 +815,15 @@ def _cleanup(self) -> None:
def _engine(self) -> libindex.IndexEngine:
# For base class (object dtype) we get ObjectEngine

if isinstance(self._values, BaseMaskedArray):
return libindex.NullableEngine(self._values)
elif (
isinstance(self._values, ExtensionArray)
and self._engine_type is libindex.ObjectEngine
):
return libindex.ExtensionEngine(self._values)

assert self.dtype != "boolean"
# to avoid a reference cycle, bind `target_values` to a local variable, so
# `self` is not passed into the lambda.
target_values = self._get_engine_target()
Expand Down Expand Up @@ -1025,9 +1037,15 @@ def take(

# Note: we discard fill_value and use self._na_value, only relevant
# in the case where allow_fill is True and fill_value is not None
taken = algos.take(
self._values, indices, allow_fill=allow_fill, fill_value=self._na_value
)
values = self._values
if isinstance(values, np.ndarray):
taken = algos.take(
values, indices, allow_fill=allow_fill, fill_value=self._na_value
)
else:
taken = values.take(
indices, allow_fill=allow_fill, fill_value=self._na_value
)
# _constructor so RangeIndex->Int64Index
return self._constructor._simple_new(taken, name=self.name)

Expand Down Expand Up @@ -3572,6 +3590,7 @@ def get_indexer(

indexer = self._engine.get_indexer(target.codes)
if self.hasnans and target.hasnans:
#loc = self.get_loc(libmissing.NA)
loc = self.get_loc(np.nan)
mask = target.isna()
indexer[mask] = loc
Expand All @@ -3590,6 +3609,7 @@ def get_indexer(
# Exclude MultiIndex because hasnans raises NotImplementedError
# we should only get here if we are unique, so loc is an integer
# GH#41934
#loc = self.get_loc(libmissing.NA)
loc = self.get_loc(np.nan)
mask = target.isna()
indexer[mask] = loc
Expand Down Expand Up @@ -6353,6 +6373,18 @@ def insert(self, loc: int, item) -> Index:

arr = self._values

if isinstance(arr, ExtensionArray):
# TODO: need EA.insert
try:
arr2 = type(arr)._from_sequence([item], dtype=arr.dtype)
except TypeError:
# TODO: make this into _validate_fill_value
dtype = self._find_common_type_compat(item)
return self.astype(dtype).insert(loc, item)

res_values = arr._concat_same_type([arr[:loc], arr2, arr[loc:]])
return type(self)._simple_new(res_values, name=self.name)

if arr.dtype != object or not isinstance(
item, (tuple, np.datetime64, np.timedelta64)
):
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,12 @@ def _get_indexer_pointwise(
if isinstance(locs, slice):
# Only needed for get_indexer_non_unique
locs = np.arange(locs.start, locs.stop, locs.step, dtype="intp")
locs = np.array(locs, ndmin=1)
elif lib.is_integer(locs):
locs = np.array(locs, ndmin=1)
else:
# FIXME: This is wrong; its boolean; not reached
assert locs.dtype.kind == "i"

except KeyError:
missing.append(i)
locs = np.array([-1])
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def _validate_key(self, key, axis: int):
# slice of labels (where start-end in labels)
# slice of integers (only if in the labels)
# boolean not in slice and with boolean index
if isinstance(key, bool) and not is_bool_dtype(self.obj.index):
if isinstance(key, bool) and not (is_bool_dtype(self.obj.index) or self.obj.index.dtype.name == "boolean"):
raise KeyError(
f"{key}: boolean label can not be used without a boolean index"
)
Expand Down
23 changes: 16 additions & 7 deletions pandas/tests/arithmetic/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,14 +1403,23 @@ def test_integer_array_add_list_like(
left = container + box_1d_array(data)
right = box_1d_array(data) + container

if Series == box_pandas_1d_array:
expected = Series(expected_data, dtype="Int64")
elif Series == box_1d_array:
expected = Series(expected_data, dtype="object")
elif Index in (box_pandas_1d_array, box_1d_array):
expected = Int64Index(expected_data)
if Series in [box_1d_array, box_pandas_1d_array]:
cls = Series
elif Index in [box_1d_array, box_pandas_1d_array]:
cls = Index
else:
expected = np.array(expected_data, dtype="object")
cls = np.array

if box_pandas_1d_array in [Index, Series]:
expected = cls(expected_data, dtype="Int64")

elif box_1d_array == Index:
# tm.to_array casts to object, Index constructor does inference
expected = cls(expected_data, dtype="int64")

else:
# tm.to_array casts to object, no inference
expected = cls(expected_data, dtype="object")

tm.assert_equal(left, expected)
tm.assert_equal(right, expected)
Expand Down
13 changes: 8 additions & 5 deletions pandas/tests/arrays/boolean/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,21 @@ def test_ufunc_reduce_raises(values):
def test_value_counts_na():
arr = pd.array([True, False, pd.NA], dtype="boolean")
result = arr.value_counts(dropna=False)
expected = pd.Series([1, 1, 1], index=[True, False, pd.NA], dtype="Int64")
expected = pd.Series([1, 1, 1], index=arr, dtype="Int64")
assert expected.index.dtype == arr.dtype
tm.assert_series_equal(result, expected)

result = arr.value_counts(dropna=True)
expected = pd.Series([1, 1], index=[True, False], dtype="Int64")
expected = pd.Series([1, 1], index=arr[:-1], dtype="Int64")
assert expected.index.dtype == arr.dtype
tm.assert_series_equal(result, expected)


def test_value_counts_with_normalize():
s = pd.Series([True, False, pd.NA], dtype="boolean")
result = s.value_counts(normalize=True)
expected = pd.Series([1, 1], index=[True, False], dtype="Float64") / 2
ser = pd.Series([True, False, pd.NA], dtype="boolean")
result = ser.value_counts(normalize=True)
expected = pd.Series([1, 1], index=ser[:-1], dtype="Float64") / 2
assert expected.index.dtype == "boolean"
tm.assert_series_equal(result, expected)


Expand Down
Loading