Skip to content

Commit 4ea8e57

Browse files
mroeschkepooja-subramaniam
authored andcommitted
ENH: Add dtype_backend to to_numeric (pandas-dev#50910)
1 parent 03b88cb commit 4ea8e57

File tree

3 files changed

+86
-12
lines changed

3 files changed

+86
-12
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ to select the nullable dtypes implementation.
6161
* :func:`read_parquet`
6262
* :func:`read_orc`
6363
* :func:`read_feather`
64+
* :func:`to_numeric`
6465

6566

6667
And the following methods will also utilize the ``mode.dtype_backend`` option.

pandas/core/tools/numeric.py

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import numpy as np
66

7+
from pandas._config import get_option
8+
79
from pandas._libs import lib
810
from pandas._typing import (
911
DateTimeErrorChoices,
@@ -190,6 +192,9 @@ def to_numeric(
190192
values = values._data[~mask]
191193

192194
values_dtype = getattr(values, "dtype", None)
195+
if isinstance(values_dtype, pd.ArrowDtype):
196+
mask = values.isna()
197+
values = values.dropna().to_numpy()
193198
new_mask: np.ndarray | None = None
194199
if is_numeric_dtype(values_dtype):
195200
pass
@@ -258,6 +263,7 @@ def to_numeric(
258263
data[~mask] = values
259264

260265
from pandas.core.arrays import (
266+
ArrowExtensionArray,
261267
BooleanArray,
262268
FloatingArray,
263269
IntegerArray,
@@ -272,6 +278,11 @@ def to_numeric(
272278
klass = FloatingArray
273279
values = klass(data, mask)
274280

281+
if get_option("mode.dtype_backend") == "pyarrow" or isinstance(
282+
values_dtype, pd.ArrowDtype
283+
):
284+
values = ArrowExtensionArray(values.__arrow_array__())
285+
275286
if is_series:
276287
return arg._constructor(values, index=arg.index, name=arg.name)
277288
elif is_index:

pandas/tests/tools/test_to_numeric.py

+74-12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
DataFrame,
1010
Index,
1111
Series,
12+
option_context,
1213
to_numeric,
1314
)
1415
import pandas._testing as tm
@@ -813,39 +814,86 @@ def test_to_numeric_use_nullable_dtypes(val, dtype):
813814

814815

815816
@pytest.mark.parametrize(
816-
"val, dtype", [(1, "Int64"), (1.5, "Float64"), (True, "boolean")]
817+
"val, dtype",
818+
[
819+
(1, "Int64"),
820+
(1.5, "Float64"),
821+
(True, "boolean"),
822+
(1, "int64[pyarrow]"),
823+
(1.5, "float64[pyarrow]"),
824+
(True, "bool[pyarrow]"),
825+
],
817826
)
818827
def test_to_numeric_use_nullable_dtypes_na(val, dtype):
819828
# GH#50505
829+
if "pyarrow" in dtype:
830+
pytest.importorskip("pyarrow")
831+
dtype_backend = "pyarrow"
832+
else:
833+
dtype_backend = "pandas"
820834
ser = Series([val, None], dtype=object)
821-
result = to_numeric(ser, use_nullable_dtypes=True)
835+
with option_context("mode.dtype_backend", dtype_backend):
836+
result = to_numeric(ser, use_nullable_dtypes=True)
822837
expected = Series([val, pd.NA], dtype=dtype)
823838
tm.assert_series_equal(result, expected)
824839

825840

826841
@pytest.mark.parametrize(
827842
"val, dtype, downcast",
828-
[(1, "Int8", "integer"), (1.5, "Float32", "float"), (1, "Int8", "signed")],
843+
[
844+
(1, "Int8", "integer"),
845+
(1.5, "Float32", "float"),
846+
(1, "Int8", "signed"),
847+
(1, "int8[pyarrow]", "integer"),
848+
(1.5, "float[pyarrow]", "float"),
849+
(1, "int8[pyarrow]", "signed"),
850+
],
829851
)
830852
def test_to_numeric_use_nullable_dtypes_downcasting(val, dtype, downcast):
831853
# GH#50505
854+
if "pyarrow" in dtype:
855+
pytest.importorskip("pyarrow")
856+
dtype_backend = "pyarrow"
857+
else:
858+
dtype_backend = "pandas"
832859
ser = Series([val, None], dtype=object)
833-
result = to_numeric(ser, use_nullable_dtypes=True, downcast=downcast)
860+
with option_context("mode.dtype_backend", dtype_backend):
861+
result = to_numeric(ser, use_nullable_dtypes=True, downcast=downcast)
834862
expected = Series([val, pd.NA], dtype=dtype)
835863
tm.assert_series_equal(result, expected)
836864

837865

838-
def test_to_numeric_use_nullable_dtypes_downcasting_uint():
866+
@pytest.mark.parametrize(
867+
"smaller, dtype_backend", [["UInt8", "pandas"], ["uint8[pyarrow]", "pyarrow"]]
868+
)
869+
def test_to_numeric_use_nullable_dtypes_downcasting_uint(smaller, dtype_backend):
839870
# GH#50505
871+
if dtype_backend == "pyarrow":
872+
pytest.importorskip("pyarrow")
840873
ser = Series([1, pd.NA], dtype="UInt64")
841-
result = to_numeric(ser, use_nullable_dtypes=True, downcast="unsigned")
842-
expected = Series([1, pd.NA], dtype="UInt8")
874+
with option_context("mode.dtype_backend", dtype_backend):
875+
result = to_numeric(ser, use_nullable_dtypes=True, downcast="unsigned")
876+
expected = Series([1, pd.NA], dtype=smaller)
843877
tm.assert_series_equal(result, expected)
844878

845879

846-
@pytest.mark.parametrize("dtype", ["Int64", "UInt64", "Float64", "boolean"])
880+
@pytest.mark.parametrize(
881+
"dtype",
882+
[
883+
"Int64",
884+
"UInt64",
885+
"Float64",
886+
"boolean",
887+
"int64[pyarrow]",
888+
"uint64[pyarrow]",
889+
"float64[pyarrow]",
890+
"bool[pyarrow]",
891+
],
892+
)
847893
def test_to_numeric_use_nullable_dtypes_already_nullable(dtype):
848894
# GH#50505
895+
if "pyarrow" in dtype:
896+
pytest.importorskip("pyarrow")
849897
ser = Series([1, pd.NA], dtype=dtype)
850898
result = to_numeric(ser, use_nullable_dtypes=True)
851899
expected = Series([1, pd.NA], dtype=dtype)
@@ -855,16 +903,30 @@ def test_to_numeric_use_nullable_dtypes_already_nullable(dtype):
855903
@pytest.mark.parametrize(
856904
"use_nullable_dtypes, dtype", [(True, "Float64"), (False, "float64")]
857905
)
858-
def test_to_numeric_use_nullable_dtypes_error(use_nullable_dtypes, dtype):
906+
@pytest.mark.parametrize("dtype_backend", ["pandas", "pyarrow"])
907+
def test_to_numeric_use_nullable_dtypes_error(
908+
use_nullable_dtypes, dtype, dtype_backend
909+
):
859910
# GH#50505
911+
if dtype_backend == "pyarrow":
912+
pytest.importorskip("pyarrow")
860913
ser = Series(["a", "b", ""])
861914
expected = ser.copy()
862915
with pytest.raises(ValueError, match="Unable to parse string"):
863-
to_numeric(ser, use_nullable_dtypes=use_nullable_dtypes)
916+
with option_context("mode.dtype_backend", dtype_backend):
917+
to_numeric(ser, use_nullable_dtypes=use_nullable_dtypes)
864918

865-
result = to_numeric(ser, use_nullable_dtypes=use_nullable_dtypes, errors="ignore")
919+
with option_context("mode.dtype_backend", dtype_backend):
920+
result = to_numeric(
921+
ser, use_nullable_dtypes=use_nullable_dtypes, errors="ignore"
922+
)
866923
tm.assert_series_equal(result, expected)
867924

868-
result = to_numeric(ser, use_nullable_dtypes=use_nullable_dtypes, errors="coerce")
925+
with option_context("mode.dtype_backend", dtype_backend):
926+
result = to_numeric(
927+
ser, use_nullable_dtypes=use_nullable_dtypes, errors="coerce"
928+
)
929+
if use_nullable_dtypes and dtype_backend == "pyarrow":
930+
dtype = "double[pyarrow]"
869931
expected = Series([np.nan, np.nan, np.nan], dtype=dtype)
870932
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)