Skip to content

Commit 7246381

Browse files
TomAugspurgerharisbal
authored and
harisbal
committed
API: Default ExtensionArray.astype (pandas-dev#19604)
* API: Default ExtensionArray.astype (cherry picked from commit 943a915562b72bed147c857de927afa0daf31c1a) * Py2 compat * Moved * Moved dtypes
1 parent b5d4128 commit 7246381

File tree

4 files changed

+89
-31
lines changed

4 files changed

+89
-31
lines changed

pandas/core/arrays/base.py

+21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""An interface for extending pandas with custom arrays."""
2+
import numpy as np
3+
24
from pandas.errors import AbstractMethodError
35

46
_not_implemented_message = "{} does not implement {}."
@@ -138,6 +140,25 @@ def nbytes(self):
138140
# ------------------------------------------------------------------------
139141
# Additional Methods
140142
# ------------------------------------------------------------------------
143+
def astype(self, dtype, copy=True):
144+
"""Cast to a NumPy array with 'dtype'.
145+
146+
Parameters
147+
----------
148+
dtype : str or dtype
149+
Typecode or data-type to which the array is cast.
150+
copy : bool, default True
151+
Whether to copy the data, even if not necessary. If False,
152+
a copy is made only if the old dtype does not match the
153+
new dtype.
154+
155+
Returns
156+
-------
157+
array : ndarray
158+
NumPy ndarray with 'dtype' for its dtype.
159+
"""
160+
return np.array(self, dtype=dtype, copy=copy)
161+
141162
def isna(self):
142163
# type: () -> np.ndarray
143164
"""Boolean NumPy array indicating if each value is missing.

pandas/tests/dtypes/test_dtypes.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@
1010
Series, Categorical, CategoricalIndex, IntervalIndex, date_range)
1111

1212
from pandas.compat import string_types
13-
from pandas.core.arrays import ExtensionArray
1413
from pandas.core.dtypes.dtypes import (
1514
DatetimeTZDtype, PeriodDtype,
16-
IntervalDtype, CategoricalDtype, ExtensionDtype)
15+
IntervalDtype, CategoricalDtype)
1716
from pandas.core.dtypes.common import (
1817
is_categorical_dtype, is_categorical,
1918
is_datetime64tz_dtype, is_datetimetz,
20-
is_extension_array_dtype,
2119
is_period_dtype, is_period,
2220
is_dtype_equal, is_datetime64_ns_dtype,
2321
is_datetime64_dtype, is_interval_dtype,
@@ -744,31 +742,3 @@ def test_categorical_categories(self):
744742
tm.assert_index_equal(c1.categories, pd.Index(['a', 'b']))
745743
c1 = CategoricalDtype(CategoricalIndex(['a', 'b']))
746744
tm.assert_index_equal(c1.categories, pd.Index(['a', 'b']))
747-
748-
749-
class DummyArray(ExtensionArray):
750-
pass
751-
752-
753-
class DummyDtype(ExtensionDtype):
754-
pass
755-
756-
757-
class TestExtensionArrayDtype(object):
758-
759-
@pytest.mark.parametrize('values', [
760-
pd.Categorical([]),
761-
pd.Categorical([]).dtype,
762-
pd.Series(pd.Categorical([])),
763-
DummyDtype(),
764-
DummyArray(),
765-
])
766-
def test_is_extension_array_dtype(self, values):
767-
assert is_extension_array_dtype(values)
768-
769-
@pytest.mark.parametrize('values', [
770-
np.array([]),
771-
pd.Series(np.array([])),
772-
])
773-
def test_is_not_extension_array_dtype(self, values):
774-
assert not is_extension_array_dtype(values)

pandas/tests/extension/__init__.py

Whitespace-only changes.

pandas/tests/extension/test_common.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pandas as pd
5+
import pandas.util.testing as tm
6+
from pandas.core.arrays import ExtensionArray
7+
from pandas.core.dtypes.common import is_extension_array_dtype
8+
from pandas.core.dtypes.dtypes import ExtensionDtype
9+
10+
11+
class DummyDtype(ExtensionDtype):
12+
pass
13+
14+
15+
class DummyArray(ExtensionArray):
16+
17+
def __init__(self, data):
18+
self.data = data
19+
20+
def __array__(self, dtype):
21+
return self.data
22+
23+
@property
24+
def dtype(self):
25+
return self.data.dtype
26+
27+
28+
class TestExtensionArrayDtype(object):
29+
30+
@pytest.mark.parametrize('values', [
31+
pd.Categorical([]),
32+
pd.Categorical([]).dtype,
33+
pd.Series(pd.Categorical([])),
34+
DummyDtype(),
35+
DummyArray(np.array([1, 2])),
36+
])
37+
def test_is_extension_array_dtype(self, values):
38+
assert is_extension_array_dtype(values)
39+
40+
@pytest.mark.parametrize('values', [
41+
np.array([]),
42+
pd.Series(np.array([])),
43+
])
44+
def test_is_not_extension_array_dtype(self, values):
45+
assert not is_extension_array_dtype(values)
46+
47+
48+
def test_astype():
49+
50+
arr = DummyArray(np.array([1, 2, 3]))
51+
expected = np.array([1, 2, 3], dtype=object)
52+
53+
result = arr.astype(object)
54+
tm.assert_numpy_array_equal(result, expected)
55+
56+
result = arr.astype('object')
57+
tm.assert_numpy_array_equal(result, expected)
58+
59+
60+
def test_astype_no_copy():
61+
arr = DummyArray(np.array([1, 2, 3], dtype=np.int64))
62+
result = arr.astype(arr.dtype, copy=False)
63+
64+
assert arr.data is result
65+
66+
result = arr.astype(arr.dtype)
67+
assert arr.data is not result

0 commit comments

Comments
 (0)