Skip to content

Commit 1273bc9

Browse files
authored
ENH: Add use_nullable_dtypes in csv internals (#48403)
* ENH: Add use_nullable_dtypes in csv internals * Add tests * Fix mypy * Add comment * Add pyarrow test * Fix float32 * Fix float32 * Add contextmanager
1 parent 81b5f1d commit 1273bc9

File tree

2 files changed

+159
-8
lines changed

2 files changed

+159
-8
lines changed

pandas/_libs/parsers.pyx

+51-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ import warnings
1515

1616
from pandas.util._exceptions import find_stack_level
1717

18+
from pandas import StringDtype
19+
from pandas.core.arrays import (
20+
BooleanArray,
21+
FloatingArray,
22+
IntegerArray,
23+
)
24+
1825
cimport cython
1926
from cpython.bytes cimport (
2027
PyBytes_AsString,
@@ -1378,18 +1385,53 @@ STR_NA_VALUES = {
13781385
_NA_VALUES = _ensure_encoded(list(STR_NA_VALUES))
13791386

13801387

1381-
def _maybe_upcast(arr):
1382-
"""
1388+
def _maybe_upcast(arr, use_nullable_dtypes: bool = False):
1389+
"""Sets nullable dtypes or upcasts if nans are present.
13831390
1391+
Upcast, if use_nullable_dtypes is false and nans are present so that the
1392+
current dtype can not hold the na value. We use nullable dtypes if the
1393+
flag is true for every array.
1394+
1395+
Parameters
1396+
----------
1397+
arr: ndarray
1398+
Numpy array that is potentially being upcast.
1399+
1400+
use_nullable_dtypes: bool, default False
1401+
If true, we cast to the associated nullable dtypes.
1402+
1403+
Returns
1404+
-------
1405+
The casted array.
13841406
"""
1407+
na_value = na_values[arr.dtype]
1408+
13851409
if issubclass(arr.dtype.type, np.integer):
1386-
na_value = na_values[arr.dtype]
1387-
arr = arr.astype(float)
1388-
np.putmask(arr, arr == na_value, np.nan)
1410+
mask = arr == na_value
1411+
1412+
if use_nullable_dtypes:
1413+
arr = IntegerArray(arr, mask)
1414+
else:
1415+
arr = arr.astype(float)
1416+
np.putmask(arr, mask, np.nan)
1417+
13891418
elif arr.dtype == np.bool_:
1390-
mask = arr.view(np.uint8) == na_values[np.uint8]
1391-
arr = arr.astype(object)
1392-
np.putmask(arr, mask, np.nan)
1419+
mask = arr.view(np.uint8) == na_value
1420+
1421+
if use_nullable_dtypes:
1422+
arr = BooleanArray(arr, mask)
1423+
else:
1424+
arr = arr.astype(object)
1425+
np.putmask(arr, mask, np.nan)
1426+
1427+
elif issubclass(arr.dtype.type, float) or arr.dtype.type == np.float32:
1428+
if use_nullable_dtypes:
1429+
mask = np.isnan(arr)
1430+
arr = FloatingArray(arr, mask)
1431+
1432+
elif arr.dtype == np.object_:
1433+
if use_nullable_dtypes:
1434+
arr = StringDtype().construct_array_type()._from_sequence(arr)
13931435

13941436
return arr
13951437

@@ -1985,6 +2027,7 @@ def _compute_na_values():
19852027
uint16info = np.iinfo(np.uint16)
19862028
uint8info = np.iinfo(np.uint8)
19872029
na_values = {
2030+
np.float32: np.nan,
19882031
np.float64: np.nan,
19892032
np.int64: int64info.min,
19902033
np.int32: int32info.min,

pandas/tests/io/parser/test_upcast.py

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas._libs.parsers import ( # type: ignore[attr-defined]
5+
_maybe_upcast,
6+
na_values,
7+
)
8+
import pandas.util._test_decorators as td
9+
10+
import pandas as pd
11+
from pandas import NA
12+
import pandas._testing as tm
13+
from pandas.core.arrays import (
14+
ArrowStringArray,
15+
BooleanArray,
16+
FloatingArray,
17+
IntegerArray,
18+
StringArray,
19+
)
20+
21+
22+
def test_maybe_upcast(any_real_numpy_dtype):
23+
# GH#36712
24+
25+
dtype = np.dtype(any_real_numpy_dtype)
26+
na_value = na_values[dtype]
27+
arr = np.array([1, 2, na_value], dtype=dtype)
28+
result = _maybe_upcast(arr, use_nullable_dtypes=True)
29+
30+
expected_mask = np.array([False, False, True])
31+
if issubclass(dtype.type, np.integer):
32+
expected = IntegerArray(arr, mask=expected_mask)
33+
else:
34+
expected = FloatingArray(arr, mask=expected_mask)
35+
36+
tm.assert_extension_array_equal(result, expected)
37+
38+
39+
def test_maybe_upcast_no_na(any_real_numpy_dtype):
40+
# GH#36712
41+
if any_real_numpy_dtype == "float32":
42+
pytest.skip()
43+
44+
arr = np.array([1, 2, 3], dtype=any_real_numpy_dtype)
45+
result = _maybe_upcast(arr, use_nullable_dtypes=True)
46+
47+
expected_mask = np.array([False, False, False])
48+
if issubclass(np.dtype(any_real_numpy_dtype).type, np.integer):
49+
expected = IntegerArray(arr, mask=expected_mask)
50+
else:
51+
expected = FloatingArray(arr, mask=expected_mask)
52+
53+
tm.assert_extension_array_equal(result, expected)
54+
55+
56+
def test_maybe_upcaste_bool():
57+
# GH#36712
58+
dtype = np.bool_
59+
na_value = na_values[dtype]
60+
arr = np.array([True, False, na_value], dtype="uint8").view(dtype)
61+
result = _maybe_upcast(arr, use_nullable_dtypes=True)
62+
63+
expected_mask = np.array([False, False, True])
64+
expected = BooleanArray(arr, mask=expected_mask)
65+
tm.assert_extension_array_equal(result, expected)
66+
67+
68+
def test_maybe_upcaste_bool_no_nan():
69+
# GH#36712
70+
dtype = np.bool_
71+
arr = np.array([True, False, False], dtype="uint8").view(dtype)
72+
result = _maybe_upcast(arr, use_nullable_dtypes=True)
73+
74+
expected_mask = np.array([False, False, False])
75+
expected = BooleanArray(arr, mask=expected_mask)
76+
tm.assert_extension_array_equal(result, expected)
77+
78+
79+
def test_maybe_upcaste_all_nan():
80+
# GH#36712
81+
dtype = np.int64
82+
na_value = na_values[dtype]
83+
arr = np.array([na_value, na_value], dtype=dtype)
84+
result = _maybe_upcast(arr, use_nullable_dtypes=True)
85+
86+
expected_mask = np.array([True, True])
87+
expected = IntegerArray(arr, mask=expected_mask)
88+
tm.assert_extension_array_equal(result, expected)
89+
90+
91+
@td.skip_if_no("pyarrow")
92+
@pytest.mark.parametrize("storage", ["pyarrow", "python"])
93+
@pytest.mark.parametrize("val", [na_values[np.object_], "c"])
94+
def test_maybe_upcast_object(val, storage):
95+
# GH#36712
96+
import pyarrow as pa
97+
98+
with pd.option_context("mode.string_storage", storage):
99+
arr = np.array(["a", "b", val], dtype=np.object_)
100+
result = _maybe_upcast(arr, use_nullable_dtypes=True)
101+
102+
if storage == "python":
103+
exp_val = "c" if val == "c" else NA
104+
expected = StringArray(np.array(["a", "b", exp_val], dtype=np.object_))
105+
else:
106+
exp_val = "c" if val == "c" else None
107+
expected = ArrowStringArray(pa.array(["a", "b", exp_val]))
108+
tm.assert_extension_array_equal(result, expected)

0 commit comments

Comments
 (0)