|
5 | 5 | import numpy as np
|
6 | 6 |
|
7 | 7 | from pandas._libs import lib, missing as libmissing
|
8 |
| -from pandas._typing import Scalar |
| 8 | +from pandas._typing import ArrayLike, Dtype, Scalar |
9 | 9 | from pandas.errors import AbstractMethodError
|
10 | 10 | from pandas.util._decorators import cache_readonly, doc
|
11 | 11 |
|
12 | 12 | from pandas.core.dtypes.base import ExtensionDtype
|
13 | 13 | from pandas.core.dtypes.common import (
|
| 14 | + is_dtype_equal, |
14 | 15 | is_integer,
|
15 | 16 | is_object_dtype,
|
16 | 17 | is_scalar,
|
17 | 18 | is_string_dtype,
|
| 19 | + pandas_dtype, |
18 | 20 | )
|
19 | 21 | from pandas.core.dtypes.missing import isna, notna
|
20 | 22 |
|
@@ -229,6 +231,30 @@ def to_numpy(
|
229 | 231 | data = self._data.astype(dtype, copy=copy)
|
230 | 232 | return data
|
231 | 233 |
|
| 234 | + def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: |
| 235 | + dtype = pandas_dtype(dtype) |
| 236 | + |
| 237 | + if is_dtype_equal(dtype, self.dtype): |
| 238 | + if copy: |
| 239 | + return self.copy() |
| 240 | + return self |
| 241 | + |
| 242 | + # if we are astyping to another nullable masked dtype, we can fastpath |
| 243 | + if isinstance(dtype, BaseMaskedDtype): |
| 244 | + # TODO deal with NaNs for FloatingArray case |
| 245 | + data = self._data.astype(dtype.numpy_dtype, copy=copy) |
| 246 | + # mask is copied depending on whether the data was copied, and |
| 247 | + # not directly depending on the `copy` keyword |
| 248 | + mask = self._mask if data is self._data else self._mask.copy() |
| 249 | + cls = dtype.construct_array_type() |
| 250 | + return cls(data, mask, copy=False) |
| 251 | + |
| 252 | + if isinstance(dtype, ExtensionDtype): |
| 253 | + eacls = dtype.construct_array_type() |
| 254 | + return eacls._from_sequence(self, dtype=dtype, copy=copy) |
| 255 | + |
| 256 | + raise NotImplementedError("subclass must implement astype to np.dtype") |
| 257 | + |
232 | 258 | __array_priority__ = 1000 # higher than ndarray so ops dispatch to us
|
233 | 259 |
|
234 | 260 | def __array__(self, dtype=None) -> np.ndarray:
|
|
0 commit comments