Skip to content

Commit e1e4585

Browse files
Testsuite now passes with both Pandas 2.0.2 and 2.1.0rc0
Restore some Pandas 2.0.x interfaces and use pandas_version_info from pint_array to condition logic. Also change some more `self.assert_series_equal` to `tm.assert_series_equal` in the restored code to accommodate Pandas 2.1 BaseExtensionTests behavior. Signed-off-by: Michael Tiemann <[email protected]>
1 parent 630c2ae commit e1e4585

File tree

2 files changed

+71
-20
lines changed

2 files changed

+71
-20
lines changed

pint_pandas/testsuite/test_issues.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pint.testsuite import helpers
1111

1212
from pint_pandas import PintArray, PintType
13+
from pint_pandas.pint_array import pandas_version_info
1314

1415
ureg = PintType.ureg
1516

@@ -172,6 +173,8 @@ def test_issue_127():
172173

173174
class TestIssue174(BaseExtensionTests):
174175
def test_sum(self):
176+
if pandas_version_info < (2, 1):
177+
pytest.skip("Pandas reduce functions strip units prior to version 2.1.0")
175178
a = pd.DataFrame([[0, 1, 2], [3, 4, 5]]).astype("pint[m]")
176179
row_sum = a.sum(axis=0)
177180
expected_1 = pd.Series([3, 5, 7], dtype="pint[m]")

pint_pandas/testsuite/test_pandas_extensiontests.py

+68-20
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pint.errors import DimensionalityError
2323

2424
from pint_pandas import PintArray, PintType
25-
from pint_pandas.pint_array import dtypemap
25+
from pint_pandas.pint_array import dtypemap, pandas_version_info
2626

2727
ureg = PintType.ureg
2828

@@ -328,6 +328,10 @@ def test_apply_simple_series(self, data):
328328
@pytest.mark.parametrize("na_action", [None, "ignore"])
329329
def test_map(self, data_missing, na_action):
330330
s = pd.Series(data_missing)
331+
if pandas_version_info < (2, 1) and na_action is not None:
332+
pytest.skip(
333+
"Pandas EA map function only accepts None as na_action parameter"
334+
)
331335
result = s.map(lambda x: x, na_action=na_action)
332336
expected = s
333337
tm.assert_series_equal(result, expected)
@@ -338,10 +342,6 @@ def test_insert_invalid(self):
338342

339343

340344
class TestArithmeticOps(base.BaseArithmeticOpsTests):
341-
# With Pint 0.21, series and scalar need to have compatible units for
342-
# the arithmetic to work
343-
# series & scalar
344-
345345
divmod_exc = None
346346
series_scalar_exc = None
347347
frame_scalar_exc = None
@@ -428,31 +428,73 @@ def _get_expected_exception(
428428
# Fall through...
429429
return exc
430430

431+
# The following methods are needed to work with Pandas < 2.1
432+
def _check_divmod_op(self, s, op, other, exc=None):
433+
# divmod has multiple return values, so check separately
434+
if exc is None:
435+
result_div, result_mod = op(s, other)
436+
if op is divmod:
437+
expected_div, expected_mod = s // other, s % other
438+
else:
439+
expected_div, expected_mod = other // s, other % s
440+
tm.assert_series_equal(result_div, expected_div)
441+
tm.assert_series_equal(result_mod, expected_mod)
442+
else:
443+
with pytest.raises(exc):
444+
divmod(s, other)
445+
446+
def _get_exception(self, data, op_name):
447+
if data.data.dtype == pd.core.dtypes.dtypes.PandasDtype("complex128"):
448+
if op_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"]:
449+
return op_name, TypeError
450+
if op_name in ["__pow__", "__rpow__"]:
451+
return op_name, DimensionalityError
452+
453+
return op_name, None
454+
431455
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
456+
# With Pint 0.21, series and scalar need to have compatible units for
457+
# the arithmetic to work
432458
# series & scalar
433-
op_name = all_arithmetic_operators
434-
ser = pd.Series(data)
435-
self.check_opname(ser, op_name, ser.iloc[0])
436-
437-
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
438-
# frame & scalar
439-
op_name = all_arithmetic_operators
440-
df = pd.DataFrame({"A": data})
441-
self.check_opname(df, op_name, data[0])
459+
if pandas_version_info < (2, 1):
460+
op_name, exc = self._get_exception(data, all_arithmetic_operators)
461+
s = pd.Series(data)
462+
self.check_opname(s, op_name, s.iloc[0], exc=exc)
463+
else:
464+
op_name = all_arithmetic_operators
465+
ser = pd.Series(data)
466+
self.check_opname(ser, op_name, ser.iloc[0])
442467

443468
def test_arith_series_with_array(self, data, all_arithmetic_operators):
444469
# ndarray & other series
445-
op_name = all_arithmetic_operators
446-
ser = pd.Series(data)
447-
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))
470+
if pandas_version_info < (2, 1):
471+
op_name, exc = self._get_exception(data, all_arithmetic_operators)
472+
ser = pd.Series(data)
473+
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc)
474+
else:
475+
op_name = all_arithmetic_operators
476+
ser = pd.Series(data)
477+
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))
478+
479+
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
480+
# frame & scalar
481+
if pandas_version_info < (2, 1):
482+
op_name, exc = self._get_exception(data, all_arithmetic_operators)
483+
df = pd.DataFrame({"A": data})
484+
self.check_opname(df, op_name, data[0], exc=exc)
485+
else:
486+
op_name = all_arithmetic_operators
487+
df = pd.DataFrame({"A": data})
488+
self.check_opname(df, op_name, data[0])
448489

449490
# parameterise this to try divisor not equal to 1 Mm
450491
@pytest.mark.parametrize("numeric_dtype", _base_numeric_dtypes, indirect=True)
451492
def test_divmod(self, data):
452-
s = pd.Series(data)
453-
self._check_divmod_op(s, divmod, 1 * ureg.Mm)
454-
self._check_divmod_op(1 * ureg.Mm, ops.rdivmod, s)
493+
ser = pd.Series(data)
494+
self._check_divmod_op(ser, divmod, 1 * ureg.Mm)
495+
self._check_divmod_op(1 * ureg.Mm, ops.rdivmod, ser)
455496

497+
@pytest.mark.parametrize("numeric_dtype", _base_numeric_dtypes, indirect=True)
456498
def test_divmod_series_array(self, data, data_for_twos):
457499
ser = pd.Series(data)
458500
self._check_divmod_op(ser, divmod, data)
@@ -615,6 +657,12 @@ def test_setitem_2d_values(self, data):
615657

616658

617659
class TestAccumulate(base.BaseAccumulateTests):
660+
@pytest.mark.parametrize("skipna", [True, False])
661+
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
662+
if pandas_version_info < (2, 1):
663+
# Should this be skip? Historic code simply used pass.
664+
pass
665+
618666
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
619667
return True
620668

0 commit comments

Comments
 (0)