18
18
import numpy as np
19
19
import pytest
20
20
21
- from pandas .core .dtypes .cast import can_hold_element
22
21
from pandas .core .dtypes .dtypes import NumpyEADtype
23
22
24
23
import pandas as pd
25
24
import pandas ._testing as tm
26
25
from pandas .api .types import is_object_dtype
27
26
from pandas .core .arrays .numpy_ import NumpyExtensionArray
28
- from pandas .core .internals import blocks
29
27
from pandas .tests .extension import base
30
28
31
-
32
- def _can_hold_element_patched (obj , element ) -> bool :
33
- if isinstance (element , NumpyExtensionArray ):
34
- element = element .to_numpy ()
35
- return can_hold_element (obj , element )
36
-
37
-
38
29
orig_assert_attr_equal = tm .assert_attr_equal
39
30
40
31
@@ -78,7 +69,6 @@ def allow_in_pandas(monkeypatch):
78
69
"""
79
70
with monkeypatch .context () as m :
80
71
m .setattr (NumpyExtensionArray , "_typ" , "extension" )
81
- m .setattr (blocks , "can_hold_element" , _can_hold_element_patched )
82
72
m .setattr (tm .asserters , "assert_attr_equal" , _assert_attr_equal )
83
73
yield
84
74
@@ -175,15 +165,7 @@ def skip_numpy_object(dtype, request):
175
165
skip_nested = pytest .mark .usefixtures ("skip_numpy_object" )
176
166
177
167
178
- class BaseNumPyTests :
179
- pass
180
-
181
-
182
- class TestCasting (BaseNumPyTests , base .BaseCastingTests ):
183
- pass
184
-
185
-
186
- class TestConstructors (BaseNumPyTests , base .BaseConstructorsTests ):
168
+ class TestNumpyExtensionArray (base .ExtensionTests ):
187
169
@pytest .mark .skip (reason = "We don't register our dtype" )
188
170
# We don't want to register. This test should probably be split in two.
189
171
def test_from_dtype (self , data ):
@@ -194,8 +176,6 @@ def test_series_constructor_scalar_with_index(self, data, dtype):
194
176
# ValueError: Length of passed values is 1, index implies 3.
195
177
super ().test_series_constructor_scalar_with_index (data , dtype )
196
178
197
-
198
- class TestDtype (BaseNumPyTests , base .BaseDtypeTests ):
199
179
def test_check_dtype (self , data , request , using_infer_string ):
200
180
if data .dtype .numpy_dtype == "object" :
201
181
request .applymarker (
@@ -214,26 +194,11 @@ def test_is_not_object_type(self, dtype, request):
214
194
else :
215
195
super ().test_is_not_object_type (dtype )
216
196
217
-
218
- class TestGetitem (BaseNumPyTests , base .BaseGetitemTests ):
219
197
@skip_nested
220
198
def test_getitem_scalar (self , data ):
221
199
# AssertionError
222
200
super ().test_getitem_scalar (data )
223
201
224
-
225
- class TestGroupby (BaseNumPyTests , base .BaseGroupbyTests ):
226
- pass
227
-
228
-
229
- class TestInterface (BaseNumPyTests , base .BaseInterfaceTests ):
230
- @skip_nested
231
- def test_array_interface (self , data ):
232
- # NumPy array shape inference
233
- super ().test_array_interface (data )
234
-
235
-
236
- class TestMethods (BaseNumPyTests , base .BaseMethodsTests ):
237
202
@skip_nested
238
203
def test_shift_fill_value (self , data ):
239
204
# np.array shape inference. Shift implementation fails.
@@ -251,7 +216,9 @@ def test_fillna_copy_series(self, data_missing):
251
216
252
217
@skip_nested
253
218
def test_searchsorted (self , data_for_sorting , as_series ):
254
- # Test setup fails.
219
+ # TODO: NumpyExtensionArray.searchsorted calls ndarray.searchsorted which
220
+ # isn't quite what we want in nested data cases. Instead we need to
221
+ # adapt something like libindex._bin_search.
255
222
super ().test_searchsorted (data_for_sorting , as_series )
256
223
257
224
@pytest .mark .xfail (reason = "NumpyExtensionArray.diff may fail on dtype" )
@@ -270,38 +237,60 @@ def test_insert_invalid(self, data, invalid_scalar):
270
237
# NumpyExtensionArray[object] can hold anything, so skip
271
238
super ().test_insert_invalid (data , invalid_scalar )
272
239
273
-
274
- class TestArithmetics (BaseNumPyTests , base .BaseArithmeticOpsTests ):
275
240
divmod_exc = None
276
241
series_scalar_exc = None
277
242
frame_scalar_exc = None
278
243
series_array_exc = None
279
244
280
- @skip_nested
281
245
def test_divmod (self , data ):
246
+ divmod_exc = None
247
+ if data .dtype .kind == "O" :
248
+ divmod_exc = TypeError
249
+ self .divmod_exc = divmod_exc
282
250
super ().test_divmod (data )
283
251
284
- @skip_nested
285
- def test_arith_series_with_scalar (self , data , all_arithmetic_operators ):
252
+ def test_divmod_series_array (self , data ):
253
+ ser = pd .Series (data )
254
+ exc = None
255
+ if data .dtype .kind == "O" :
256
+ exc = TypeError
257
+ self .divmod_exc = exc
258
+ self ._check_divmod_op (ser , divmod , data )
259
+
260
+ def test_arith_series_with_scalar (self , data , all_arithmetic_operators , request ):
261
+ opname = all_arithmetic_operators
262
+ series_scalar_exc = None
263
+ if data .dtype .numpy_dtype == object :
264
+ if opname in ["__mul__" , "__rmul__" ]:
265
+ mark = pytest .mark .xfail (
266
+ reason = "the Series.combine step raises but not the Series method."
267
+ )
268
+ request .node .add_marker (mark )
269
+ series_scalar_exc = TypeError
270
+ self .series_scalar_exc = series_scalar_exc
286
271
super ().test_arith_series_with_scalar (data , all_arithmetic_operators )
287
272
288
- def test_arith_series_with_array (self , data , all_arithmetic_operators , request ):
273
+ def test_arith_series_with_array (self , data , all_arithmetic_operators ):
289
274
opname = all_arithmetic_operators
275
+ series_array_exc = None
290
276
if data .dtype .numpy_dtype == object and opname not in ["__add__" , "__radd__" ]:
291
- mark = pytest . mark . xfail ( reason = "Fails for object dtype" )
292
- request . applymarker ( mark )
277
+ series_array_exc = TypeError
278
+ self . series_array_exc = series_array_exc
293
279
super ().test_arith_series_with_array (data , all_arithmetic_operators )
294
280
295
- @skip_nested
296
- def test_arith_frame_with_scalar (self , data , all_arithmetic_operators ):
281
+ def test_arith_frame_with_scalar (self , data , all_arithmetic_operators , request ):
282
+ opname = all_arithmetic_operators
283
+ frame_scalar_exc = None
284
+ if data .dtype .numpy_dtype == object :
285
+ if opname in ["__mul__" , "__rmul__" ]:
286
+ mark = pytest .mark .xfail (
287
+ reason = "the Series.combine step raises but not the Series method."
288
+ )
289
+ request .node .add_marker (mark )
290
+ frame_scalar_exc = TypeError
291
+ self .frame_scalar_exc = frame_scalar_exc
297
292
super ().test_arith_frame_with_scalar (data , all_arithmetic_operators )
298
293
299
-
300
- class TestPrinting (BaseNumPyTests , base .BasePrintingTests ):
301
- pass
302
-
303
-
304
- class TestReduce (BaseNumPyTests , base .BaseReduceTests ):
305
294
def _supports_reduction (self , ser : pd .Series , op_name : str ) -> bool :
306
295
if ser .dtype .kind == "O" :
307
296
return op_name in ["sum" , "min" , "max" , "any" , "all" ]
@@ -328,8 +317,6 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
328
317
def test_reduce_frame (self , data , all_numeric_reductions , skipna ):
329
318
pass
330
319
331
-
332
- class TestMissing (BaseNumPyTests , base .BaseMissingTests ):
333
320
@skip_nested
334
321
def test_fillna_series (self , data_missing ):
335
322
# Non-scalar "scalar" values.
@@ -340,12 +327,6 @@ def test_fillna_frame(self, data_missing):
340
327
# Non-scalar "scalar" values.
341
328
super ().test_fillna_frame (data_missing )
342
329
343
-
344
- class TestReshaping (BaseNumPyTests , base .BaseReshapingTests ):
345
- pass
346
-
347
-
348
- class TestSetitem (BaseNumPyTests , base .BaseSetitemTests ):
349
330
@skip_nested
350
331
def test_setitem_invalid (self , data , invalid_scalar ):
351
332
# object dtype can hold anything, so doesn't raise
@@ -431,11 +412,25 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):
431
412
expected = pd .DataFrame ({"data" : data .to_numpy ()})
432
413
tm .assert_frame_equal (result , expected , check_column_type = False )
433
414
415
+ @pytest .mark .xfail (reason = "NumpyEADtype is unpacked" )
416
+ def test_index_from_listlike_with_dtype (self , data ):
417
+ super ().test_index_from_listlike_with_dtype (data )
434
418
435
- @skip_nested
436
- class TestParsing (BaseNumPyTests , base .BaseParsingTests ):
437
- pass
419
+ @skip_nested
420
+ @pytest .mark .parametrize ("engine" , ["c" , "python" ])
421
+ def test_EA_types (self , engine , data , request ):
422
+ super ().test_EA_types (engine , data , request )
423
+
424
+ @pytest .mark .xfail (reason = "Expect NumpyEA, get np.ndarray" )
425
+ def test_compare_array (self , data , comparison_op ):
426
+ super ().test_compare_array (data , comparison_op )
427
+
428
+ def test_compare_scalar (self , data , comparison_op , request ):
429
+ if data .dtype .kind == "f" or comparison_op .__name__ in ["eq" , "ne" ]:
430
+ mark = pytest .mark .xfail (reason = "Expect NumpyEA, get np.ndarray" )
431
+ request .applymarker (mark )
432
+ super ().test_compare_scalar (data , comparison_op )
438
433
439
434
440
- class Test2DCompat (BaseNumPyTests , base .NDArrayBacked2DTests ):
435
+ class Test2DCompat (base .NDArrayBacked2DTests ):
441
436
pass
0 commit comments