Skip to content

Commit 2e8c134

Browse files
benoit9126jbrockmendel
authored andcommitted
ENH: Implement rounding for floating dtype array pandas-dev#38844 (pandas-dev#39751)
1 parent bf5aa60 commit 2e8c134

File tree

4 files changed

+95
-10
lines changed

4 files changed

+95
-10
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ Other enhancements
140140
- :meth:`pandas.read_stata` and :class:`StataReader` support reading data from compressed files.
141141
- Add support for parsing ``ISO 8601``-like timestamps with negative signs to :meth:`pandas.Timedelta` (:issue:`37172`)
142142
- Add support for unary operators in :class:`FloatingArray` (:issue:`38749`)
143+
- :meth:`round` being enabled for the nullable integer and floating dtypes (:issue:`38844`)
143144

144145
.. ---------------------------------------------------------------------------
145146

pandas/core/arrays/numeric.py

+32
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TYPE_CHECKING,
77
Any,
88
List,
9+
TypeVar,
910
Union,
1011
)
1112

@@ -15,6 +16,7 @@
1516
Timedelta,
1617
missing as libmissing,
1718
)
19+
from pandas.compat.numpy import function as nv
1820
from pandas.errors import AbstractMethodError
1921

2022
from pandas.core.dtypes.common import (
@@ -34,6 +36,8 @@
3436
if TYPE_CHECKING:
3537
import pyarrow
3638

39+
T = TypeVar("T", bound="NumericArray")
40+
3741

3842
class NumericDtype(BaseMaskedDtype):
3943
def __from_arrow__(
@@ -208,3 +212,31 @@ def __pos__(self):
208212

209213
def __abs__(self):
210214
return type(self)(abs(self._data), self._mask.copy())
215+
216+
def round(self: T, decimals: int = 0, *args, **kwargs) -> T:
217+
"""
218+
Round each value in the array a to the given number of decimals.
219+
220+
Parameters
221+
----------
222+
decimals : int, default 0
223+
Number of decimal places to round to. If decimals is negative,
224+
it specifies the number of positions to the left of the decimal point.
225+
*args, **kwargs
226+
Additional arguments and keywords have no effect but might be
227+
accepted for compatibility with NumPy.
228+
229+
Returns
230+
-------
231+
NumericArray
232+
Rounded values of the NumericArray.
233+
234+
See Also
235+
--------
236+
numpy.around : Round values of an np.array.
237+
DataFrame.round : Round values of a DataFrame.
238+
Series.round : Round values of a Series.
239+
"""
240+
nv.validate_round(args, kwargs)
241+
values = np.round(self._data, decimals=decimals, **kwargs)
242+
return type(self)(values, self._mask.copy())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.core.dtypes.common import is_integer_dtype
5+
6+
import pandas as pd
7+
import pandas._testing as tm
8+
9+
arrays = [pd.array([1, 2, 3, None], dtype=dtype) for dtype in tm.ALL_EA_INT_DTYPES]
10+
arrays += [
11+
pd.array([0.141, -0.268, 5.895, None], dtype=dtype) for dtype in tm.FLOAT_EA_DTYPES
12+
]
13+
14+
15+
@pytest.fixture(params=arrays, ids=[a.dtype.name for a in arrays])
16+
def data(request):
17+
return request.param
18+
19+
20+
@pytest.fixture()
21+
def numpy_dtype(data):
22+
# For integer dtype, the numpy conversion must be done to float
23+
if is_integer_dtype(data):
24+
numpy_dtype = float
25+
else:
26+
numpy_dtype = data.dtype.type
27+
return numpy_dtype
28+
29+
30+
def test_round(data, numpy_dtype):
31+
# No arguments
32+
result = data.round()
33+
expected = pd.array(
34+
np.round(data.to_numpy(dtype=numpy_dtype, na_value=None)), dtype=data.dtype
35+
)
36+
tm.assert_extension_array_equal(result, expected)
37+
38+
# Decimals argument
39+
result = data.round(decimals=2)
40+
expected = pd.array(
41+
np.round(data.to_numpy(dtype=numpy_dtype, na_value=None), decimals=2),
42+
dtype=data.dtype,
43+
)
44+
tm.assert_extension_array_equal(result, expected)

pandas/tests/series/methods/test_round.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,41 @@ def test_round(self, datetime_series):
1616
tm.assert_series_equal(result, expected)
1717
assert result.name == datetime_series.name
1818

19-
def test_round_numpy(self):
19+
def test_round_numpy(self, any_float_allowed_nullable_dtype):
2020
# See GH#12600
21-
ser = Series([1.53, 1.36, 0.06])
21+
ser = Series([1.53, 1.36, 0.06], dtype=any_float_allowed_nullable_dtype)
2222
out = np.round(ser, decimals=0)
23-
expected = Series([2.0, 1.0, 0.0])
23+
expected = Series([2.0, 1.0, 0.0], dtype=any_float_allowed_nullable_dtype)
2424
tm.assert_series_equal(out, expected)
2525

2626
msg = "the 'out' parameter is not supported"
2727
with pytest.raises(ValueError, match=msg):
2828
np.round(ser, decimals=0, out=ser)
2929

30-
def test_round_numpy_with_nan(self):
30+
def test_round_numpy_with_nan(self, any_float_allowed_nullable_dtype):
3131
# See GH#14197
32-
ser = Series([1.53, np.nan, 0.06])
32+
ser = Series([1.53, np.nan, 0.06], dtype=any_float_allowed_nullable_dtype)
3333
with tm.assert_produces_warning(None):
3434
result = ser.round()
35-
expected = Series([2.0, np.nan, 0.0])
35+
expected = Series([2.0, np.nan, 0.0], dtype=any_float_allowed_nullable_dtype)
3636
tm.assert_series_equal(result, expected)
3737

38-
def test_round_builtin(self):
39-
ser = Series([1.123, 2.123, 3.123], index=range(3))
38+
def test_round_builtin(self, any_float_allowed_nullable_dtype):
39+
ser = Series(
40+
[1.123, 2.123, 3.123],
41+
index=range(3),
42+
dtype=any_float_allowed_nullable_dtype,
43+
)
4044
result = round(ser)
41-
expected_rounded0 = Series([1.0, 2.0, 3.0], index=range(3))
45+
expected_rounded0 = Series(
46+
[1.0, 2.0, 3.0], index=range(3), dtype=any_float_allowed_nullable_dtype
47+
)
4248
tm.assert_series_equal(result, expected_rounded0)
4349

4450
decimals = 2
45-
expected_rounded = Series([1.12, 2.12, 3.12], index=range(3))
51+
expected_rounded = Series(
52+
[1.12, 2.12, 3.12], index=range(3), dtype=any_float_allowed_nullable_dtype
53+
)
4654
result = round(ser, decimals)
4755
tm.assert_series_equal(result, expected_rounded)
4856

0 commit comments

Comments
 (0)