Skip to content

ENH: Add option to use nullable dtypes in read_csv #48776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 7, 2022
8 changes: 8 additions & 0 deletions doc/source/user_guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ dtype : Type name or dict of column -> type, default ``None``
Support for defaultdict was added. Specify a defaultdict as input where
the default determines the dtype of the columns which are not explicitly
listed.

use_nullable_dtypes : bool = False
Whether or not to use nullable dtypes as default when reading data. If
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. versionadded:: 2.0

engine : {``'c'``, ``'python'``, ``'pyarrow'``}
Parser engine to use. The C and pyarrow engines are faster, while the python engine
is currently more feature-complete. Multithreading is currently only supported by
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Other enhancements
- :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` to enable automatic conversion to nullable dtypes (:issue:`36712`)
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)
Expand Down
16 changes: 13 additions & 3 deletions pandas/_libs/parsers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ cdef class TextReader:
object index_col
object skiprows
object dtype
bint use_nullable_dtypes
object usecols
set unnamed_cols # set[str]

Expand Down Expand Up @@ -380,7 +381,8 @@ cdef class TextReader:
bint mangle_dupe_cols=True,
float_precision=None,
bint skip_blank_lines=True,
encoding_errors=b"strict"):
encoding_errors=b"strict",
use_nullable_dtypes=False):

# set encoding for native Python and C library
if isinstance(encoding_errors, str):
Expand Down Expand Up @@ -505,6 +507,7 @@ cdef class TextReader:
# - DtypeObj
# - dict[Any, DtypeObj]
self.dtype = dtype
self.use_nullable_dtypes = use_nullable_dtypes

# XXX
self.noconvert = set()
Expand Down Expand Up @@ -933,6 +936,7 @@ cdef class TextReader:
bint na_filter = 0
int64_t num_cols
dict result
bint use_nullable_dtypes

start = self.parser_start

Expand Down Expand Up @@ -1053,8 +1057,14 @@ cdef class TextReader:
self._free_na_set(na_hashset)

# don't try to upcast EAs
if na_count > 0 and not is_extension_array_dtype(col_dtype):
col_res = _maybe_upcast(col_res)
if (
na_count > 0 and not is_extension_array_dtype(col_dtype)
or self.use_nullable_dtypes
):
use_nullable_dtypes = self.use_nullable_dtypes and col_dtype is None
col_res = _maybe_upcast(
col_res, use_nullable_dtypes=use_nullable_dtypes
)

if col_res is None:
raise ParserError(f'Unable to parse column {i}')
Expand Down
67 changes: 54 additions & 13 deletions pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Hashable,
Iterable,
List,
Literal,
Mapping,
Sequence,
Tuple,
Expand Down Expand Up @@ -50,6 +51,7 @@
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_integer,
is_integer_dtype,
is_list_like,
Expand All @@ -61,8 +63,14 @@
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.missing import isna

from pandas import StringDtype
from pandas.core import algorithms
from pandas.core.arrays import Categorical
from pandas.core.arrays import (
BooleanArray,
Categorical,
FloatingArray,
IntegerArray,
)
from pandas.core.indexes.api import (
Index,
MultiIndex,
Expand Down Expand Up @@ -110,6 +118,7 @@ def __init__(self, kwds) -> None:

self.dtype = copy(kwds.get("dtype", None))
self.converters = kwds.get("converters")
self.use_nullable_dtypes = kwds.get("use_nullable_dtypes", False)

self.true_values = kwds.get("true_values")
self.false_values = kwds.get("false_values")
Expand Down Expand Up @@ -508,7 +517,7 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
)

arr, _ = self._infer_types(
arr, col_na_values | col_na_fvalues, try_num_bool
arr, col_na_values | col_na_fvalues, cast_type, try_num_bool
)
arrays.append(arr)

Expand Down Expand Up @@ -574,7 +583,10 @@ def _convert_to_ndarrays(
values = lib.map_infer_mask(values, conv_f, mask)

cvals, na_count = self._infer_types(
values, set(col_na_values) | col_na_fvalues, try_num_bool=False
values,
set(col_na_values) | col_na_fvalues,
cast_type,
try_num_bool=False,
)
else:
is_ea = is_extension_array_dtype(cast_type)
Expand All @@ -585,14 +597,11 @@ def _convert_to_ndarrays(

# general type inference and conversion
cvals, na_count = self._infer_types(
values, set(col_na_values) | col_na_fvalues, try_num_bool
values, set(col_na_values) | col_na_fvalues, cast_type, try_num_bool
)

# type specified in dtype param or cast_type is an EA
if cast_type and (
not is_dtype_equal(cvals, cast_type)
or is_extension_array_dtype(cast_type)
):
if cast_type and (not is_dtype_equal(cvals, cast_type) or is_ea):
if not is_ea and na_count > 0:
try:
if is_bool_dtype(cast_type):
Expand Down Expand Up @@ -679,14 +688,15 @@ def _set(x) -> int:

return noconvert_columns

def _infer_types(self, values, na_values, try_num_bool: bool = True):
def _infer_types(self, values, na_values, cast_type, try_num_bool: bool = True):
"""
Infer types of values, possibly casting

Parameters
----------
values : ndarray
na_values : set
cast_type: Specifies if we want to cast explicitly
Copy link
Member

@mroeschke mroeschke Sep 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this bool? Looks like we only need to check that it's not None?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

try_num_bool : bool, default try
try to cast values to numeric (first preference) or boolean

Expand All @@ -707,28 +717,58 @@ def _infer_types(self, values, na_values, try_num_bool: bool = True):
np.putmask(values, mask, np.nan)
return values, na_count

use_nullable_dtypes: Literal[True] | Literal[False] = (
self.use_nullable_dtypes and cast_type is None
)
result: ArrayLike

if try_num_bool and is_object_dtype(values.dtype):
# exclude e.g DatetimeIndex here
try:
result, _ = lib.maybe_convert_numeric(values, na_values, False)
result, result_mask = lib.maybe_convert_numeric(
values,
na_values,
False,
convert_to_masked_nullable=use_nullable_dtypes,
)
except (ValueError, TypeError):
# e.g. encountering datetime string gets ValueError
# TypeError can be raised in floatify
na_count = parsers.sanitize_objects(values, na_values)
result = values
na_count = parsers.sanitize_objects(result, na_values)
else:
na_count = isna(result).sum()
if use_nullable_dtypes:
if result_mask is None:
result_mask = np.zeros(result.shape, dtype=np.bool_)

if is_integer_dtype(result):
result = IntegerArray(result, result_mask)
elif is_bool_dtype(result):
result = BooleanArray(result, result_mask)
elif is_float_dtype(result):
result = FloatingArray(result, result_mask)

na_count = result_mask.sum()
else:
na_count = isna(result).sum()
else:
result = values
if values.dtype == np.object_:
na_count = parsers.sanitize_objects(values, na_values)

if result.dtype == np.object_ and try_num_bool:
result, _ = libops.maybe_convert_bool(
result, bool_mask = libops.maybe_convert_bool(
np.asarray(values),
true_values=self.true_values,
false_values=self.false_values,
convert_to_masked_nullable=use_nullable_dtypes,
)
if result.dtype == np.bool_ and use_nullable_dtypes:
if bool_mask is None:
bool_mask = np.zeros(result.shape, dtype=np.bool_)
result = BooleanArray(result, bool_mask)
elif result.dtype == np.object_ and use_nullable_dtypes:
result = StringDtype().construct_array_type()._from_sequence(values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test what happens when the string pyarrow global config is true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


return result, na_count

Expand Down Expand Up @@ -1146,6 +1186,7 @@ def converter(*date_cols):
"on_bad_lines": ParserBase.BadLineHandleMethod.ERROR,
"error_bad_lines": None,
"warn_bad_lines": None,
"use_nullable_dtypes": False,
}


Expand Down
17 changes: 17 additions & 0 deletions pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@

.. versionadded:: 1.2

use_nullable_dtypes : bool = False
Whether or not to use nullable dtypes as default when reading data. If
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. versionadded:: 2.0

Returns
-------
DataFrame or TextFileReader
Expand Down Expand Up @@ -669,6 +676,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -729,6 +737,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -789,6 +798,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame:
...

Expand Down Expand Up @@ -849,6 +859,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame | TextFileReader:
...

Expand Down Expand Up @@ -928,6 +939,7 @@ def read_csv(
memory_map: bool = False,
float_precision: Literal["high", "legacy"] | None = None,
storage_options: StorageOptions = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | TextFileReader:
# locals() should never be modified
kwds = locals().copy()
Expand Down Expand Up @@ -1008,6 +1020,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -1068,6 +1081,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -1128,6 +1142,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame:
...

Expand Down Expand Up @@ -1188,6 +1203,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame | TextFileReader:
...

Expand Down Expand Up @@ -1267,6 +1283,7 @@ def read_table(
memory_map: bool = False,
float_precision: str | None = None,
storage_options: StorageOptions = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | TextFileReader:
# locals() should never be modified
kwds = locals().copy()
Expand Down
42 changes: 42 additions & 0 deletions pandas/tests/io/parser/dtypes/test_dtypes_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,45 @@ def test_dtypes_defaultdict_invalid(all_parsers):
parser = all_parsers
with pytest.raises(TypeError, match="not understood"):
parser.read_csv(StringIO(data), dtype=dtype)


def test_use_nullabla_dtypes(all_parsers):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo here and below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, fixed

# GH#36712

parser = all_parsers

data = """a,b,c,d,e,f,g,h,i
1,2.5,True,a,,,,,12-31-2019
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a column here where both rows have an empty value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, casts to Int64 now in both cases. Better question is what we actually want here, because this could be everything

3,4.5,False,b,6,7.5,True,a,12-31-2019
"""
result = parser.read_csv(
StringIO(data), use_nullable_dtypes=True, parse_dates=["i"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you parametrize for use_nullable_dtypes = True/False here and for the other tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is impossible to understand if paramterized. Expected looks completely different. I could add a new test in theory, but would not bring much value, we are testing all possible cases already with numpy dtypes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks for checking.

)
expected = DataFrame(
{
"a": pd.Series([1, 3], dtype="Int64"),
"b": pd.Series([2.5, 4.5], dtype="Float64"),
"c": pd.Series([True, False], dtype="boolean"),
"d": pd.Series(["a", "b"], dtype="string"),
"e": pd.Series([pd.NA, 6], dtype="Int64"),
"f": pd.Series([pd.NA, 7.5], dtype="Float64"),
"g": pd.Series([pd.NA, True], dtype="boolean"),
"h": pd.Series([pd.NA, "a"], dtype="string"),
"i": pd.Series([Timestamp("2019-12-31")] * 2),
}
)
tm.assert_frame_equal(result, expected)


def test_use_nullabla_dtypes_and_dtype(all_parsers):
# GH#36712

parser = all_parsers

data = """a,b
1,2.5
,
"""
result = parser.read_csv(StringIO(data), use_nullable_dtypes=True, dtype="float64")
expected = DataFrame({"a": [1.0, np.nan], "b": [2.5, np.nan]})
tm.assert_frame_equal(result, expected)