|
29 | 29 | is_categorical_dtype,
|
30 | 30 | pandas_dtype,
|
31 | 31 | )
|
32 |
| -from pandas.core.dtypes.concat import union_categoricals |
33 |
| -from pandas.core.dtypes.dtypes import ExtensionDtype |
| 32 | +from pandas.core.dtypes.concat import ( |
| 33 | + concat_compat, |
| 34 | + union_categoricals, |
| 35 | +) |
34 | 36 |
|
35 | 37 | from pandas.core.indexes.api import ensure_index_from_sequences
|
36 | 38 |
|
@@ -378,43 +380,15 @@ def _concatenate_chunks(chunks: list[dict[int, ArrayLike]]) -> dict:
|
378 | 380 | arrs = [chunk.pop(name) for chunk in chunks]
|
379 | 381 | # Check each arr for consistent types.
|
380 | 382 | dtypes = {a.dtype for a in arrs}
|
381 |
| - # TODO: shouldn't we exclude all EA dtypes here? |
382 |
| - numpy_dtypes = {x for x in dtypes if not is_categorical_dtype(x)} |
383 |
| - if len(numpy_dtypes) > 1: |
384 |
| - # error: Argument 1 to "find_common_type" has incompatible type |
385 |
| - # "Set[Any]"; expected "Sequence[Union[dtype[Any], None, type, |
386 |
| - # _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, |
387 |
| - # Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]" |
388 |
| - common_type = np.find_common_type( |
389 |
| - numpy_dtypes, # type: ignore[arg-type] |
390 |
| - [], |
391 |
| - ) |
392 |
| - if common_type == np.dtype(object): |
393 |
| - warning_columns.append(str(name)) |
| 383 | + non_cat_dtypes = {x for x in dtypes if not is_categorical_dtype(x)} |
394 | 384 |
|
395 | 385 | dtype = dtypes.pop()
|
396 | 386 | if is_categorical_dtype(dtype):
|
397 | 387 | result[name] = union_categoricals(arrs, sort_categories=False)
|
398 | 388 | else:
|
399 |
| - if isinstance(dtype, ExtensionDtype): |
400 |
| - # TODO: concat_compat? |
401 |
| - array_type = dtype.construct_array_type() |
402 |
| - # error: Argument 1 to "_concat_same_type" of "ExtensionArray" |
403 |
| - # has incompatible type "List[Union[ExtensionArray, ndarray]]"; |
404 |
| - # expected "Sequence[ExtensionArray]" |
405 |
| - result[name] = array_type._concat_same_type( |
406 |
| - arrs # type: ignore[arg-type] |
407 |
| - ) |
408 |
| - else: |
409 |
| - # error: Argument 1 to "concatenate" has incompatible |
410 |
| - # type "List[Union[ExtensionArray, ndarray[Any, Any]]]" |
411 |
| - # ; expected "Union[_SupportsArray[dtype[Any]], |
412 |
| - # Sequence[_SupportsArray[dtype[Any]]], |
413 |
| - # Sequence[Sequence[_SupportsArray[dtype[Any]]]], |
414 |
| - # Sequence[Sequence[Sequence[_SupportsArray[dtype[Any]]]]] |
415 |
| - # , Sequence[Sequence[Sequence[Sequence[ |
416 |
| - # _SupportsArray[dtype[Any]]]]]]]" |
417 |
| - result[name] = np.concatenate(arrs) # type: ignore[arg-type] |
| 389 | + result[name] = concat_compat(arrs) |
| 390 | + if len(non_cat_dtypes) > 1 and result[name].dtype == np.dtype(object): |
| 391 | + warning_columns.append(str(name)) |
418 | 392 |
|
419 | 393 | if warning_columns:
|
420 | 394 | warning_names = ",".join(warning_columns)
|
|
0 commit comments