Skip to content

Commit fbf0a06

Browse files
committed
API: Default ExtensionArray.astype
(cherry picked from commit 943a915562b72bed147c857de927afa0daf31c1a)
1 parent b835127 commit fbf0a06

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

pandas/core/arrays/base.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""An interface for extending pandas with custom arrays."""
2+
import numpy as np
3+
24
from pandas.errors import AbstractMethodError
35

46
_not_implemented_message = "{} does not implement {}."
@@ -138,6 +140,34 @@ def nbytes(self):
138140
# ------------------------------------------------------------------------
139141
# Additional Methods
140142
# ------------------------------------------------------------------------
143+
def astype(self, dtype, copy=True):
144+
"""Cast to a NumPy array with 'dtype'.
145+
146+
The default implementation only allows casting to 'object' dtype.
147+
148+
Parameters
149+
----------
150+
dtype : str or dtype
151+
Typecode or data-type to which the array is cast.
152+
copy : bool, default True
153+
Whether to copy the data, even if not necessary. If False,
154+
a copy is made only if the old dtype does not match the
155+
new dtype.
156+
157+
Returns
158+
-------
159+
array : ndarray
160+
NumPy ndarray with 'dtype' for its dtype.
161+
"""
162+
np_dtype = np.dtype(dtype)
163+
164+
if np_dtype != 'object':
165+
msg = ("{} can only be coerced to 'object' dtype, "
166+
"not '{}'.").format(type(self).__name__, dtype)
167+
raise ValueError(msg)
168+
169+
return np.array(self, dtype=np_dtype, copy=copy)
170+
141171
def isna(self):
142172
# type: () -> np.ndarray
143173
"""Boolean NumPy array indicating if each value is missing.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
3+
import pandas.util.testing as tm
4+
from pandas.core.arrays import ExtensionArray
5+
6+
7+
class DummyArray(ExtensionArray):
8+
9+
def __init__(self, data):
10+
self.data = data
11+
12+
def __array__(self, dtype):
13+
return self.data
14+
15+
16+
def test_astype():
17+
arr = DummyArray(np.array([1, 2, 3]))
18+
expected = np.array([1, 2, 3], dtype=object)
19+
20+
result = arr.astype(object)
21+
tm.assert_numpy_array_equal(result, expected)
22+
23+
result = arr.astype('object')
24+
tm.assert_numpy_array_equal(result, expected)
25+
26+
27+
def test_astype_raises():
28+
arr = DummyArray(np.array([1, 2, 3]))
29+
30+
xpr = ("DummyArray can only be coerced to 'object' dtype, not "
31+
"'<class 'int'>'")
32+
33+
with tm.assert_raises_regex(ValueError, xpr):
34+
arr.astype(int)

0 commit comments

Comments
 (0)