-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
Convert ArrowExtensionArray to proper NumPy dtype #56290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
eeca9e3
455df48
3bbfab5
c3659f1
7b1122f
513ad7f
eaf4211
a9fb2ac
dd19a6f
393cb33
0edcc4e
9dd2505
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from pandas._libs import lib | ||
from pandas.errors import LossySetitemError | ||
|
||
from pandas.core.dtypes.cast import np_can_hold_element | ||
from pandas.core.dtypes.common import is_numeric_dtype | ||
|
||
|
||
def _to_numpy_dtype_inference(arr, dtype, na_value, hasna): | ||
if dtype is None and is_numeric_dtype(arr.dtype): | ||
dtype_given = False | ||
if hasna: | ||
if arr.dtype.kind == "b": | ||
dtype = object | ||
else: | ||
if arr.dtype.kind in "iu": | ||
dtype = np.dtype(np.float64) | ||
else: | ||
dtype = arr.dtype.numpy_dtype | ||
if na_value is lib.no_default: | ||
na_value = np.nan | ||
else: | ||
dtype = arr.dtype.numpy_dtype | ||
elif dtype is not None: | ||
dtype = np.dtype(dtype) | ||
dtype_given = True | ||
else: | ||
dtype_given = True | ||
|
||
if na_value is lib.no_default: | ||
na_value = arr.dtype.na_value | ||
|
||
if not dtype_given and hasna: | ||
try: | ||
np_can_hold_element(dtype, na_value) | ||
except LossySetitemError: | ||
dtype = object | ||
return dtype, na_value |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -266,6 +266,16 @@ def data_for_twos(data): | |
|
||
|
||
class TestArrowArray(base.ExtensionTests): | ||
@pytest.mark.parametrize("na_action", [None, "ignore"]) | ||
def test_map(self, data_missing, na_action): | ||
result = data_missing.map(lambda x: x, na_action=na_action) | ||
if data_missing.dtype == "float32[pyarrow]": | ||
# map roundtrips through objects, which converts to float64 | ||
expected = data_missing.to_numpy(dtype="float64", na_value=np.nan) | ||
else: | ||
expected = data_missing.to_numpy() | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
def test_astype_str(self, data, request): | ||
pa_dtype = data.dtype.pyarrow_dtype | ||
if pa.types.is_binary(pa_dtype): | ||
|
@@ -1489,7 +1499,7 @@ def test_to_numpy_with_defaults(data): | |
else: | ||
expected = np.array(data._pa_array) | ||
|
||
if data._hasna: | ||
if data._hasna and not is_numeric_dtype(data.dtype): | ||
expected = expected.astype(object) | ||
expected[pd.isna(data)] = pd.NA | ||
|
||
|
@@ -1501,8 +1511,8 @@ def test_to_numpy_int_with_na(): | |
data = [1, None] | ||
arr = pd.array(data, dtype="int64[pyarrow]") | ||
result = arr.to_numpy() | ||
expected = np.array([1, pd.NA], dtype=object) | ||
assert isinstance(result[0], int) | ||
expected = np.array([1, np.nan]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this mean that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little surprised it doesn't break more tests as I think there are still a number of places that try to round-trip through numpy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah same |
||
assert isinstance(result[0], float) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,7 @@ | |
"_iLocIndexer", | ||
# TODO(3.0): GH#55043 - remove upon removal of ArrayManager | ||
"_get_option", | ||
"_to_numpy_dtype_inference", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the original method is in core, probably can remove the leading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
} | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you type this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done