Skip to content

API: Default ExtensionArray.astype #19604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""An interface for extending pandas with custom arrays."""
import numpy as np

from pandas.errors import AbstractMethodError

_not_implemented_message = "{} does not implement {}."
Expand Down Expand Up @@ -138,6 +140,34 @@ def nbytes(self):
# ------------------------------------------------------------------------
# Additional Methods
# ------------------------------------------------------------------------
def astype(self, dtype, copy=True):
"""Cast to a NumPy array with 'dtype'.

The default implementation only allows casting to 'object' dtype.

Parameters
----------
dtype : str or dtype
Typecode or data-type to which the array is cast.
copy : bool, default True
Whether to copy the data, even if not necessary. If False,
a copy is made only if the old dtype does not match the
new dtype.

Returns
-------
array : ndarray
NumPy ndarray with 'dtype' for its dtype.
"""
np_dtype = np.dtype(dtype)

if np_dtype != 'object':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use is_object_dtype

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you even need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the check for object? Fair point. If the underlying object supports conversion to whatever format, we should allow it.

In [1]: import pandas_ip as ip

In [2]: ip.IPAddress([1, 2, 3]).astype(object)
Out[2]:
array([IPv4Address('0.0.0.1'), IPv4Address('0.0.0.2'),
       IPv4Address('0.0.0.3')], dtype=object)

In [3]: ip.IPAddress([1, 2, 3]).astype(int)
Out[3]: array([1, 2, 3])

Which simplifies things nicely!

msg = ("{} can only be coerced to 'object' dtype, "
"not '{}'.").format(type(self).__name__, dtype)
raise ValueError(msg)

return np.array(self, dtype=np_dtype, copy=copy)

def isna(self):
# type: () -> np.ndarray
"""Boolean NumPy array indicating if each value is missing.
Expand Down
36 changes: 36 additions & 0 deletions pandas/tests/extension_arrays/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this in pandas/tests/extension/test_common.py you also need an __init__.py


import pandas.util.testing as tm
from pandas.core.arrays import ExtensionArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would also move the existing extension tests here



class DummyArray(ExtensionArray):

def __init__(self, data):
self.data = data

def __array__(self, dtype):
return self.data


def test_astype():
arr = DummyArray(np.array([1, 2, 3]))
expected = np.array([1, 2, 3], dtype=object)

result = arr.astype(object)
tm.assert_numpy_array_equal(result, expected)

result = arr.astype('object')
tm.assert_numpy_array_equal(result, expected)


def test_astype_raises():
arr = DummyArray(np.array([1, 2, 3]))

# type int for py2
# class int for py3
xpr = ("DummyArray can only be coerced to 'object' dtype, not "
"'<.* 'int'>'")

with tm.assert_raises_regex(ValueError, xpr):
arr.astype(int)