Skip to content

Commit 5f15709

Browse files
jbrockmendelyeshsurya
authored andcommitted
REF: move union_categoricals call outside of cython (pandas-dev#40964)
1 parent f99ef5f commit 5f15709

File tree

4 files changed

+121
-69
lines changed

4 files changed

+121
-69
lines changed

pandas/_libs/parsers.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class TextReader:
5858
true_values=...,
5959
false_values=...,
6060
allow_leading_cols: bool = ...,
61-
low_memory: bool = ...,
6261
skiprows=...,
6362
skipfooter: int = ..., # int64_t
6463
verbose: bool = ...,
@@ -75,3 +74,4 @@ class TextReader:
7574
def close(self) -> None: ...
7675

7776
def read(self, rows: int | None = ...) -> dict[int, ArrayLike]: ...
77+
def read_low_memory(self, rows: int | None) -> list[dict[int, ArrayLike]]: ...

pandas/_libs/parsers.pyx

+15-65
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ from pandas._libs.khash cimport (
9494
)
9595

9696
from pandas.errors import (
97-
DtypeWarning,
9897
EmptyDataError,
9998
ParserError,
10099
ParserWarning,
@@ -108,9 +107,7 @@ from pandas.core.dtypes.common import (
108107
is_float_dtype,
109108
is_integer_dtype,
110109
is_object_dtype,
111-
pandas_dtype,
112110
)
113-
from pandas.core.dtypes.concat import union_categoricals
114111

115112
cdef:
116113
float64_t INF = <float64_t>np.inf
@@ -317,7 +314,7 @@ cdef class TextReader:
317314

318315
cdef public:
319316
int64_t leading_cols, table_width, skipfooter, buffer_lines
320-
bint allow_leading_cols, mangle_dupe_cols, low_memory
317+
bint allow_leading_cols, mangle_dupe_cols
321318
bint delim_whitespace
322319
object delimiter # bytes or str
323320
object converters
@@ -362,7 +359,6 @@ cdef class TextReader:
362359
true_values=None,
363360
false_values=None,
364361
bint allow_leading_cols=True,
365-
bint low_memory=False,
366362
skiprows=None,
367363
skipfooter=0, # int64_t
368364
bint verbose=False,
@@ -479,7 +475,6 @@ cdef class TextReader:
479475
self.na_filter = na_filter
480476

481477
self.verbose = verbose
482-
self.low_memory = low_memory
483478

484479
if float_precision == "round_trip":
485480
# see gh-15140
@@ -492,12 +487,10 @@ cdef class TextReader:
492487
raise ValueError(f'Unrecognized float_precision option: '
493488
f'{float_precision}')
494489

495-
if isinstance(dtype, dict):
496-
dtype = {k: pandas_dtype(dtype[k])
497-
for k in dtype}
498-
elif dtype is not None:
499-
dtype = pandas_dtype(dtype)
500-
490+
# Caller is responsible for ensuring we have one of
491+
# - None
492+
# - DtypeObj
493+
# - dict[Any, DtypeObj]
501494
self.dtype = dtype
502495

503496
# XXX
@@ -774,17 +767,18 @@ cdef class TextReader:
774767
"""
775768
rows=None --> read all rows
776769
"""
777-
if self.low_memory:
778-
# Conserve intermediate space
779-
columns = self._read_low_memory(rows)
780-
else:
781-
# Don't care about memory usage
782-
columns = self._read_rows(rows, 1)
770+
# Don't care about memory usage
771+
columns = self._read_rows(rows, 1)
783772

784773
return columns
785774

786-
# -> dict[int, "ArrayLike"]
787-
cdef _read_low_memory(self, rows):
775+
def read_low_memory(self, rows: int | None)-> list[dict[int, "ArrayLike"]]:
776+
"""
777+
rows=None --> read all rows
778+
"""
779+
# Conserve intermediate space
780+
# Caller is responsible for concatenating chunks,
781+
# see c_parser_wrapper._concatenatve_chunks
788782
cdef:
789783
size_t rows_read = 0
790784
list chunks = []
@@ -819,8 +813,7 @@ cdef class TextReader:
819813
if len(chunks) == 0:
820814
raise StopIteration
821815

822-
# destructive to chunks
823-
return _concatenate_chunks(chunks)
816+
return chunks
824817

825818
cdef _tokenize_rows(self, size_t nrows):
826819
cdef:
@@ -1908,49 +1901,6 @@ cdef raise_parser_error(object base, parser_t *parser):
19081901
raise ParserError(message)
19091902

19101903

1911-
# chunks: list[dict[int, "ArrayLike"]]
1912-
# -> dict[int, "ArrayLike"]
1913-
def _concatenate_chunks(list chunks) -> dict:
1914-
cdef:
1915-
list names = list(chunks[0].keys())
1916-
object name
1917-
list warning_columns = []
1918-
object warning_names
1919-
object common_type
1920-
1921-
result = {}
1922-
for name in names:
1923-
arrs = [chunk.pop(name) for chunk in chunks]
1924-
# Check each arr for consistent types.
1925-
dtypes = {a.dtype for a in arrs}
1926-
numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)}
1927-
if len(numpy_dtypes) > 1:
1928-
common_type = np.find_common_type(numpy_dtypes, [])
1929-
if common_type == object:
1930-
warning_columns.append(str(name))
1931-
1932-
dtype = dtypes.pop()
1933-
if is_categorical_dtype(dtype):
1934-
sort_categories = isinstance(dtype, str)
1935-
result[name] = union_categoricals(arrs,
1936-
sort_categories=sort_categories)
1937-
else:
1938-
if is_extension_array_dtype(dtype):
1939-
array_type = dtype.construct_array_type()
1940-
result[name] = array_type._concat_same_type(arrs)
1941-
else:
1942-
result[name] = np.concatenate(arrs)
1943-
1944-
if warning_columns:
1945-
warning_names = ','.join(warning_columns)
1946-
warning_message = " ".join([
1947-
f"Columns ({warning_names}) have mixed types."
1948-
f"Specify dtype option on import or set low_memory=False."
1949-
])
1950-
warnings.warn(warning_message, DtypeWarning, stacklevel=8)
1951-
return result
1952-
1953-
19541904
# ----------------------------------------------------------------------
19551905
# NA values
19561906
def _compute_na_values():

pandas/io/parsers/c_parser_wrapper.py

+100-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
5+
import numpy as np
6+
17
import pandas._libs.parsers as parsers
2-
from pandas._typing import FilePathOrBuffer
8+
from pandas._typing import (
9+
ArrayLike,
10+
FilePathOrBuffer,
11+
)
12+
from pandas.errors import DtypeWarning
13+
14+
from pandas.core.dtypes.common import (
15+
is_categorical_dtype,
16+
pandas_dtype,
17+
)
18+
from pandas.core.dtypes.concat import union_categoricals
19+
from pandas.core.dtypes.dtypes import ExtensionDtype
320

421
from pandas.core.indexes.api import ensure_index_from_sequences
522

@@ -10,12 +27,16 @@
1027

1128

1229
class CParserWrapper(ParserBase):
30+
low_memory: bool
31+
1332
def __init__(self, src: FilePathOrBuffer, **kwds):
1433
self.kwds = kwds
1534
kwds = kwds.copy()
1635

1736
ParserBase.__init__(self, kwds)
1837

38+
self.low_memory = kwds.pop("low_memory", False)
39+
1940
# #2442
2041
# error: Cannot determine type of 'index_col'
2142
kwds["allow_leading_cols"] = (
@@ -31,6 +52,7 @@ def __init__(self, src: FilePathOrBuffer, **kwds):
3152
for key in ("storage_options", "encoding", "memory_map", "compression"):
3253
kwds.pop(key, None)
3354

55+
kwds["dtype"] = ensure_dtype_objs(kwds.get("dtype", None))
3456
try:
3557
self._reader = parsers.TextReader(self.handles.handle, **kwds)
3658
except Exception:
@@ -187,7 +209,13 @@ def set_error_bad_lines(self, status):
187209

188210
def read(self, nrows=None):
189211
try:
190-
data = self._reader.read(nrows)
212+
if self.low_memory:
213+
chunks = self._reader.read_low_memory(nrows)
214+
# destructive to chunks
215+
data = _concatenate_chunks(chunks)
216+
217+
else:
218+
data = self._reader.read(nrows)
191219
except StopIteration:
192220
# error: Cannot determine type of '_first_chunk'
193221
if self._first_chunk: # type: ignore[has-type]
@@ -294,7 +322,76 @@ def _get_index_names(self):
294322

295323
return names, idx_names
296324

297-
def _maybe_parse_dates(self, values, index, try_parse_dates=True):
325+
def _maybe_parse_dates(self, values, index: int, try_parse_dates=True):
298326
if try_parse_dates and self._should_parse_dates(index):
299327
values = self._date_conv(values)
300328
return values
329+
330+
331+
def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict:
332+
"""
333+
Concatenate chunks of data read with low_memory=True.
334+
335+
The tricky part is handling Categoricals, where different chunks
336+
may have different inferred categories.
337+
"""
338+
names = list(chunks[0].keys())
339+
warning_columns = []
340+
341+
result = {}
342+
for name in names:
343+
arrs = [chunk.pop(name) for chunk in chunks]
344+
# Check each arr for consistent types.
345+
dtypes = {a.dtype for a in arrs}
346+
# TODO: shouldn't we exclude all EA dtypes here?
347+
numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)}
348+
if len(numpy_dtypes) > 1:
349+
# error: Argument 1 to "find_common_type" has incompatible type
350+
# "Set[Any]"; expected "Sequence[Union[dtype[Any], None, type,
351+
# _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any,
352+
# Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]"
353+
common_type = np.find_common_type(
354+
numpy_dtypes, # type: ignore[arg-type]
355+
[],
356+
)
357+
if common_type == object:
358+
warning_columns.append(str(name))
359+
360+
dtype = dtypes.pop()
361+
if is_categorical_dtype(dtype):
362+
result[name] = union_categoricals(arrs, sort_categories=False)
363+
else:
364+
if isinstance(dtype, ExtensionDtype):
365+
# TODO: concat_compat?
366+
array_type = dtype.construct_array_type()
367+
# error: Argument 1 to "_concat_same_type" of "ExtensionArray"
368+
# has incompatible type "List[Union[ExtensionArray, ndarray]]";
369+
# expected "Sequence[ExtensionArray]"
370+
result[name] = array_type._concat_same_type(
371+
arrs # type: ignore[arg-type]
372+
)
373+
else:
374+
result[name] = np.concatenate(arrs)
375+
376+
if warning_columns:
377+
warning_names = ",".join(warning_columns)
378+
warning_message = " ".join(
379+
[
380+
f"Columns ({warning_names}) have mixed types."
381+
f"Specify dtype option on import or set low_memory=False."
382+
]
383+
)
384+
warnings.warn(warning_message, DtypeWarning, stacklevel=8)
385+
return result
386+
387+
388+
def ensure_dtype_objs(dtype):
389+
"""
390+
Ensure we have either None, a dtype object, or a dictionary mapping to
391+
dtype objects.
392+
"""
393+
if isinstance(dtype, dict):
394+
dtype = {k: pandas_dtype(dtype[k]) for k in dtype}
395+
elif dtype is not None:
396+
dtype = pandas_dtype(dtype)
397+
return dtype

pandas/tests/io/parser/test_textreader.py

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
TextFileReader,
2222
read_csv,
2323
)
24+
from pandas.io.parsers.c_parser_wrapper import ensure_dtype_objs
2425

2526

2627
class TestTextReader:
@@ -206,6 +207,8 @@ def test_numpy_string_dtype(self):
206207
aaaaa,5"""
207208

208209
def _make_reader(**kwds):
210+
if "dtype" in kwds:
211+
kwds["dtype"] = ensure_dtype_objs(kwds["dtype"])
209212
return TextReader(StringIO(data), delimiter=",", header=None, **kwds)
210213

211214
reader = _make_reader(dtype="S5,i4")
@@ -233,6 +236,8 @@ def test_pass_dtype(self):
233236
4,d"""
234237

235238
def _make_reader(**kwds):
239+
if "dtype" in kwds:
240+
kwds["dtype"] = ensure_dtype_objs(kwds["dtype"])
236241
return TextReader(StringIO(data), delimiter=",", **kwds)
237242

238243
reader = _make_reader(dtype={"one": "u1", 1: "S1"})

0 commit comments

Comments
 (0)