Skip to content

Commit b9dd4fa

Browse files
authored
Backport PR pandas-dev#51976 on branch 2.0.x (BUG: read_csv for arrow with mismatching dtypes does not work) (pandas-dev#51995)
BUG: read_csv for arrow with mismatching dtypes does not work (pandas-dev#51976) * BUG: read_csv for arrow with mismatching dtypes does not work * Rename var
1 parent 4eb55ed commit b9dd4fa

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

pandas/io/parsers/c_parser_wrapper.py

+8-34
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
is_categorical_dtype,
3030
pandas_dtype,
3131
)
32-
from pandas.core.dtypes.concat import union_categoricals
33-
from pandas.core.dtypes.dtypes import ExtensionDtype
32+
from pandas.core.dtypes.concat import (
33+
concat_compat,
34+
union_categoricals,
35+
)
3436

3537
from pandas.core.indexes.api import ensure_index_from_sequences
3638

@@ -378,43 +380,15 @@ def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict:
378380
arrs = [chunk.pop(name) for chunk in chunks]
379381
# Check each arr for consistent types.
380382
dtypes = {a.dtype for a in arrs}
381-
# TODO: shouldn't we exclude all EA dtypes here?
382-
numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)}
383-
if len(numpy_dtypes) > 1:
384-
# error: Argument 1 to "find_common_type" has incompatible type
385-
# "Set[Any]"; expected "Sequence[Union[dtype[Any], None, type,
386-
# _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any,
387-
# Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]"
388-
common_type = np.find_common_type(
389-
numpy_dtypes, # type: ignore[arg-type]
390-
[],
391-
)
392-
if common_type == np.dtype(object):
393-
warning_columns.append(str(name))
383+
non_cat_dtypes = {x for x in dtypes if not is_categorical_dtype(x)}
394384

395385
dtype = dtypes.pop()
396386
if is_categorical_dtype(dtype):
397387
result[name] = union_categoricals(arrs, sort_categories=False)
398388
else:
399-
if isinstance(dtype, ExtensionDtype):
400-
# TODO: concat_compat?
401-
array_type = dtype.construct_array_type()
402-
# error: Argument 1 to "_concat_same_type" of "ExtensionArray"
403-
# has incompatible type "List[Union[ExtensionArray, ndarray]]";
404-
# expected "Sequence[ExtensionArray]"
405-
result[name] = array_type._concat_same_type(
406-
arrs # type: ignore[arg-type]
407-
)
408-
else:
409-
# error: Argument 1 to "concatenate" has incompatible
410-
# type "List[Union[ExtensionArray, ndarray[Any, Any]]]"
411-
# ; expected "Union[_SupportsArray[dtype[Any]],
412-
# Sequence[_SupportsArray[dtype[Any]]],
413-
# Sequence[Sequence[_SupportsArray[dtype[Any]]]],
414-
# Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]]
415-
# , Sequence[Sequence[Sequence[Sequence[
416-
# _SupportsArray[dtype[Any]]]]]]]"
417-
result[name] = np.concatenate(arrs) # type: ignore[arg-type]
389+
result[name] = concat_compat(arrs)
390+
if len(non_cat_dtypes) > 1 and result[name].dtype == np.dtype(object):
391+
warning_columns.append(str(name))
418392

419393
if warning_columns:
420394
warning_names = ",".join(warning_columns)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.errors import DtypeWarning
5+
6+
import pandas._testing as tm
7+
from pandas.core.arrays import ArrowExtensionArray
8+
9+
from pandas.io.parsers.c_parser_wrapper import _concatenate_chunks
10+
11+
12+
def test_concatenate_chunks_pyarrow():
13+
# GH#51876
14+
pa = pytest.importorskip("pyarrow")
15+
chunks = [
16+
{0: ArrowExtensionArray(pa.array([1.5, 2.5]))},
17+
{0: ArrowExtensionArray(pa.array([1, 2]))},
18+
]
19+
result = _concatenate_chunks(chunks)
20+
expected = ArrowExtensionArray(pa.array([1.5, 2.5, 1.0, 2.0]))
21+
tm.assert_extension_array_equal(result[0], expected)
22+
23+
24+
def test_concatenate_chunks_pyarrow_strings():
25+
# GH#51876
26+
pa = pytest.importorskip("pyarrow")
27+
chunks = [
28+
{0: ArrowExtensionArray(pa.array([1.5, 2.5]))},
29+
{0: ArrowExtensionArray(pa.array(["a", "b"]))},
30+
]
31+
with tm.assert_produces_warning(DtypeWarning, match="have mixed types"):
32+
result = _concatenate_chunks(chunks)
33+
expected = np.concatenate(
34+
[np.array([1.5, 2.5], dtype=object), np.array(["a", "b"])]
35+
)
36+
tm.assert_numpy_array_equal(result[0], expected)

0 commit comments

Comments
 (0)