Skip to content

Commit ace1dd5

Browse files
TST: base test for ExtensionArray.astype to its own type + copy keyword (#35116)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 9b6d66e commit ace1dd5

File tree

8 files changed

+37
-18
lines changed

8 files changed

+37
-18
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ ExtensionArray
366366
^^^^^^^^^^^^^^
367367

368368
- Fixed Bug where :class:`DataFrame` column set to scalar extension type via a dict instantion was considered an object type rather than the extension type (:issue:`35965`)
369+
- Fixed bug where ``astype()`` with equal dtype and ``copy=False`` would return a new object (:issue:`284881`)
369370
-
370371

371372

pandas/core/arrays/base.py

+5
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,11 @@ def astype(self, dtype, copy=True):
457457
from pandas.core.arrays.string_ import StringDtype
458458

459459
dtype = pandas_dtype(dtype)
460+
if is_dtype_equal(dtype, self.dtype):
461+
if not copy:
462+
return self
463+
elif copy:
464+
return self.copy()
460465
if isinstance(dtype, StringDtype): # allow conversion to StringArrays
461466
return dtype.construct_array_type()._from_sequence(self, copy=False)
462467

pandas/core/arrays/boolean.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,10 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
375375

376376
if isinstance(dtype, BooleanDtype):
377377
values, mask = coerce_to_array(self, copy=copy)
378-
return BooleanArray(values, mask, copy=False)
378+
if not copy:
379+
return self
380+
else:
381+
return BooleanArray(values, mask, copy=False)
379382
elif isinstance(dtype, StringDtype):
380383
return dtype.construct_array_type()._from_sequence(self, copy=False)
381384

pandas/core/arrays/period.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TD64NS_DTYPE,
3434
ensure_object,
3535
is_datetime64_dtype,
36+
is_dtype_equal,
3637
is_float_dtype,
3738
is_period_dtype,
3839
pandas_dtype,
@@ -582,7 +583,11 @@ def astype(self, dtype, copy: bool = True):
582583
# We handle Period[T] -> Period[U]
583584
# Our parent handles everything else.
584585
dtype = pandas_dtype(dtype)
585-
586+
if is_dtype_equal(dtype, self._dtype):
587+
if not copy:
588+
return self
589+
elif copy:
590+
return self.copy()
586591
if is_period_dtype(dtype):
587592
return self.asfreq(dtype.freq)
588593
return super().astype(dtype, copy=copy)

pandas/core/arrays/sparse/array.py

+5
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,11 @@ def astype(self, dtype=None, copy=True):
10631063
IntIndex
10641064
Indices: array([2, 3], dtype=int32)
10651065
"""
1066+
if is_dtype_equal(dtype, self._dtype):
1067+
if not copy:
1068+
return self
1069+
elif copy:
1070+
return self.copy()
10661071
dtype = self.dtype.update_dtype(dtype)
10671072
subtype = dtype._subtype_with_str
10681073
# TODO copy=False is broken for astype_nansafe with int -> float, so cannot

pandas/tests/extension/base/casting.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pandas as pd
45
from pandas.core.internals import ObjectBlock
@@ -56,3 +57,11 @@ def test_astype_empty_dataframe(self, dtype):
5657
df = pd.DataFrame()
5758
result = df.astype(dtype)
5859
self.assert_frame_equal(result, df)
60+
61+
@pytest.mark.parametrize("copy", [True, False])
62+
def test_astype_own_type(self, data, copy):
63+
# ensure that astype returns the original object for equal dtype and copy=False
64+
# https://github.com/pandas-dev/pandas/issues/28488
65+
result = data.astype(data.dtype, copy=copy)
66+
assert (result is data) is (not copy)
67+
self.assert_extension_array_equal(result, data)

pandas/tests/extension/decimal/array.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from pandas.core.dtypes.base import ExtensionDtype
10-
from pandas.core.dtypes.common import pandas_dtype
10+
from pandas.core.dtypes.common import is_dtype_equal, pandas_dtype
1111

1212
import pandas as pd
1313
from pandas.api.extensions import no_default, register_extension_dtype
@@ -131,9 +131,12 @@ def copy(self):
131131
return type(self)(self._data.copy())
132132

133133
def astype(self, dtype, copy=True):
134+
if is_dtype_equal(dtype, self._dtype):
135+
if not copy:
136+
return self
134137
dtype = pandas_dtype(dtype)
135138
if isinstance(dtype, type(self.dtype)):
136-
return type(self)(self._data, context=dtype.context)
139+
return type(self)(self._data, copy=copy, context=dtype.context)
137140

138141
return super().astype(dtype, copy=copy)
139142

pandas/tests/extension/test_numpy.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_take_series(self, data):
177177

178178
def test_loc_iloc_frame_single_dtype(self, data, request):
179179
npdtype = data.dtype.numpy_dtype
180-
if npdtype == object or npdtype == np.float64:
180+
if npdtype == object:
181181
# GH#33125
182182
mark = pytest.mark.xfail(
183183
reason="GH#33125 astype doesn't recognize data.dtype"
@@ -191,14 +191,6 @@ class TestGroupby(BaseNumPyTests, base.BaseGroupbyTests):
191191
def test_groupby_extension_apply(
192192
self, data_for_grouping, groupby_apply_op, request
193193
):
194-
# ValueError: Names should be list-like for a MultiIndex
195-
a = "a"
196-
is_identity = groupby_apply_op(a) is a
197-
if data_for_grouping.dtype.numpy_dtype == np.float64 and is_identity:
198-
mark = pytest.mark.xfail(
199-
reason="GH#33125 astype doesn't recognize data.dtype"
200-
)
201-
request.node.add_marker(mark)
202194
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)
203195

204196

@@ -306,11 +298,7 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
306298

307299

308300
class TestPrinting(BaseNumPyTests, base.BasePrintingTests):
309-
@pytest.mark.xfail(
310-
reason="GH#33125 PandasArray.astype does not recognize PandasDtype"
311-
)
312-
def test_series_repr(self, data):
313-
super().test_series_repr(data)
301+
pass
314302

315303

316304
@skip_nested

0 commit comments

Comments
 (0)