Skip to content

Commit 2a983bb

Browse files
authored
BUG: read_csv for arrow with mismatching dtypes does not work (#51976)
* BUG: read_csv for arrow with mismatching dtypes does not work * Rename var
1 parent f08871d commit 2a983bb

File tree

2 files changed

+44
-31
lines changed

2 files changed

+44
-31
lines changed

pandas/io/parsers/c_parser_wrapper.py

+8-31
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
is_categorical_dtype,
2424
pandas_dtype,
2525
)
26-
from pandas.core.dtypes.concat import union_categoricals
27-
from pandas.core.dtypes.dtypes import ExtensionDtype
26+
from pandas.core.dtypes.concat import (
27+
concat_compat,
28+
union_categoricals,
29+
)
2830

2931
from pandas.core.indexes.api import ensure_index_from_sequences
3032

@@ -379,40 +381,15 @@ def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict:
379381
arrs = [chunk.pop(name) for chunk in chunks]
380382
# Check each arr for consistent types.
381383
dtypes = {a.dtype for a in arrs}
382-
# TODO: shouldn't we exclude all EA dtypes here?
383-
numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)}
384-
if len(numpy_dtypes) > 1:
385-
# error: Argument 1 to "find_common_type" has incompatible type
386-
# "Set[Any]"; expected "Sequence[Union[dtype[Any], None, type,
387-
# _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any,
388-
# Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]"
389-
common_type = np.find_common_type(
390-
numpy_dtypes, # type: ignore[arg-type]
391-
[],
392-
)
393-
if common_type == np.dtype(object):
394-
warning_columns.append(str(name))
384+
non_cat_dtypes = {x for x in dtypes if not is_categorical_dtype(x)}
395385

396386
dtype = dtypes.pop()
397387
if is_categorical_dtype(dtype):
398388
result[name] = union_categoricals(arrs, sort_categories=False)
399-
elif isinstance(dtype, ExtensionDtype):
400-
# TODO: concat_compat?
401-
array_type = dtype.construct_array_type()
402-
# error: Argument 1 to "_concat_same_type" of "ExtensionArray"
403-
# has incompatible type "List[Union[ExtensionArray, ndarray]]";
404-
# expected "Sequence[ExtensionArray]"
405-
result[name] = array_type._concat_same_type(arrs) # type: ignore[arg-type]
406389
else:
407-
# error: Argument 1 to "concatenate" has incompatible
408-
# type "List[Union[ExtensionArray, ndarray[Any, Any]]]"
409-
# ; expected "Union[_SupportsArray[dtype[Any]],
410-
# Sequence[_SupportsArray[dtype[Any]]],
411-
# Sequence[Sequence[_SupportsArray[dtype[Any]]]],
412-
# Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]]
413-
# , Sequence[Sequence[Sequence[Sequence[
414-
# _SupportsArray[dtype[Any]]]]]]]"
415-
result[name] = np.concatenate(arrs) # type: ignore[arg-type]
390+
result[name] = concat_compat(arrs)
391+
if len(non_cat_dtypes) > 1 and result[name].dtype == np.dtype(object):
392+
warning_columns.append(str(name))
416393

417394
if warning_columns:
418395
warning_names = ",".join(warning_columns)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas.errors import DtypeWarning
5+
6+
import pandas._testing as tm
7+
from pandas.core.arrays import ArrowExtensionArray
8+
9+
from pandas.io.parsers.c_parser_wrapper import _concatenate_chunks
10+
11+
12+
def test_concatenate_chunks_pyarrow():
13+
# GH#51876
14+
pa = pytest.importorskip("pyarrow")
15+
chunks = [
16+
{0: ArrowExtensionArray(pa.array([1.5, 2.5]))},
17+
{0: ArrowExtensionArray(pa.array([1, 2]))},
18+
]
19+
result = _concatenate_chunks(chunks)
20+
expected = ArrowExtensionArray(pa.array([1.5, 2.5, 1.0, 2.0]))
21+
tm.assert_extension_array_equal(result[0], expected)
22+
23+
24+
def test_concatenate_chunks_pyarrow_strings():
25+
# GH#51876
26+
pa = pytest.importorskip("pyarrow")
27+
chunks = [
28+
{0: ArrowExtensionArray(pa.array([1.5, 2.5]))},
29+
{0: ArrowExtensionArray(pa.array(["a", "b"]))},
30+
]
31+
with tm.assert_produces_warning(DtypeWarning, match="have mixed types"):
32+
result = _concatenate_chunks(chunks)
33+
expected = np.concatenate(
34+
[np.array([1.5, 2.5], dtype=object), np.array(["a", "b"])]
35+
)
36+
tm.assert_numpy_array_equal(result[0], expected)

0 commit comments

Comments
 (0)