Skip to content

Commit edae9d1

Browse files
authored
Implement __array__ on ExtensionIndex (#32255)
1 parent ed8df2d commit edae9d1

File tree

8 files changed

+43
-18
lines changed

8 files changed

+43
-18
lines changed

doc/source/whatsnew/v1.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ Backwards incompatible API changes
9191
now raise a ``TypeError`` if a not-accepted keyword argument is passed into it.
9292
Previously a ``UnsupportedFunctionCall`` was raised (``AssertionError`` if ``min_count`` passed into :meth:`~DataFrameGroupby.median``) (:issue:`31485`)
9393
- :meth:`DataFrame.at` and :meth:`Series.at` will raise a ``TypeError`` instead of a ``ValueError`` if an incompatible key is passed, and ``KeyError`` if a missing key is passed, matching the behavior of ``.loc[]`` (:issue:`31722`)
94+
- Passing an integer dtype other than ``int64`` to ``np.array(period_index, dtype=...)`` will now raise ``TypeError`` instead of incorrectly using ``int64`` (:issue:`32255`)
95+
-
9496

9597
.. _whatsnew_110.api_breaking.indexing_raises_key_errors:
9698

pandas/core/arrays/period.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,12 @@ def freq(self):
282282
return self.dtype.freq
283283

284284
def __array__(self, dtype=None) -> np.ndarray:
285-
# overriding DatetimelikeArray
285+
if dtype == "i8":
286+
return self.asi8
287+
elif dtype == bool:
288+
return ~self._isnan
289+
290+
# This will raise TypeErorr for non-object dtypes
286291
return np.array(list(self), dtype=object)
287292

288293
def __arrow_array__(self, type=None):

pandas/core/indexes/category.py

-4
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,6 @@ def __contains__(self, key: Any) -> bool:
364364
hash(key)
365365
return contains(self, key, container=self._engine)
366366

367-
def __array__(self, dtype=None) -> np.ndarray:
368-
""" the array interface, return my values """
369-
return np.array(self._data, dtype=dtype)
370-
371367
@Appender(Index.astype.__doc__)
372368
def astype(self, dtype, copy=True):
373369
if is_interval_dtype(dtype):

pandas/core/indexes/datetimes.py

-3
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,6 @@ def _simple_new(cls, values, name=None, freq=None, tz=None, dtype=None):
267267

268268
# --------------------------------------------------------------------
269269

270-
def __array__(self, dtype=None) -> np.ndarray:
271-
return np.asarray(self._data, dtype=dtype)
272-
273270
@cache_readonly
274271
def _is_dates_only(self) -> bool:
275272
"""

pandas/core/indexes/extension.py

+3
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ def __iter__(self):
224224

225225
# ---------------------------------------------------------------------
226226

227+
def __array__(self, dtype=None) -> np.ndarray:
228+
return np.asarray(self._data, dtype=dtype)
229+
227230
@property
228231
def _ndarray_values(self) -> np.ndarray:
229232
return self._data._ndarray_values

pandas/core/indexes/period.py

-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
is_dtype_equal,
2020
is_float,
2121
is_integer,
22-
is_integer_dtype,
2322
is_object_dtype,
2423
is_scalar,
2524
pandas_dtype,
@@ -338,12 +337,6 @@ def _int64index(self) -> Int64Index:
338337
# ------------------------------------------------------------------------
339338
# Index Methods
340339

341-
def __array__(self, dtype=None) -> np.ndarray:
342-
if is_integer_dtype(dtype):
343-
return self.asi8
344-
else:
345-
return self.astype(object).values
346-
347340
def __array_wrap__(self, result, context=None):
348341
"""
349342
Gets called after a ufunc. Needs additional handling as

pandas/tests/arrays/test_datetimelike.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -687,10 +687,10 @@ def test_array_interface(self, period_index):
687687
result = np.asarray(arr, dtype=object)
688688
tm.assert_numpy_array_equal(result, expected)
689689

690-
# to other dtypes
691-
with pytest.raises(TypeError):
692-
np.asarray(arr, dtype="int64")
690+
result = np.asarray(arr, dtype="int64")
691+
tm.assert_numpy_array_equal(result, arr.asi8)
693692

693+
# to other dtypes
694694
with pytest.raises(TypeError):
695695
np.asarray(arr, dtype="float64")
696696

pandas/tests/indexes/period/test_period.py

+29
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,32 @@ def test_is_monotonic_with_nat():
681681
assert not obj.is_monotonic_increasing
682682
assert not obj.is_monotonic_decreasing
683683
assert obj.is_unique
684+
685+
686+
@pytest.mark.parametrize("array", [True, False])
687+
def test_dunder_array(array):
688+
obj = PeriodIndex(["2000-01-01", "2001-01-01"], freq="D")
689+
if array:
690+
obj = obj._data
691+
692+
expected = np.array([obj[0], obj[1]], dtype=object)
693+
result = np.array(obj)
694+
tm.assert_numpy_array_equal(result, expected)
695+
696+
result = np.asarray(obj)
697+
tm.assert_numpy_array_equal(result, expected)
698+
699+
expected = obj.asi8
700+
for dtype in ["i8", "int64", np.int64]:
701+
result = np.array(obj, dtype=dtype)
702+
tm.assert_numpy_array_equal(result, expected)
703+
704+
result = np.asarray(obj, dtype=dtype)
705+
tm.assert_numpy_array_equal(result, expected)
706+
707+
for dtype in ["float64", "int32", "uint64"]:
708+
msg = "argument must be"
709+
with pytest.raises(TypeError, match=msg):
710+
np.array(obj, dtype=dtype)
711+
with pytest.raises(TypeError, match=msg):
712+
np.array(obj, dtype=getattr(np, dtype))

0 commit comments

Comments
 (0)