Skip to content

ENH: Add option to use nullable dtypes in read_csv #48776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 7, 2022
8 changes: 8 additions & 0 deletions doc/source/user_guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ dtype : Type name or dict of column -> type, default ``None``
Support for defaultdict was added. Specify a defaultdict as input where
the default determines the dtype of the columns which are not explicitly
listed.

use_nullable_dtypes : bool = False
Whether or not to use nullable dtypes as default when reading data. If
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. versionadded:: 2.0

engine : {``'c'``, ``'python'``, ``'pyarrow'``}
Parser engine to use. The C and pyarrow engines are faster, while the python engine
is currently more feature-complete. Multithreading is currently only supported by
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Other enhancements
- :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` to enable automatic conversion to nullable dtypes (:issue:`36712`)
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)
Expand Down
16 changes: 13 additions & 3 deletions pandas/_libs/parsers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ cdef class TextReader:
object index_col
object skiprows
object dtype
bint use_nullable_dtypes
object usecols
set unnamed_cols # set[str]

Expand Down Expand Up @@ -380,7 +381,8 @@ cdef class TextReader:
bint mangle_dupe_cols=True,
float_precision=None,
bint skip_blank_lines=True,
encoding_errors=b"strict"):
encoding_errors=b"strict",
use_nullable_dtypes=False):

# set encoding for native Python and C library
if isinstance(encoding_errors, str):
Expand Down Expand Up @@ -505,6 +507,7 @@ cdef class TextReader:
# - DtypeObj
# - dict[Any, DtypeObj]
self.dtype = dtype
self.use_nullable_dtypes = use_nullable_dtypes

# XXX
self.noconvert = set()
Expand Down Expand Up @@ -933,6 +936,7 @@ cdef class TextReader:
bint na_filter = 0
int64_t num_cols
dict result
bint use_nullable_dtypes

start = self.parser_start

Expand Down Expand Up @@ -1053,8 +1057,14 @@ cdef class TextReader:
self._free_na_set(na_hashset)

# don't try to upcast EAs
if na_count > 0 and not is_extension_array_dtype(col_dtype):
col_res = _maybe_upcast(col_res)
if (
na_count > 0 and not is_extension_array_dtype(col_dtype)
or self.use_nullable_dtypes
):
use_nullable_dtypes = self.use_nullable_dtypes and col_dtype is None
col_res = _maybe_upcast(
col_res, use_nullable_dtypes=use_nullable_dtypes
)

if col_res is None:
raise ParserError(f'Unable to parse column {i}')
Expand Down
76 changes: 63 additions & 13 deletions pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Hashable,
Iterable,
List,
Literal,
Mapping,
Sequence,
Tuple,
Expand Down Expand Up @@ -50,6 +51,7 @@
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_integer,
is_integer_dtype,
is_list_like,
Expand All @@ -61,8 +63,14 @@
from pandas.core.dtypes.dtypes import CategoricalDtype
from pandas.core.dtypes.missing import isna

from pandas import StringDtype
from pandas.core import algorithms
from pandas.core.arrays import Categorical
from pandas.core.arrays import (
BooleanArray,
Categorical,
FloatingArray,
IntegerArray,
)
from pandas.core.indexes.api import (
Index,
MultiIndex,
Expand Down Expand Up @@ -110,6 +118,7 @@ def __init__(self, kwds) -> None:

self.dtype = copy(kwds.get("dtype", None))
self.converters = kwds.get("converters")
self.use_nullable_dtypes = kwds.get("use_nullable_dtypes", False)

self.true_values = kwds.get("true_values")
self.false_values = kwds.get("false_values")
Expand Down Expand Up @@ -508,7 +517,7 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
)

arr, _ = self._infer_types(
arr, col_na_values | col_na_fvalues, try_num_bool
arr, col_na_values | col_na_fvalues, cast_type is None, try_num_bool
)
arrays.append(arr)

Expand Down Expand Up @@ -574,7 +583,10 @@ def _convert_to_ndarrays(
values = lib.map_infer_mask(values, conv_f, mask)

cvals, na_count = self._infer_types(
values, set(col_na_values) | col_na_fvalues, try_num_bool=False
values,
set(col_na_values) | col_na_fvalues,
cast_type is None,
try_num_bool=False,
)
else:
is_ea = is_extension_array_dtype(cast_type)
Expand All @@ -585,14 +597,14 @@ def _convert_to_ndarrays(

# general type inference and conversion
cvals, na_count = self._infer_types(
values, set(col_na_values) | col_na_fvalues, try_num_bool
values,
set(col_na_values) | col_na_fvalues,
cast_type is None,
try_num_bool,
)

# type specified in dtype param or cast_type is an EA
if cast_type and (
not is_dtype_equal(cvals, cast_type)
or is_extension_array_dtype(cast_type)
):
if cast_type and (not is_dtype_equal(cvals, cast_type) or is_ea):
if not is_ea and na_count > 0:
try:
if is_bool_dtype(cast_type):
Expand Down Expand Up @@ -679,14 +691,17 @@ def _set(x) -> int:

return noconvert_columns

def _infer_types(self, values, na_values, try_num_bool: bool = True):
def _infer_types(
self, values, na_values, no_dtype_specified, try_num_bool: bool = True
):
"""
Infer types of values, possibly casting

Parameters
----------
values : ndarray
na_values : set
no_dtype_specified: Specifies if we want to cast explicitly
try_num_bool : bool, default try
try to cast values to numeric (first preference) or boolean

Expand All @@ -707,28 +722,62 @@ def _infer_types(self, values, na_values, try_num_bool: bool = True):
np.putmask(values, mask, np.nan)
return values, na_count

use_nullable_dtypes: Literal[True] | Literal[False] = (
self.use_nullable_dtypes and no_dtype_specified
)
result: ArrayLike

if try_num_bool and is_object_dtype(values.dtype):
# exclude e.g DatetimeIndex here
try:
result, _ = lib.maybe_convert_numeric(values, na_values, False)
result, result_mask = lib.maybe_convert_numeric(
values,
na_values,
False,
convert_to_masked_nullable=use_nullable_dtypes,
)
except (ValueError, TypeError):
# e.g. encountering datetime string gets ValueError
# TypeError can be raised in floatify
na_count = parsers.sanitize_objects(values, na_values)
result = values
na_count = parsers.sanitize_objects(result, na_values)
else:
na_count = isna(result).sum()
if use_nullable_dtypes:
if result_mask is None:
result_mask = np.zeros(result.shape, dtype=np.bool_)

if result_mask.all():
result = IntegerArray(
np.ones(result_mask.shape, dtype=np.int64), result_mask
)
elif is_integer_dtype(result):
result = IntegerArray(result, result_mask)
elif is_bool_dtype(result):
result = BooleanArray(result, result_mask)
elif is_float_dtype(result):
result = FloatingArray(result, result_mask)

na_count = result_mask.sum()
else:
na_count = isna(result).sum()
else:
result = values
if values.dtype == np.object_:
na_count = parsers.sanitize_objects(values, na_values)

if result.dtype == np.object_ and try_num_bool:
result, _ = libops.maybe_convert_bool(
result, bool_mask = libops.maybe_convert_bool(
np.asarray(values),
true_values=self.true_values,
false_values=self.false_values,
convert_to_masked_nullable=use_nullable_dtypes,
)
if result.dtype == np.bool_ and use_nullable_dtypes:
if bool_mask is None:
bool_mask = np.zeros(result.shape, dtype=np.bool_)
result = BooleanArray(result, bool_mask)
elif result.dtype == np.object_ and use_nullable_dtypes:
result = StringDtype().construct_array_type()._from_sequence(values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you test what happens when the string pyarrow global config is true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


return result, na_count

Expand Down Expand Up @@ -1146,6 +1195,7 @@ def converter(*date_cols):
"on_bad_lines": ParserBase.BadLineHandleMethod.ERROR,
"error_bad_lines": None,
"warn_bad_lines": None,
"use_nullable_dtypes": False,
}


Expand Down
17 changes: 17 additions & 0 deletions pandas/io/parsers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@

.. versionadded:: 1.2

use_nullable_dtypes : bool = False
Whether or not to use nullable dtypes as default when reading data. If
set to True, nullable dtypes are used for all dtypes that have a nullable
implementation, even if no nulls are present.

.. versionadded:: 2.0

Returns
-------
DataFrame or TextFileReader
Expand Down Expand Up @@ -669,6 +676,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -729,6 +737,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -789,6 +798,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame:
...

Expand Down Expand Up @@ -849,6 +859,7 @@ def read_csv(
memory_map: bool = ...,
float_precision: Literal["high", "legacy"] | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame | TextFileReader:
...

Expand Down Expand Up @@ -928,6 +939,7 @@ def read_csv(
memory_map: bool = False,
float_precision: Literal["high", "legacy"] | None = None,
storage_options: StorageOptions = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | TextFileReader:
# locals() should never be modified
kwds = locals().copy()
Expand Down Expand Up @@ -1008,6 +1020,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -1068,6 +1081,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> TextFileReader:
...

Expand Down Expand Up @@ -1128,6 +1142,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame:
...

Expand Down Expand Up @@ -1188,6 +1203,7 @@ def read_table(
memory_map: bool = ...,
float_precision: str | None = ...,
storage_options: StorageOptions = ...,
use_nullable_dtypes: bool = ...,
) -> DataFrame | TextFileReader:
...

Expand Down Expand Up @@ -1267,6 +1283,7 @@ def read_table(
memory_map: bool = False,
float_precision: str | None = None,
storage_options: StorageOptions = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | TextFileReader:
# locals() should never be modified
kwds = locals().copy()
Expand Down
Loading