|
23 | 23 | is_categorical_dtype,
|
24 | 24 | pandas_dtype,
|
25 | 25 | )
|
26 |
| -from pandas.core.dtypes.concat import union_categoricals |
27 |
| -from pandas.core.dtypes.dtypes import ExtensionDtype |
| 26 | +from pandas.core.dtypes.concat import ( |
| 27 | + concat_compat, |
| 28 | + union_categoricals, |
| 29 | +) |
28 | 30 |
|
29 | 31 | from pandas.core.indexes.api import ensure_index_from_sequences
|
30 | 32 |
|
@@ -379,40 +381,15 @@ def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict:
|
379 | 381 | arrs = [chunk.pop(name) for chunk in chunks]
|
380 | 382 | # Check each arr for consistent types.
|
381 | 383 | dtypes = {a.dtype for a in arrs}
|
382 |
| - # TODO: shouldn't we exclude all EA dtypes here? |
383 |
| - numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)} |
384 |
| - if len(numpy_dtypes) > 1: |
385 |
| - # error: Argument 1 to "find_common_type" has incompatible type |
386 |
| - # "Set[Any]"; expected "Sequence[Union[dtype[Any], None, type, |
387 |
| - # _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, |
388 |
| - # Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]" |
389 |
| - common_type = np.find_common_type( |
390 |
| - numpy_dtypes, # type: ignore[arg-type] |
391 |
| - [], |
392 |
| - ) |
393 |
| - if common_type == np.dtype(object): |
394 |
| - warning_columns.append(str(name)) |
| 384 | + non_cat_dtypes = {x for x in dtypes if not is_categorical_dtype(x)} |
395 | 385 |
|
396 | 386 | dtype = dtypes.pop()
|
397 | 387 | if is_categorical_dtype(dtype):
|
398 | 388 | result[name] = union_categoricals(arrs, sort_categories=False)
|
399 |
| - elif isinstance(dtype, ExtensionDtype): |
400 |
| - # TODO: concat_compat? |
401 |
| - array_type = dtype.construct_array_type() |
402 |
| - # error: Argument 1 to "_concat_same_type" of "ExtensionArray" |
403 |
| - # has incompatible type "List[Union[ExtensionArray, ndarray]]"; |
404 |
| - # expected "Sequence[ExtensionArray]" |
405 |
| - result[name] = array_type._concat_same_type(arrs) # type: ignore[arg-type] |
406 | 389 | else:
|
407 |
| - # error: Argument 1 to "concatenate" has incompatible |
408 |
| - # type "List[Union[ExtensionArray, ndarray[Any, Any]]]" |
409 |
| - # ; expected "Union[_SupportsArray[dtype[Any]], |
410 |
| - # Sequence[_SupportsArray[dtype[Any]]], |
411 |
| - # Sequence[Sequence[_SupportsArray[dtype[Any]]]], |
412 |
| - # Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]] |
413 |
| - # , Sequence[Sequence[Sequence[Sequence[ |
414 |
| - # _SupportsArray[dtype[Any]]]]]]]" |
415 |
| - result[name] = np.concatenate(arrs) # type: ignore[arg-type] |
| 390 | + result[name] = concat_compat(arrs) |
| 391 | + if len(non_cat_dtypes) > 1 and result[name].dtype == np.dtype(object): |
| 392 | + warning_columns.append(str(name)) |
416 | 393 |
|
417 | 394 | if warning_columns:
|
418 | 395 | warning_names = ",".join(warning_columns)
|
|
0 commit comments