Skip to content

Commit 8a2942c

Browse files
authored
REF: share astype code in MaskedArray (#38490)
1 parent dee5603 commit 8a2942c

File tree

5 files changed

+41
-54
lines changed

5 files changed

+41
-54
lines changed

pandas/core/arrays/boolean.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010

1111
from pandas.core.dtypes.common import (
1212
is_bool_dtype,
13-
is_extension_array_dtype,
1413
is_float,
1514
is_float_dtype,
1615
is_integer_dtype,
1716
is_list_like,
1817
is_numeric_dtype,
1918
pandas_dtype,
2019
)
21-
from pandas.core.dtypes.dtypes import register_extension_dtype
20+
from pandas.core.dtypes.dtypes import ExtensionDtype, register_extension_dtype
2221
from pandas.core.dtypes.missing import isna
2322

2423
from pandas.core import ops
@@ -372,34 +371,22 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
372371
if incompatible type with an BooleanDtype, equivalent of same_kind
373372
casting
374373
"""
375-
from pandas.core.arrays.string_ import StringDtype
376-
377374
dtype = pandas_dtype(dtype)
378375

379-
if isinstance(dtype, BooleanDtype):
380-
values, mask = coerce_to_array(self, copy=copy)
381-
if not copy:
382-
return self
383-
else:
384-
return BooleanArray(values, mask, copy=False)
385-
elif isinstance(dtype, StringDtype):
386-
return dtype.construct_array_type()._from_sequence(self, copy=False)
376+
if isinstance(dtype, ExtensionDtype):
377+
return super().astype(dtype, copy)
387378

388379
if is_bool_dtype(dtype):
389380
# astype_nansafe converts np.nan to True
390381
if self._hasna:
391382
raise ValueError("cannot convert float NaN to bool")
392383
else:
393384
return self._data.astype(dtype, copy=copy)
394-
if is_extension_array_dtype(dtype) and is_integer_dtype(dtype):
395-
from pandas.core.arrays import IntegerArray
396385

397-
return IntegerArray(
398-
self._data.astype(dtype.numpy_dtype), self._mask.copy(), copy=False
399-
)
400386
# for integer, error if there are missing values
401387
if is_integer_dtype(dtype) and self._hasna:
402388
raise ValueError("cannot convert NA to integer")
389+
403390
# for float dtype, ensure we use np.nan before casting (numpy cannot
404391
# deal with pd.NA)
405392
na_value = self._na_value

pandas/core/arrays/floating.py

+3-18
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@
1919
is_object_dtype,
2020
pandas_dtype,
2121
)
22-
from pandas.core.dtypes.dtypes import register_extension_dtype
22+
from pandas.core.dtypes.dtypes import ExtensionDtype, register_extension_dtype
2323
from pandas.core.dtypes.missing import isna
2424

2525
from pandas.core import ops
2626
from pandas.core.ops import invalid_comparison
2727
from pandas.core.tools.numeric import to_numeric
2828

29-
from .masked import BaseMaskedDtype
3029
from .numeric import NumericArray, NumericDtype
3130

3231

@@ -332,24 +331,10 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
332331
if incompatible type with an FloatingDtype, equivalent of same_kind
333332
casting
334333
"""
335-
from pandas.core.arrays.string_ import StringArray, StringDtype
336-
337334
dtype = pandas_dtype(dtype)
338335

339-
# if the dtype is exactly the same, we can fastpath
340-
if self.dtype == dtype:
341-
# return the same object for copy=False
342-
return self.copy() if copy else self
343-
# if we are astyping to another nullable masked dtype, we can fastpath
344-
if isinstance(dtype, BaseMaskedDtype):
345-
# TODO deal with NaNs
346-
data = self._data.astype(dtype.numpy_dtype, copy=copy)
347-
# mask is copied depending on whether the data was copied, and
348-
# not directly depending on the `copy` keyword
349-
mask = self._mask if data is self._data else self._mask.copy()
350-
return dtype.construct_array_type()(data, mask, copy=False)
351-
elif isinstance(dtype, StringDtype):
352-
return StringArray._from_sequence(self, copy=False)
336+
if isinstance(dtype, ExtensionDtype):
337+
return super().astype(dtype, copy=copy)
353338

354339
# coerce
355340
if is_float_dtype(dtype):

pandas/core/arrays/integer.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pandas.compat.numpy import function as nv
1010
from pandas.util._decorators import cache_readonly
1111

12-
from pandas.core.dtypes.base import register_extension_dtype
12+
from pandas.core.dtypes.base import ExtensionDtype, register_extension_dtype
1313
from pandas.core.dtypes.common import (
1414
is_bool_dtype,
1515
is_datetime64_dtype,
@@ -390,24 +390,10 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
390390
if incompatible type with an IntegerDtype, equivalent of same_kind
391391
casting
392392
"""
393-
from pandas.core.arrays.masked import BaseMaskedDtype
394-
from pandas.core.arrays.string_ import StringDtype
395-
396393
dtype = pandas_dtype(dtype)
397394

398-
# if the dtype is exactly the same, we can fastpath
399-
if self.dtype == dtype:
400-
# return the same object for copy=False
401-
return self.copy() if copy else self
402-
# if we are astyping to another nullable masked dtype, we can fastpath
403-
if isinstance(dtype, BaseMaskedDtype):
404-
data = self._data.astype(dtype.numpy_dtype, copy=copy)
405-
# mask is copied depending on whether the data was copied, and
406-
# not directly depending on the `copy` keyword
407-
mask = self._mask if data is self._data else self._mask.copy()
408-
return dtype.construct_array_type()(data, mask, copy=False)
409-
elif isinstance(dtype, StringDtype):
410-
return dtype.construct_array_type()._from_sequence(self, copy=False)
395+
if isinstance(dtype, ExtensionDtype):
396+
return super().astype(dtype, copy=copy)
411397

412398
# coerce
413399
if is_float_dtype(dtype):

pandas/core/arrays/masked.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
import numpy as np
66

77
from pandas._libs import lib, missing as libmissing
8-
from pandas._typing import Scalar
8+
from pandas._typing import ArrayLike, Dtype, Scalar
99
from pandas.errors import AbstractMethodError
1010
from pandas.util._decorators import cache_readonly, doc
1111

1212
from pandas.core.dtypes.base import ExtensionDtype
1313
from pandas.core.dtypes.common import (
14+
is_dtype_equal,
1415
is_integer,
1516
is_object_dtype,
1617
is_scalar,
1718
is_string_dtype,
19+
pandas_dtype,
1820
)
1921
from pandas.core.dtypes.missing import isna, notna
2022

@@ -229,6 +231,30 @@ def to_numpy(
229231
data = self._data.astype(dtype, copy=copy)
230232
return data
231233

234+
def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
235+
dtype = pandas_dtype(dtype)
236+
237+
if is_dtype_equal(dtype, self.dtype):
238+
if copy:
239+
return self.copy()
240+
return self
241+
242+
# if we are astyping to another nullable masked dtype, we can fastpath
243+
if isinstance(dtype, BaseMaskedDtype):
244+
# TODO deal with NaNs for FloatingArray case
245+
data = self._data.astype(dtype.numpy_dtype, copy=copy)
246+
# mask is copied depending on whether the data was copied, and
247+
# not directly depending on the `copy` keyword
248+
mask = self._mask if data is self._data else self._mask.copy()
249+
cls = dtype.construct_array_type()
250+
return cls(data, mask, copy=False)
251+
252+
if isinstance(dtype, ExtensionDtype):
253+
eacls = dtype.construct_array_type()
254+
return eacls._from_sequence(self, dtype=dtype, copy=copy)
255+
256+
raise NotImplementedError("subclass must implement astype to np.dtype")
257+
232258
__array_priority__ = 1000 # higher than ndarray so ops dispatch to us
233259

234260
def __array__(self, dtype=None) -> np.ndarray:

pandas/core/arrays/string_.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pandas.core.dtypes.common import (
1111
is_array_like,
1212
is_bool_dtype,
13+
is_dtype_equal,
1314
is_integer_dtype,
1415
is_object_dtype,
1516
is_string_dtype,
@@ -285,10 +286,12 @@ def __setitem__(self, key, value):
285286

286287
def astype(self, dtype, copy=True):
287288
dtype = pandas_dtype(dtype)
288-
if isinstance(dtype, StringDtype):
289+
290+
if is_dtype_equal(dtype, self.dtype):
289291
if copy:
290292
return self.copy()
291293
return self
294+
292295
elif isinstance(dtype, _IntegerDtype):
293296
arr = self._ndarray.copy()
294297
mask = self.isna()

0 commit comments

Comments
 (0)