Skip to content

Commit bb14870

Browse files
authored
TST: use one-class pattern in test_numpy (#56512)
* TST: use one-class pattern in test_numpy * revert accidentally-commited
1 parent ee4ceec commit bb14870

File tree

2 files changed

+64
-65
lines changed

2 files changed

+64
-65
lines changed

pandas/tests/extension/base/interface.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
45
from pandas.core.dtypes.common import is_extension_array_dtype
56
from pandas.core.dtypes.dtypes import ExtensionDtype
67

@@ -65,6 +66,9 @@ def test_array_interface(self, data):
6566

6667
result = np.array(data, dtype=object)
6768
expected = np.array(list(data), dtype=object)
69+
if expected.ndim > 1:
70+
# nested data, explicitly construct as 1D
71+
expected = construct_1d_object_array_from_listlike(list(data))
6872
tm.assert_numpy_array_equal(result, expected)
6973

7074
def test_is_extension_array_dtype(self, data):

pandas/tests/extension/test_numpy.py

+60-65
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,14 @@
1818
import numpy as np
1919
import pytest
2020

21-
from pandas.core.dtypes.cast import can_hold_element
2221
from pandas.core.dtypes.dtypes import NumpyEADtype
2322

2423
import pandas as pd
2524
import pandas._testing as tm
2625
from pandas.api.types import is_object_dtype
2726
from pandas.core.arrays.numpy_ import NumpyExtensionArray
28-
from pandas.core.internals import blocks
2927
from pandas.tests.extension import base
3028

31-
32-
def _can_hold_element_patched(obj, element) -> bool:
33-
if isinstance(element, NumpyExtensionArray):
34-
element = element.to_numpy()
35-
return can_hold_element(obj, element)
36-
37-
3829
orig_assert_attr_equal = tm.assert_attr_equal
3930

4031

@@ -78,7 +69,6 @@ def allow_in_pandas(monkeypatch):
7869
"""
7970
with monkeypatch.context() as m:
8071
m.setattr(NumpyExtensionArray, "_typ", "extension")
81-
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
8272
m.setattr(tm.asserters, "assert_attr_equal", _assert_attr_equal)
8373
yield
8474

@@ -175,15 +165,7 @@ def skip_numpy_object(dtype, request):
175165
skip_nested = pytest.mark.usefixtures("skip_numpy_object")
176166

177167

178-
class BaseNumPyTests:
179-
pass
180-
181-
182-
class TestCasting(BaseNumPyTests, base.BaseCastingTests):
183-
pass
184-
185-
186-
class TestConstructors(BaseNumPyTests, base.BaseConstructorsTests):
168+
class TestNumpyExtensionArray(base.ExtensionTests):
187169
@pytest.mark.skip(reason="We don't register our dtype")
188170
# We don't want to register. This test should probably be split in two.
189171
def test_from_dtype(self, data):
@@ -194,8 +176,6 @@ def test_series_constructor_scalar_with_index(self, data, dtype):
194176
# ValueError: Length of passed values is 1, index implies 3.
195177
super().test_series_constructor_scalar_with_index(data, dtype)
196178

197-
198-
class TestDtype(BaseNumPyTests, base.BaseDtypeTests):
199179
def test_check_dtype(self, data, request, using_infer_string):
200180
if data.dtype.numpy_dtype == "object":
201181
request.applymarker(
@@ -214,26 +194,11 @@ def test_is_not_object_type(self, dtype, request):
214194
else:
215195
super().test_is_not_object_type(dtype)
216196

217-
218-
class TestGetitem(BaseNumPyTests, base.BaseGetitemTests):
219197
@skip_nested
220198
def test_getitem_scalar(self, data):
221199
# AssertionError
222200
super().test_getitem_scalar(data)
223201

224-
225-
class TestGroupby(BaseNumPyTests, base.BaseGroupbyTests):
226-
pass
227-
228-
229-
class TestInterface(BaseNumPyTests, base.BaseInterfaceTests):
230-
@skip_nested
231-
def test_array_interface(self, data):
232-
# NumPy array shape inference
233-
super().test_array_interface(data)
234-
235-
236-
class TestMethods(BaseNumPyTests, base.BaseMethodsTests):
237202
@skip_nested
238203
def test_shift_fill_value(self, data):
239204
# np.array shape inference. Shift implementation fails.
@@ -251,7 +216,9 @@ def test_fillna_copy_series(self, data_missing):
251216

252217
@skip_nested
253218
def test_searchsorted(self, data_for_sorting, as_series):
254-
# Test setup fails.
219+
# TODO: NumpyExtensionArray.searchsorted calls ndarray.searchsorted which
220+
# isn't quite what we want in nested data cases. Instead we need to
221+
# adapt something like libindex._bin_search.
255222
super().test_searchsorted(data_for_sorting, as_series)
256223

257224
@pytest.mark.xfail(reason="NumpyExtensionArray.diff may fail on dtype")
@@ -270,38 +237,60 @@ def test_insert_invalid(self, data, invalid_scalar):
270237
# NumpyExtensionArray[object] can hold anything, so skip
271238
super().test_insert_invalid(data, invalid_scalar)
272239

273-
274-
class TestArithmetics(BaseNumPyTests, base.BaseArithmeticOpsTests):
275240
divmod_exc = None
276241
series_scalar_exc = None
277242
frame_scalar_exc = None
278243
series_array_exc = None
279244

280-
@skip_nested
281245
def test_divmod(self, data):
246+
divmod_exc = None
247+
if data.dtype.kind == "O":
248+
divmod_exc = TypeError
249+
self.divmod_exc = divmod_exc
282250
super().test_divmod(data)
283251

284-
@skip_nested
285-
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
252+
def test_divmod_series_array(self, data):
253+
ser = pd.Series(data)
254+
exc = None
255+
if data.dtype.kind == "O":
256+
exc = TypeError
257+
self.divmod_exc = exc
258+
self._check_divmod_op(ser, divmod, data)
259+
260+
def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
261+
opname = all_arithmetic_operators
262+
series_scalar_exc = None
263+
if data.dtype.numpy_dtype == object:
264+
if opname in ["__mul__", "__rmul__"]:
265+
mark = pytest.mark.xfail(
266+
reason="the Series.combine step raises but not the Series method."
267+
)
268+
request.node.add_marker(mark)
269+
series_scalar_exc = TypeError
270+
self.series_scalar_exc = series_scalar_exc
286271
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
287272

288-
def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
273+
def test_arith_series_with_array(self, data, all_arithmetic_operators):
289274
opname = all_arithmetic_operators
275+
series_array_exc = None
290276
if data.dtype.numpy_dtype == object and opname not in ["__add__", "__radd__"]:
291-
mark = pytest.mark.xfail(reason="Fails for object dtype")
292-
request.applymarker(mark)
277+
series_array_exc = TypeError
278+
self.series_array_exc = series_array_exc
293279
super().test_arith_series_with_array(data, all_arithmetic_operators)
294280

295-
@skip_nested
296-
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
281+
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
282+
opname = all_arithmetic_operators
283+
frame_scalar_exc = None
284+
if data.dtype.numpy_dtype == object:
285+
if opname in ["__mul__", "__rmul__"]:
286+
mark = pytest.mark.xfail(
287+
reason="the Series.combine step raises but not the Series method."
288+
)
289+
request.node.add_marker(mark)
290+
frame_scalar_exc = TypeError
291+
self.frame_scalar_exc = frame_scalar_exc
297292
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
298293

299-
300-
class TestPrinting(BaseNumPyTests, base.BasePrintingTests):
301-
pass
302-
303-
304-
class TestReduce(BaseNumPyTests, base.BaseReduceTests):
305294
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
306295
if ser.dtype.kind == "O":
307296
return op_name in ["sum", "min", "max", "any", "all"]
@@ -328,8 +317,6 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
328317
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
329318
pass
330319

331-
332-
class TestMissing(BaseNumPyTests, base.BaseMissingTests):
333320
@skip_nested
334321
def test_fillna_series(self, data_missing):
335322
# Non-scalar "scalar" values.
@@ -340,12 +327,6 @@ def test_fillna_frame(self, data_missing):
340327
# Non-scalar "scalar" values.
341328
super().test_fillna_frame(data_missing)
342329

343-
344-
class TestReshaping(BaseNumPyTests, base.BaseReshapingTests):
345-
pass
346-
347-
348-
class TestSetitem(BaseNumPyTests, base.BaseSetitemTests):
349330
@skip_nested
350331
def test_setitem_invalid(self, data, invalid_scalar):
351332
# object dtype can hold anything, so doesn't raise
@@ -431,11 +412,25 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):
431412
expected = pd.DataFrame({"data": data.to_numpy()})
432413
tm.assert_frame_equal(result, expected, check_column_type=False)
433414

415+
@pytest.mark.xfail(reason="NumpyEADtype is unpacked")
416+
def test_index_from_listlike_with_dtype(self, data):
417+
super().test_index_from_listlike_with_dtype(data)
434418

435-
@skip_nested
436-
class TestParsing(BaseNumPyTests, base.BaseParsingTests):
437-
pass
419+
@skip_nested
420+
@pytest.mark.parametrize("engine", ["c", "python"])
421+
def test_EA_types(self, engine, data, request):
422+
super().test_EA_types(engine, data, request)
423+
424+
@pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
425+
def test_compare_array(self, data, comparison_op):
426+
super().test_compare_array(data, comparison_op)
427+
428+
def test_compare_scalar(self, data, comparison_op, request):
429+
if data.dtype.kind == "f" or comparison_op.__name__ in ["eq", "ne"]:
430+
mark = pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
431+
request.applymarker(mark)
432+
super().test_compare_scalar(data, comparison_op)
438433

439434

440-
class Test2DCompat(BaseNumPyTests, base.NDArrayBacked2DTests):
435+
class Test2DCompat(base.NDArrayBacked2DTests):
441436
pass

0 commit comments

Comments
 (0)