Skip to content

Commit 2edfb9d

Browse files
simonjayhawkinsJulianWgs
authored andcommitted
[ArrowStringArray] fix test_astype_int, test_astype_float (pandas-dev#41018)
1 parent 7be93b8 commit 2edfb9d

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

pandas/core/arrays/string_arrow.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@
3131
from pandas.util._decorators import doc
3232
from pandas.util._validators import validate_fillna_kwargs
3333

34+
from pandas.core.dtypes.base import ExtensionDtype
3435
from pandas.core.dtypes.common import (
3536
is_array_like,
3637
is_bool_dtype,
38+
is_dtype_equal,
3739
is_integer,
3840
is_integer_dtype,
3941
is_object_dtype,
4042
is_scalar,
4143
is_string_dtype,
44+
pandas_dtype,
4245
)
4346
from pandas.core.dtypes.dtypes import register_extension_dtype
4447
from pandas.core.dtypes.missing import isna
@@ -48,6 +51,7 @@
4851
from pandas.core.arrays.base import ExtensionArray
4952
from pandas.core.arrays.boolean import BooleanDtype
5053
from pandas.core.arrays.integer import Int64Dtype
54+
from pandas.core.arrays.numeric import NumericDtype
5155
from pandas.core.arrays.string_ import StringDtype
5256
from pandas.core.indexers import (
5357
check_array_indexer,
@@ -285,10 +289,14 @@ def to_numpy( # type: ignore[override]
285289
"""
286290
# TODO: copy argument is ignored
287291

288-
if na_value is lib.no_default:
289-
na_value = self._dtype.na_value
290-
result = self._data.__array__(dtype=dtype)
291-
result[isna(result)] = na_value
292+
result = np.array(self._data, dtype=dtype)
293+
if self._data.null_count > 0:
294+
if na_value is lib.no_default:
295+
if dtype and np.issubdtype(dtype, np.floating):
296+
return result
297+
na_value = self._dtype.na_value
298+
mask = self.isna()
299+
result[mask] = na_value
292300
return result
293301

294302
def __len__(self) -> int:
@@ -732,6 +740,24 @@ def value_counts(self, dropna: bool = True) -> Series:
732740

733741
return Series(counts, index=index).astype("Int64")
734742

743+
def astype(self, dtype, copy=True):
744+
dtype = pandas_dtype(dtype)
745+
746+
if is_dtype_equal(dtype, self.dtype):
747+
if copy:
748+
return self.copy()
749+
return self
750+
751+
elif isinstance(dtype, NumericDtype):
752+
data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
753+
return dtype.__from_arrow__(data)
754+
755+
elif isinstance(dtype, ExtensionDtype):
756+
cls = dtype.construct_array_type()
757+
return cls._from_sequence(self, dtype=dtype, copy=copy)
758+
759+
return super().astype(dtype, copy)
760+
735761
# ------------------------------------------------------------------------
736762
# String methods interface
737763

pandas/tests/arrays/string_/test_string.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
44
"""
55

6+
import re
7+
68
import numpy as np
79
import pytest
810

@@ -325,32 +327,29 @@ def test_from_sequence_no_mutate(copy, cls, request):
325327
tm.assert_numpy_array_equal(nan_arr, expected)
326328

327329

328-
def test_astype_int(dtype, request):
329-
if dtype == "arrow_string":
330-
reason = "Cannot interpret 'Int64Dtype()' as a data type"
331-
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
332-
request.node.add_marker(mark)
330+
def test_astype_int(dtype):
331+
arr = pd.array(["1", "2", "3"], dtype=dtype)
332+
result = arr.astype("int64")
333+
expected = np.array([1, 2, 3], dtype="int64")
334+
tm.assert_numpy_array_equal(result, expected)
335+
336+
arr = pd.array(["1", pd.NA, "3"], dtype=dtype)
337+
msg = re.escape("int() argument must be a string, a bytes-like object or a number")
338+
with pytest.raises(TypeError, match=msg):
339+
arr.astype("int64")
340+
333341

342+
def test_astype_nullable_int(dtype):
334343
arr = pd.array(["1", pd.NA, "3"], dtype=dtype)
335344

336345
result = arr.astype("Int64")
337346
expected = pd.array([1, pd.NA, 3], dtype="Int64")
338347
tm.assert_extension_array_equal(result, expected)
339348

340349

341-
def test_astype_float(dtype, any_float_allowed_nullable_dtype, request):
350+
def test_astype_float(dtype, any_float_allowed_nullable_dtype):
342351
# Don't compare arrays (37974)
343-
344-
if dtype == "arrow_string":
345-
if any_float_allowed_nullable_dtype in {"Float32", "Float64"}:
346-
reason = "Cannot interpret 'Float32Dtype()' as a data type"
347-
else:
348-
reason = "float() argument must be a string or a number, not 'NAType'"
349-
mark = pytest.mark.xfail(raises=TypeError, reason=reason)
350-
request.node.add_marker(mark)
351-
352352
ser = pd.Series(["1.1", pd.NA, "3.3"], dtype=dtype)
353-
354353
result = ser.astype(any_float_allowed_nullable_dtype)
355354
expected = pd.Series([1.1, np.nan, 3.3], dtype=any_float_allowed_nullable_dtype)
356355
tm.assert_series_equal(result, expected)

pandas/tests/series/methods/test_astype.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,9 @@ class TestAstypeString:
379379
# currently no way to parse IntervalArray from a list of strings
380380
],
381381
)
382-
def test_astype_string_to_extension_dtype_roundtrip(self, data, dtype, request):
382+
def test_astype_string_to_extension_dtype_roundtrip(
383+
self, data, dtype, request, nullable_string_dtype
384+
):
383385
if dtype == "boolean" or (
384386
dtype in ("period[M]", "datetime64[ns]", "timedelta64[ns]") and NaT in data
385387
):
@@ -389,7 +391,8 @@ def test_astype_string_to_extension_dtype_roundtrip(self, data, dtype, request):
389391
request.node.add_marker(mark)
390392
# GH-40351
391393
s = Series(data, dtype=dtype)
392-
tm.assert_series_equal(s, s.astype("string").astype(dtype))
394+
result = s.astype(nullable_string_dtype).astype(dtype)
395+
tm.assert_series_equal(result, s)
393396

394397

395398
class TestAstypeCategorical:

0 commit comments

Comments
 (0)