Skip to content

Commit ea21483

Browse files
phoflnoatamir
authored andcommitted
ENH: Add option to use nullable dtypes in read_csv (pandas-dev#48776)
* ENH: Add option to use nullable dtypes in read_csv * Finish implementation * Update * Fix mypy * Add tests and fix call * Fix typo
1 parent c71a809 commit ea21483

File tree

6 files changed

+183
-16
lines changed

6 files changed

+183
-16
lines changed

doc/source/user_guide/io.rst

+8
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,14 @@ dtype : Type name or dict of column -> type, default ``None``
197197
Support for defaultdict was added. Specify a defaultdict as input where
198198
the default determines the dtype of the columns which are not explicitly
199199
listed.
200+
201+
use_nullable_dtypes : bool = False
202+
Whether or not to use nullable dtypes as default when reading data. If
203+
set to True, nullable dtypes are used for all dtypes that have a nullable
204+
implementation, even if no nulls are present.
205+
206+
.. versionadded:: 2.0
207+
200208
engine : {``'c'``, ``'python'``, ``'pyarrow'``}
201209
Parser engine to use. The C and pyarrow engines are faster, while the python engine
202210
is currently more feature-complete. Multithreading is currently only supported by

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Other enhancements
3232
- :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
3333
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
3434
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
35+
- Added new argument ``use_nullable_dtypes`` to :func:`read_csv` to enable automatic conversion to nullable dtypes (:issue:`36712`)
3536
- Added ``index`` parameter to :meth:`DataFrame.to_dict` (:issue:`46398`)
3637
- Added metadata propagation for binary operators on :class:`DataFrame` (:issue:`28283`)
3738
- :class:`.CategoricalConversionWarning`, :class:`.InvalidComparison`, :class:`.InvalidVersion`, :class:`.LossySetitemError`, and :class:`.NoBufferPresent` are now exposed in ``pandas.errors`` (:issue:`27656`)

pandas/_libs/parsers.pyx

+13-3
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ cdef class TextReader:
342342
object index_col
343343
object skiprows
344344
object dtype
345+
bint use_nullable_dtypes
345346
object usecols
346347
set unnamed_cols # set[str]
347348

@@ -380,7 +381,8 @@ cdef class TextReader:
380381
bint mangle_dupe_cols=True,
381382
float_precision=None,
382383
bint skip_blank_lines=True,
383-
encoding_errors=b"strict"):
384+
encoding_errors=b"strict",
385+
use_nullable_dtypes=False):
384386

385387
# set encoding for native Python and C library
386388
if isinstance(encoding_errors, str):
@@ -505,6 +507,7 @@ cdef class TextReader:
505507
# - DtypeObj
506508
# - dict[Any, DtypeObj]
507509
self.dtype = dtype
510+
self.use_nullable_dtypes = use_nullable_dtypes
508511

509512
# XXX
510513
self.noconvert = set()
@@ -933,6 +936,7 @@ cdef class TextReader:
933936
bint na_filter = 0
934937
int64_t num_cols
935938
dict result
939+
bint use_nullable_dtypes
936940

937941
start = self.parser_start
938942

@@ -1053,8 +1057,14 @@ cdef class TextReader:
10531057
self._free_na_set(na_hashset)
10541058

10551059
# don't try to upcast EAs
1056-
if na_count > 0 and not is_extension_array_dtype(col_dtype):
1057-
col_res = _maybe_upcast(col_res)
1060+
if (
1061+
na_count > 0 and not is_extension_array_dtype(col_dtype)
1062+
or self.use_nullable_dtypes
1063+
):
1064+
use_nullable_dtypes = self.use_nullable_dtypes and col_dtype is None
1065+
col_res = _maybe_upcast(
1066+
col_res, use_nullable_dtypes=use_nullable_dtypes
1067+
)
10581068

10591069
if col_res is None:
10601070
raise ParserError(f'Unable to parse column {i}')

pandas/io/parsers/base_parser.py

+63-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Hashable,
1616
Iterable,
1717
List,
18+
Literal,
1819
Mapping,
1920
Sequence,
2021
Tuple,
@@ -50,6 +51,7 @@
5051
is_dict_like,
5152
is_dtype_equal,
5253
is_extension_array_dtype,
54+
is_float_dtype,
5355
is_integer,
5456
is_integer_dtype,
5557
is_list_like,
@@ -61,8 +63,14 @@
6163
from pandas.core.dtypes.dtypes import CategoricalDtype
6264
from pandas.core.dtypes.missing import isna
6365

66+
from pandas import StringDtype
6467
from pandas.core import algorithms
65-
from pandas.core.arrays import Categorical
68+
from pandas.core.arrays import (
69+
BooleanArray,
70+
Categorical,
71+
FloatingArray,
72+
IntegerArray,
73+
)
6674
from pandas.core.indexes.api import (
6775
Index,
6876
MultiIndex,
@@ -110,6 +118,7 @@ def __init__(self, kwds) -> None:
110118

111119
self.dtype = copy(kwds.get("dtype", None))
112120
self.converters = kwds.get("converters")
121+
self.use_nullable_dtypes = kwds.get("use_nullable_dtypes", False)
113122

114123
self.true_values = kwds.get("true_values")
115124
self.false_values = kwds.get("false_values")
@@ -508,7 +517,7 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
508517
)
509518

510519
arr, _ = self._infer_types(
511-
arr, col_na_values | col_na_fvalues, try_num_bool
520+
arr, col_na_values | col_na_fvalues, cast_type is None, try_num_bool
512521
)
513522
arrays.append(arr)
514523

@@ -574,7 +583,10 @@ def _convert_to_ndarrays(
574583
values = lib.map_infer_mask(values, conv_f, mask)
575584

576585
cvals, na_count = self._infer_types(
577-
values, set(col_na_values) | col_na_fvalues, try_num_bool=False
586+
values,
587+
set(col_na_values) | col_na_fvalues,
588+
cast_type is None,
589+
try_num_bool=False,
578590
)
579591
else:
580592
is_ea = is_extension_array_dtype(cast_type)
@@ -585,14 +597,14 @@ def _convert_to_ndarrays(
585597

586598
# general type inference and conversion
587599
cvals, na_count = self._infer_types(
588-
values, set(col_na_values) | col_na_fvalues, try_num_bool
600+
values,
601+
set(col_na_values) | col_na_fvalues,
602+
cast_type is None,
603+
try_num_bool,
589604
)
590605

591606
# type specified in dtype param or cast_type is an EA
592-
if cast_type and (
593-
not is_dtype_equal(cvals, cast_type)
594-
or is_extension_array_dtype(cast_type)
595-
):
607+
if cast_type and (not is_dtype_equal(cvals, cast_type) or is_ea):
596608
if not is_ea and na_count > 0:
597609
try:
598610
if is_bool_dtype(cast_type):
@@ -679,14 +691,17 @@ def _set(x) -> int:
679691

680692
return noconvert_columns
681693

682-
def _infer_types(self, values, na_values, try_num_bool: bool = True):
694+
def _infer_types(
695+
self, values, na_values, no_dtype_specified, try_num_bool: bool = True
696+
):
683697
"""
684698
Infer types of values, possibly casting
685699
686700
Parameters
687701
----------
688702
values : ndarray
689703
na_values : set
704+
no_dtype_specified: Specifies if we want to cast explicitly
690705
try_num_bool : bool, default try
691706
try to cast values to numeric (first preference) or boolean
692707
@@ -707,28 +722,62 @@ def _infer_types(self, values, na_values, try_num_bool: bool = True):
707722
np.putmask(values, mask, np.nan)
708723
return values, na_count
709724

725+
use_nullable_dtypes: Literal[True] | Literal[False] = (
726+
self.use_nullable_dtypes and no_dtype_specified
727+
)
728+
result: ArrayLike
729+
710730
if try_num_bool and is_object_dtype(values.dtype):
711731
# exclude e.g DatetimeIndex here
712732
try:
713-
result, _ = lib.maybe_convert_numeric(values, na_values, False)
733+
result, result_mask = lib.maybe_convert_numeric(
734+
values,
735+
na_values,
736+
False,
737+
convert_to_masked_nullable=use_nullable_dtypes,
738+
)
714739
except (ValueError, TypeError):
715740
# e.g. encountering datetime string gets ValueError
716741
# TypeError can be raised in floatify
742+
na_count = parsers.sanitize_objects(values, na_values)
717743
result = values
718-
na_count = parsers.sanitize_objects(result, na_values)
719744
else:
720-
na_count = isna(result).sum()
745+
if use_nullable_dtypes:
746+
if result_mask is None:
747+
result_mask = np.zeros(result.shape, dtype=np.bool_)
748+
749+
if result_mask.all():
750+
result = IntegerArray(
751+
np.ones(result_mask.shape, dtype=np.int64), result_mask
752+
)
753+
elif is_integer_dtype(result):
754+
result = IntegerArray(result, result_mask)
755+
elif is_bool_dtype(result):
756+
result = BooleanArray(result, result_mask)
757+
elif is_float_dtype(result):
758+
result = FloatingArray(result, result_mask)
759+
760+
na_count = result_mask.sum()
761+
else:
762+
na_count = isna(result).sum()
721763
else:
722764
result = values
723765
if values.dtype == np.object_:
724766
na_count = parsers.sanitize_objects(values, na_values)
725767

726768
if result.dtype == np.object_ and try_num_bool:
727-
result, _ = libops.maybe_convert_bool(
769+
result, bool_mask = libops.maybe_convert_bool(
728770
np.asarray(values),
729771
true_values=self.true_values,
730772
false_values=self.false_values,
773+
convert_to_masked_nullable=use_nullable_dtypes,
731774
)
775+
if result.dtype == np.bool_ and use_nullable_dtypes:
776+
if bool_mask is None:
777+
bool_mask = np.zeros(result.shape, dtype=np.bool_)
778+
result = BooleanArray(result, bool_mask)
779+
elif result.dtype == np.object_ and use_nullable_dtypes:
780+
result = StringDtype().construct_array_type()._from_sequence(values)
732781

733782
return result, na_count
734783

@@ -1146,6 +1195,7 @@ def converter(*date_cols):
11461195
"on_bad_lines": ParserBase.BadLineHandleMethod.ERROR,
11471196
"error_bad_lines": None,
11481197
"warn_bad_lines": None,
1198+
"use_nullable_dtypes": False,
11491199
}
11501200

11511201

pandas/io/parsers/readers.py

+17
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@
427427
428428
.. versionadded:: 1.2
429429
430+
use_nullable_dtypes : bool = False
431+
Whether or not to use nullable dtypes as default when reading data. If
432+
set to True, nullable dtypes are used for all dtypes that have a nullable
433+
implementation, even if no nulls are present.
434+
435+
.. versionadded:: 2.0
436+
430437
Returns
431438
-------
432439
DataFrame or TextFileReader
@@ -669,6 +676,7 @@ def read_csv(
669676
memory_map: bool = ...,
670677
float_precision: Literal["high", "legacy"] | None = ...,
671678
storage_options: StorageOptions = ...,
679+
use_nullable_dtypes: bool = ...,
672680
) -> TextFileReader:
673681
...
674682

@@ -729,6 +737,7 @@ def read_csv(
729737
memory_map: bool = ...,
730738
float_precision: Literal["high", "legacy"] | None = ...,
731739
storage_options: StorageOptions = ...,
740+
use_nullable_dtypes: bool = ...,
732741
) -> TextFileReader:
733742
...
734743

@@ -789,6 +798,7 @@ def read_csv(
789798
memory_map: bool = ...,
790799
float_precision: Literal["high", "legacy"] | None = ...,
791800
storage_options: StorageOptions = ...,
801+
use_nullable_dtypes: bool = ...,
792802
) -> DataFrame:
793803
...
794804

@@ -849,6 +859,7 @@ def read_csv(
849859
memory_map: bool = ...,
850860
float_precision: Literal["high", "legacy"] | None = ...,
851861
storage_options: StorageOptions = ...,
862+
use_nullable_dtypes: bool = ...,
852863
) -> DataFrame | TextFileReader:
853864
...
854865

@@ -928,6 +939,7 @@ def read_csv(
928939
memory_map: bool = False,
929940
float_precision: Literal["high", "legacy"] | None = None,
930941
storage_options: StorageOptions = None,
942+
use_nullable_dtypes: bool = False,
931943
) -> DataFrame | TextFileReader:
932944
# locals() should never be modified
933945
kwds = locals().copy()
@@ -1008,6 +1020,7 @@ def read_table(
10081020
memory_map: bool = ...,
10091021
float_precision: str | None = ...,
10101022
storage_options: StorageOptions = ...,
1023+
use_nullable_dtypes: bool = ...,
10111024
) -> TextFileReader:
10121025
...
10131026

@@ -1068,6 +1081,7 @@ def read_table(
10681081
memory_map: bool = ...,
10691082
float_precision: str | None = ...,
10701083
storage_options: StorageOptions = ...,
1084+
use_nullable_dtypes: bool = ...,
10711085
) -> TextFileReader:
10721086
...
10731087

@@ -1128,6 +1142,7 @@ def read_table(
11281142
memory_map: bool = ...,
11291143
float_precision: str | None = ...,
11301144
storage_options: StorageOptions = ...,
1145+
use_nullable_dtypes: bool = ...,
11311146
) -> DataFrame:
11321147
...
11331148

@@ -1188,6 +1203,7 @@ def read_table(
11881203
memory_map: bool = ...,
11891204
float_precision: str | None = ...,
11901205
storage_options: StorageOptions = ...,
1206+
use_nullable_dtypes: bool = ...,
11911207
) -> DataFrame | TextFileReader:
11921208
...
11931209

@@ -1267,6 +1283,7 @@ def read_table(
12671283
memory_map: bool = False,
12681284
float_precision: str | None = None,
12691285
storage_options: StorageOptions = None,
1286+
use_nullable_dtypes: bool = False,
12701287
) -> DataFrame | TextFileReader:
12711288
# locals() should never be modified
12721289
kwds = locals().copy()

0 commit comments

Comments
 (0)