From fbf0a0672380e210d3cb3c527fa8045a204d81be Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 8 Feb 2018 15:01:28 -0600 Subject: [PATCH 1/4] API: Default ExtensionArray.astype (cherry picked from commit 943a915562b72bed147c857de927afa0daf31c1a) --- pandas/core/arrays/base.py | 30 +++++++++++++++++ pandas/tests/extension_arrays/test_common.py | 34 ++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 pandas/tests/extension_arrays/test_common.py diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 1556b653819a6..8c3d033dffba7 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -1,4 +1,6 @@ """An interface for extending pandas with custom arrays.""" +import numpy as np + from pandas.errors import AbstractMethodError _not_implemented_message = "{} does not implement {}." @@ -138,6 +140,34 @@ def nbytes(self): # ------------------------------------------------------------------------ # Additional Methods # ------------------------------------------------------------------------ + def astype(self, dtype, copy=True): + """Cast to a NumPy array with 'dtype'. + + The default implementation only allows casting to 'object' dtype. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + copy : bool, default True + Whether to copy the data, even if not necessary. If False, + a copy is made only if the old dtype does not match the + new dtype. + + Returns + ------- + array : ndarray + NumPy ndarray with 'dtype' for its dtype. + """ + np_dtype = np.dtype(dtype) + + if np_dtype != 'object': + msg = ("{} can only be coerced to 'object' dtype, " + "not '{}'.").format(type(self).__name__, dtype) + raise ValueError(msg) + + return np.array(self, dtype=np_dtype, copy=copy) + def isna(self): # type: () -> np.ndarray """Boolean NumPy array indicating if each value is missing. diff --git a/pandas/tests/extension_arrays/test_common.py b/pandas/tests/extension_arrays/test_common.py new file mode 100644 index 0000000000000..7feb7fdf09ec6 --- /dev/null +++ b/pandas/tests/extension_arrays/test_common.py @@ -0,0 +1,34 @@ +import numpy as np + +import pandas.util.testing as tm +from pandas.core.arrays import ExtensionArray + + +class DummyArray(ExtensionArray): + + def __init__(self, data): + self.data = data + + def __array__(self, dtype): + return self.data + + +def test_astype(): + arr = DummyArray(np.array([1, 2, 3])) + expected = np.array([1, 2, 3], dtype=object) + + result = arr.astype(object) + tm.assert_numpy_array_equal(result, expected) + + result = arr.astype('object') + tm.assert_numpy_array_equal(result, expected) + + +def test_astype_raises(): + arr = DummyArray(np.array([1, 2, 3])) + + xpr = ("DummyArray can only be coerced to 'object' dtype, not " + "''") + + with tm.assert_raises_regex(ValueError, xpr): + arr.astype(int) From b20e12cae68dd86ff51597464045656763d369f7 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 8 Feb 2018 15:46:36 -0600 Subject: [PATCH 2/4] Py2 compat --- pandas/tests/extension_arrays/test_common.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas/tests/extension_arrays/test_common.py b/pandas/tests/extension_arrays/test_common.py index 7feb7fdf09ec6..f19754482b04f 100644 --- a/pandas/tests/extension_arrays/test_common.py +++ b/pandas/tests/extension_arrays/test_common.py @@ -27,8 +27,10 @@ def test_astype(): def test_astype_raises(): arr = DummyArray(np.array([1, 2, 3])) + # type int for py2 + # class int for py3 xpr = ("DummyArray can only be coerced to 'object' dtype, not " - "''") + "'<.* 'int'>'") with tm.assert_raises_regex(ValueError, xpr): arr.astype(int) From 87583dc82c3d29ade720b78ccea0e8cad99abe66 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 9 Feb 2018 07:12:13 -0600 Subject: [PATCH 3/4] Moved --- pandas/core/arrays/base.py | 11 +---------- pandas/tests/extension/__init__.py | 0 .../test_common.py | 18 ++++++++++-------- 3 files changed, 11 insertions(+), 18 deletions(-) create mode 100644 pandas/tests/extension/__init__.py rename pandas/tests/{extension_arrays => extension}/test_common.py (64%) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 8c3d033dffba7..553e1e0ac2066 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -143,8 +143,6 @@ def nbytes(self): def astype(self, dtype, copy=True): """Cast to a NumPy array with 'dtype'. - The default implementation only allows casting to 'object' dtype. - Parameters ---------- dtype : str or dtype @@ -159,14 +157,7 @@ def astype(self, dtype, copy=True): array : ndarray NumPy ndarray with 'dtype' for its dtype. """ - np_dtype = np.dtype(dtype) - - if np_dtype != 'object': - msg = ("{} can only be coerced to 'object' dtype, " - "not '{}'.").format(type(self).__name__, dtype) - raise ValueError(msg) - - return np.array(self, dtype=np_dtype, copy=copy) + return np.array(self, dtype=dtype, copy=copy) def isna(self): # type: () -> np.ndarray diff --git a/pandas/tests/extension/__init__.py b/pandas/tests/extension/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pandas/tests/extension_arrays/test_common.py b/pandas/tests/extension/test_common.py similarity index 64% rename from pandas/tests/extension_arrays/test_common.py rename to pandas/tests/extension/test_common.py index f19754482b04f..4668154122e45 100644 --- a/pandas/tests/extension_arrays/test_common.py +++ b/pandas/tests/extension/test_common.py @@ -12,6 +12,10 @@ def __init__(self, data): def __array__(self, dtype): return self.data + @property + def dtype(self): + return self.data.dtype + def test_astype(): arr = DummyArray(np.array([1, 2, 3])) @@ -24,13 +28,11 @@ def test_astype(): tm.assert_numpy_array_equal(result, expected) -def test_astype_raises(): - arr = DummyArray(np.array([1, 2, 3])) +def test_astype_no_copy(): + arr = DummyArray(np.array([1, 2, 3], dtype=np.int64)) + result = arr.astype(arr.dtype, copy=False) - # type int for py2 - # class int for py3 - xpr = ("DummyArray can only be coerced to 'object' dtype, not " - "'<.* 'int'>'") + assert arr.data is result - with tm.assert_raises_regex(ValueError, xpr): - arr.astype(int) + result = arr.astype(arr.dtype) + assert arr.data is not result From d1362271bca8a7b183f3241e5c2f040c422118b8 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 9 Feb 2018 08:21:51 -0600 Subject: [PATCH 4/4] Moved dtypes --- pandas/tests/dtypes/test_dtypes.py | 32 +-------------------------- pandas/tests/extension/test_common.py | 29 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index eca4dd4cf2106..d800a7b92b559 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -10,14 +10,12 @@ Series, Categorical, CategoricalIndex, IntervalIndex, date_range) from pandas.compat import string_types -from pandas.core.arrays import ExtensionArray from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, PeriodDtype, - IntervalDtype, CategoricalDtype, ExtensionDtype) + IntervalDtype, CategoricalDtype) from pandas.core.dtypes.common import ( is_categorical_dtype, is_categorical, is_datetime64tz_dtype, is_datetimetz, - is_extension_array_dtype, is_period_dtype, is_period, is_dtype_equal, is_datetime64_ns_dtype, is_datetime64_dtype, is_interval_dtype, @@ -744,31 +742,3 @@ def test_categorical_categories(self): tm.assert_index_equal(c1.categories, pd.Index(['a', 'b'])) c1 = CategoricalDtype(CategoricalIndex(['a', 'b'])) tm.assert_index_equal(c1.categories, pd.Index(['a', 'b'])) - - -class DummyArray(ExtensionArray): - pass - - -class DummyDtype(ExtensionDtype): - pass - - -class TestExtensionArrayDtype(object): - - @pytest.mark.parametrize('values', [ - pd.Categorical([]), - pd.Categorical([]).dtype, - pd.Series(pd.Categorical([])), - DummyDtype(), - DummyArray(), - ]) - def test_is_extension_array_dtype(self, values): - assert is_extension_array_dtype(values) - - @pytest.mark.parametrize('values', [ - np.array([]), - pd.Series(np.array([])), - ]) - def test_is_not_extension_array_dtype(self, values): - assert not is_extension_array_dtype(values) diff --git a/pandas/tests/extension/test_common.py b/pandas/tests/extension/test_common.py index 4668154122e45..1f4582f687415 100644 --- a/pandas/tests/extension/test_common.py +++ b/pandas/tests/extension/test_common.py @@ -1,7 +1,15 @@ import numpy as np +import pytest +import pandas as pd import pandas.util.testing as tm from pandas.core.arrays import ExtensionArray +from pandas.core.dtypes.common import is_extension_array_dtype +from pandas.core.dtypes.dtypes import ExtensionDtype + + +class DummyDtype(ExtensionDtype): + pass class DummyArray(ExtensionArray): @@ -17,7 +25,28 @@ def dtype(self): return self.data.dtype +class TestExtensionArrayDtype(object): + + @pytest.mark.parametrize('values', [ + pd.Categorical([]), + pd.Categorical([]).dtype, + pd.Series(pd.Categorical([])), + DummyDtype(), + DummyArray(np.array([1, 2])), + ]) + def test_is_extension_array_dtype(self, values): + assert is_extension_array_dtype(values) + + @pytest.mark.parametrize('values', [ + np.array([]), + pd.Series(np.array([])), + ]) + def test_is_not_extension_array_dtype(self, values): + assert not is_extension_array_dtype(values) + + def test_astype(): + arr = DummyArray(np.array([1, 2, 3])) expected = np.array([1, 2, 3], dtype=object)