Skip to content

Commit a3c50a6

Browse files
TomAugspurgerjreback
authored andcommitted
API: dispatch to EA.astype (#22343)
1 parent 83be235 commit a3c50a6

File tree

7 files changed

+85
-20
lines changed

7 files changed

+85
-20
lines changed

doc/source/whatsnew/v0.24.0.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ ExtensionType Changes
446446
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
447447
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
448448
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
449-
-
449+
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
450450

451451
.. _whatsnew_0240.api.incompatibilities:
452452

pandas/core/arrays/integer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandas.compat import u, range
99
from pandas.compat import set_function_name
1010

11+
from pandas.core.dtypes.cast import astype_nansafe
1112
from pandas.core.dtypes.generic import ABCSeries, ABCIndexClass
1213
from pandas.core.dtypes.common import (
1314
is_integer, is_scalar, is_float,
@@ -410,7 +411,7 @@ def astype(self, dtype, copy=True):
410411

411412
# coerce
412413
data = self._coerce_to_ndarray()
413-
return data.astype(dtype=dtype, copy=False)
414+
return astype_nansafe(data, dtype, copy=None)
414415

415416
@property
416417
def _ndarray_values(self):

pandas/core/dtypes/cast.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,16 @@ def conv(r, dtype):
647647

648648
def astype_nansafe(arr, dtype, copy=True):
649649
""" return a view if copy is False, but
650-
need to be very careful as the result shape could change! """
650+
need to be very careful as the result shape could change!
651+
652+
Parameters
653+
----------
654+
arr : ndarray
655+
dtype : np.dtype
656+
copy : bool, default True
657+
If False, a view will be attempted but may fail, if
658+
e.g. the itemsizes don't align.
659+
"""
651660

652661
# dispatch on extension dtype if needed
653662
if is_extension_array_dtype(dtype):
@@ -733,8 +742,10 @@ def astype_nansafe(arr, dtype, copy=True):
733742
FutureWarning, stacklevel=5)
734743
dtype = np.dtype(dtype.name + "[ns]")
735744

736-
if copy:
745+
if copy or is_object_dtype(arr) or is_object_dtype(dtype):
746+
# Explicit copy, or required since NumPy can't view from / to object.
737747
return arr.astype(dtype, copy=True)
748+
738749
return arr.view(dtype)
739750

740751

pandas/core/internals/blocks.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -637,22 +637,25 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
637637
# force the copy here
638638
if values is None:
639639

640-
if issubclass(dtype.type,
641-
(compat.text_type, compat.string_types)):
640+
if self.is_extension:
641+
values = self.values.astype(dtype)
642+
else:
643+
if issubclass(dtype.type,
644+
(compat.text_type, compat.string_types)):
642645

643-
# use native type formatting for datetime/tz/timedelta
644-
if self.is_datelike:
645-
values = self.to_native_types()
646+
# use native type formatting for datetime/tz/timedelta
647+
if self.is_datelike:
648+
values = self.to_native_types()
646649

647-
# astype formatting
648-
else:
649-
values = self.get_values()
650+
# astype formatting
651+
else:
652+
values = self.get_values()
650653

651-
else:
652-
values = self.get_values(dtype=dtype)
654+
else:
655+
values = self.get_values(dtype=dtype)
653656

654-
# _astype_nansafe works fine with 1-d only
655-
values = astype_nansafe(values.ravel(), dtype, copy=True)
657+
# _astype_nansafe works fine with 1-d only
658+
values = astype_nansafe(values.ravel(), dtype, copy=True)
656659

657660
# TODO(extension)
658661
# should we make this attribute?

pandas/tests/extension/decimal/array.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ class DecimalDtype(ExtensionDtype):
1515
name = 'decimal'
1616
na_value = decimal.Decimal('NaN')
1717

18+
def __init__(self, context=None):
19+
self.context = context or decimal.getcontext()
20+
21+
def __eq__(self, other):
22+
if isinstance(other, type(self)):
23+
return self.context == other.context
24+
return super(DecimalDtype, self).__eq__(other)
25+
26+
def __repr__(self):
27+
return 'DecimalDtype(context={})'.format(self.context)
28+
1829
@classmethod
1930
def construct_array_type(cls):
2031
"""Return the array type associated with this dtype
@@ -35,13 +46,12 @@ def construct_from_string(cls, string):
3546

3647

3748
class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin):
38-
dtype = DecimalDtype()
3949

40-
def __init__(self, values, dtype=None, copy=False):
50+
def __init__(self, values, dtype=None, copy=False, context=None):
4151
for val in values:
42-
if not isinstance(val, self.dtype.type):
52+
if not isinstance(val, decimal.Decimal):
4353
raise TypeError("All values must be of type " +
44-
str(self.dtype.type))
54+
str(decimal.Decimal))
4555
values = np.asarray(values, dtype=object)
4656

4757
self._data = values
@@ -51,6 +61,11 @@ def __init__(self, values, dtype=None, copy=False):
5161
# those aliases are currently not working due to assumptions
5262
# in internal code (GH-20735)
5363
# self._values = self.values = self.data
64+
self._dtype = DecimalDtype(context)
65+
66+
@property
67+
def dtype(self):
68+
return self._dtype
5469

5570
@classmethod
5671
def _from_sequence(cls, scalars, dtype=None, copy=False):
@@ -82,6 +97,11 @@ def copy(self, deep=False):
8297
return type(self)(self._data.copy())
8398
return type(self)(self)
8499

100+
def astype(self, dtype, copy=True):
101+
if isinstance(dtype, type(self.dtype)):
102+
return type(self)(self._data, context=dtype.context)
103+
return super(DecimalArray, self).astype(dtype, copy)
104+
85105
def __setitem__(self, key, value):
86106
if pd.api.types.is_list_like(value):
87107
value = [decimal.Decimal(v) for v in value]

pandas/tests/extension/decimal/test_decimal.py

+21
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,27 @@ def test_dataframe_constructor_with_dtype():
205205
tm.assert_frame_equal(result, expected)
206206

207207

208+
@pytest.mark.parametrize("frame", [True, False])
209+
def test_astype_dispatches(frame):
210+
# This is a dtype-specific test that ensures Series[decimal].astype
211+
# gets all the way through to ExtensionArray.astype
212+
# Designing a reliable smoke test that works for arbitrary data types
213+
# is difficult.
214+
data = pd.Series(DecimalArray([decimal.Decimal(2)]), name='a')
215+
ctx = decimal.Context()
216+
ctx.prec = 5
217+
218+
if frame:
219+
data = data.to_frame()
220+
221+
result = data.astype(DecimalDtype(ctx))
222+
223+
if frame:
224+
result = result['a']
225+
226+
assert result.dtype.context.prec == ctx.prec
227+
228+
208229
class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests):
209230

210231
def check_opname(self, s, op_name, other, exc=None):

pandas/tests/extension/integer/test_integer.py

+9
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,15 @@ def test_cross_type_arithmetic():
766766
tm.assert_series_equal(result, expected)
767767

768768

769+
def test_astype_nansafe():
770+
# https://github.com/pandas-dev/pandas/pull/22343
771+
arr = IntegerArray([np.nan, 1, 2], dtype="Int8")
772+
773+
with tm.assert_raises_regex(
774+
ValueError, 'cannot convert float NaN to integer'):
775+
arr.astype('uint32')
776+
777+
769778
# TODO(jreback) - these need testing / are broken
770779

771780
# shift

0 commit comments

Comments
 (0)