Skip to content

Commit a1fecf4

Browse files
committed
Disallow in pandas
1 parent 8729566 commit a1fecf4

File tree

8 files changed

+48
-9
lines changed

8 files changed

+48
-9
lines changed

pandas/core/arrays/numpy_.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def itemsize(self):
5656

5757

5858
class NumPyExtensionArray(ExtensionArray, ExtensionOpsMixin):
59+
_typ = "npy_extension"
5960
__array_priority__ = 1000
6061

6162
def __init__(self, values):

pandas/core/dtypes/generic.py

+4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ def _check(cls, inst):
6767
("extension",
6868
"categorical",
6969
"periodarray",
70+
"npy_extension",
7071
))
72+
ABCNumPyExtensionArray = create_pandas_abc_type("ABCNumPyExtensionArray",
73+
"_typ",
74+
("npy_extension",))
7175

7276

7377
class _ABCGeneric(type):

pandas/core/indexes/base.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
import pandas.core.dtypes.concat as _concat
2727
from pandas.core.dtypes.generic import (
2828
ABCDataFrame, ABCDateOffset, ABCDatetimeIndex, ABCIndexClass,
29-
ABCMultiIndex, ABCPeriodIndex, ABCSeries, ABCTimedeltaArray,
30-
ABCTimedeltaIndex)
29+
ABCMultiIndex, ABCNumPyExtensionArray, ABCPeriodIndex, ABCSeries,
30+
ABCTimedeltaArray, ABCTimedeltaIndex)
3131
from pandas.core.dtypes.missing import array_equivalent, isna
3232

3333
from pandas.core import ops
@@ -261,6 +261,9 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None,
261261
return cls._simple_new(data, name)
262262

263263
from .range import RangeIndex
264+
if isinstance(data, ABCNumPyExtensionArray):
265+
# ensure users don't accidentally put a NumPyEA in an index.
266+
data = data._ndarray
264267

265268
# range
266269
if isinstance(data, RangeIndex):

pandas/core/internals/construction.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
is_extension_array_dtype, is_extension_type, is_float_dtype,
2424
is_integer_dtype, is_iterator, is_list_like, is_object_dtype, pandas_dtype)
2525
from pandas.core.dtypes.generic import (
26-
ABCDataFrame, ABCDatetimeIndex, ABCIndexClass, ABCPeriodIndex, ABCSeries,
27-
ABCTimedeltaIndex)
26+
ABCDataFrame, ABCDatetimeIndex, ABCIndexClass, ABCNumPyExtensionArray,
27+
ABCPeriodIndex, ABCSeries, ABCTimedeltaIndex)
2828
from pandas.core.dtypes.missing import isna
2929

3030
from pandas.core import algorithms, common as com
@@ -577,6 +577,9 @@ def sanitize_array(data, index, dtype=None, copy=False,
577577
# we will try to copy be-definition here
578578
subarr = _try_cast(data, True, dtype, copy, raise_cast_failure)
579579

580+
elif isinstance(data, ABCNumPyExtensionArray):
581+
# don't let people put NumPy EAs into Series.
582+
subarr = data._ndarray
580583
elif isinstance(data, ExtensionArray):
581584
subarr = data
582585

pandas/tests/extension/test_numpy.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,19 @@ def dtype():
1313

1414

1515
@pytest.fixture
16-
def data():
16+
def allow_in_pandas(monkeypatch):
17+
with monkeypatch.context() as m:
18+
m.setattr(NumPyExtensionArray, '_typ', 'extension')
19+
yield
20+
21+
22+
@pytest.fixture
23+
def data(allow_in_pandas):
1724
return NumPyExtensionArray(np.arange(100))
1825

1926

2027
@pytest.fixture
21-
def data_missing():
28+
def data_missing(allow_in_pandas):
2229
return NumPyExtensionArray(np.array([np.nan, 1.0]))
2330

2431

@@ -35,7 +42,7 @@ def cmp(a, b):
3542

3643

3744
@pytest.fixture
38-
def data_for_sorting():
45+
def data_for_sorting(allow_in_pandas):
3946
"""Length-3 array with a known sort order.
4047
4148
This should be three items [B, C, A] with
@@ -47,7 +54,7 @@ def data_for_sorting():
4754

4855

4956
@pytest.fixture
50-
def data_missing_for_sorting():
57+
def data_missing_for_sorting(allow_in_pandas):
5158
"""Length-3 array with a known sort order.
5259
5360
This should be three items [B, NA, A] with
@@ -59,7 +66,7 @@ def data_missing_for_sorting():
5966

6067

6168
@pytest.fixture
62-
def data_for_grouping():
69+
def data_for_grouping(allow_in_pandas):
6370
"""Data for factorization, grouping, and unique tests.
6471
6572
Expected to be like [B, B, NA, NA, A, A, B, C]

pandas/tests/frame/test_constructors.py

+8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pandas as pd
2424
import pandas.util.testing as tm
2525
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
26+
from pandas.core.internals.blocks import IntBlock
2627

2728
from pandas.tests.frame.common import TestData
2829

@@ -2165,6 +2166,13 @@ def test_constructor_range_dtype(self, dtype):
21652166
result = DataFrame({'A': range(5)}, dtype=dtype)
21662167
tm.assert_frame_equal(result, expected)
21672168

2169+
def test_constructor_no_numpy_backed_ea(self):
2170+
arr = pd.Series([1, 2, 3]).array
2171+
result = pd.DataFrame({"A": arr})
2172+
expected = pd.DataFrame({"A": [1, 2, 3]})
2173+
tm.assert_frame_equal(result, expected)
2174+
assert isinstance(result._data.blocks[0], IntBlock)
2175+
21682176

21692177
class TestDataFrameConstructorWithDatetimeTZ(TestData):
21702178

pandas/tests/indexes/test_base.py

+6
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ def test_constructor_int_dtype_nan_raises(self, dtype):
260260
with pytest.raises(ValueError, match=msg):
261261
Index(data, dtype=dtype)
262262

263+
def test_constructor_no_numpy_backed_ea(self):
264+
ser = pd.Series([1, 2, 3])
265+
result = pd.Index(ser.array)
266+
expected = pd.Index([1, 2, 3])
267+
tm.assert_index_equal(result, expected)
268+
263269
@pytest.mark.parametrize("klass,dtype,na_val", [
264270
(pd.Float64Index, np.float64, np.nan),
265271
(pd.DatetimeIndex, 'datetime64[ns]', pd.NaT)

pandas/tests/series/test_constructors.py

+7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Timestamp, date_range, isna, period_range, timedelta_range)
2323
from pandas.api.types import CategoricalDtype
2424
from pandas.core.arrays import period_array
25+
from pandas.core.internals.blocks import IntBlock
2526
import pandas.util.testing as tm
2627
from pandas.util.testing import assert_series_equal
2728

@@ -1238,3 +1239,9 @@ def test_constructor_tz_mixed_data(self):
12381239
result = Series(dt_list)
12391240
expected = Series(dt_list, dtype=object)
12401241
tm.assert_series_equal(result, expected)
1242+
1243+
def test_constructor_no_numpy_backed_ea(self):
1244+
ser = pd.Series([1, 2, 3])
1245+
result = pd.Series(ser.array)
1246+
tm.assert_series_equal(ser, result)
1247+
assert isinstance(result._data.blocks[0], IntBlock)

0 commit comments

Comments
 (0)