Skip to content

Commit 96101f6

Browse files
jorisvandenbosscherhshadrach
authored andcommitted
ENH: general concat with ExtensionArrays through find_common_type (pandas-dev#33607)
1 parent 35d738d commit 96101f6

File tree

15 files changed

+209
-131
lines changed

15 files changed

+209
-131
lines changed

doc/source/whatsnew/v1.1.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ Backwards incompatible API changes
251251
- :meth:`DataFrame.at` and :meth:`Series.at` will raise a ``TypeError`` instead of a ``ValueError`` if an incompatible key is passed, and ``KeyError`` if a missing key is passed, matching the behavior of ``.loc[]`` (:issue:`31722`)
252252
- Passing an integer dtype other than ``int64`` to ``np.array(period_index, dtype=...)`` will now raise ``TypeError`` instead of incorrectly using ``int64`` (:issue:`32255`)
253253
- Passing an invalid ``fill_value`` to :meth:`Categorical.take` raises a ``ValueError`` instead of ``TypeError`` (:issue:`33660`)
254+
- Combining a ``Categorical`` with integer categories and which contains missing values
255+
with a float dtype column in operations such as :func:`concat` or :meth:`~DataFrame.append`
256+
will now result in a float column instead of an object dtyped column (:issue:`33607`)
254257

255258
``MultiIndex.get_indexer`` interprets `method` argument differently
256259
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

pandas/core/arrays/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ def _concat_same_type(
10041004
cls, to_concat: Sequence["ExtensionArray"]
10051005
) -> "ExtensionArray":
10061006
"""
1007-
Concatenate multiple array.
1007+
Concatenate multiple array of this dtype.
10081008
10091009
Parameters
10101010
----------
@@ -1014,6 +1014,11 @@ def _concat_same_type(
10141014
-------
10151015
ExtensionArray
10161016
"""
1017+
# Implementer note: this method will only be called with a sequence of
1018+
# ExtensionArrays of this class and with the same dtype as self. This
1019+
# should allow "easy" concatenation (no upcasting needed), and result
1020+
# in a new ExtensionArray of the same dtype.
1021+
# Note: this strict behaviour is only guaranteed starting with pandas 1.1
10171022
raise AbstractMethodError(cls)
10181023

10191024
# The _can_hold_na attribute is set to True so that pandas internals

pandas/core/arrays/categorical.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2296,9 +2296,9 @@ def _can_hold_na(self):
22962296

22972297
@classmethod
22982298
def _concat_same_type(self, to_concat):
2299-
from pandas.core.dtypes.concat import concat_categorical
2299+
from pandas.core.dtypes.concat import union_categoricals
23002300

2301-
return concat_categorical(to_concat)
2301+
return union_categoricals(to_concat)
23022302

23032303
def isin(self, values):
23042304
"""

pandas/core/arrays/integer.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numbers
2-
from typing import TYPE_CHECKING, Tuple, Type, Union
2+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
33
import warnings
44

55
import numpy as np
66

77
from pandas._libs import lib, missing as libmissing
8-
from pandas._typing import ArrayLike
8+
from pandas._typing import ArrayLike, DtypeObj
99
from pandas.compat import set_function_name
1010
from pandas.compat.numpy import function as nv
1111
from pandas.util._decorators import cache_readonly
@@ -96,6 +96,17 @@ def construct_array_type(cls) -> Type["IntegerArray"]:
9696
"""
9797
return IntegerArray
9898

99+
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
100+
# for now only handle other integer types
101+
if not all(isinstance(t, _IntegerDtype) for t in dtypes):
102+
return None
103+
np_dtype = np.find_common_type(
104+
[t.numpy_dtype for t in dtypes], [] # type: ignore
105+
)
106+
if np.issubdtype(np_dtype, np.integer):
107+
return _dtypes[str(np_dtype)]
108+
return None
109+
99110
def __from_arrow__(
100111
self, array: Union["pyarrow.Array", "pyarrow.ChunkedArray"]
101112
) -> "IntegerArray":

pandas/core/arrays/sparse/array.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -952,27 +952,7 @@ def copy(self):
952952

953953
@classmethod
954954
def _concat_same_type(cls, to_concat):
955-
fill_values = [x.fill_value for x in to_concat]
956-
957-
fill_value = fill_values[0]
958-
959-
# np.nan isn't a singleton, so we may end up with multiple
960-
# NaNs here, so we ignore tha all NA case too.
961-
if not (len(set(fill_values)) == 1 or isna(fill_values).all()):
962-
warnings.warn(
963-
"Concatenating sparse arrays with multiple fill "
964-
f"values: '{fill_values}'. Picking the first and "
965-
"converting the rest.",
966-
PerformanceWarning,
967-
stacklevel=6,
968-
)
969-
keep = to_concat[0]
970-
to_concat2 = [keep]
971-
972-
for arr in to_concat[1:]:
973-
to_concat2.append(cls(np.asarray(arr), fill_value=fill_value))
974-
975-
to_concat = to_concat2
955+
fill_value = to_concat[0].fill_value
976956

977957
values = []
978958
length = 0

pandas/core/arrays/sparse/dtype.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Sparse Dtype"""
22

33
import re
4-
from typing import TYPE_CHECKING, Any, Tuple, Type
4+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type
5+
import warnings
56

67
import numpy as np
78

8-
from pandas._typing import Dtype
9+
from pandas._typing import Dtype, DtypeObj
10+
from pandas.errors import PerformanceWarning
911

1012
from pandas.core.dtypes.base import ExtensionDtype
1113
from pandas.core.dtypes.cast import astype_nansafe
@@ -352,3 +354,23 @@ def _subtype_with_str(self):
352354
if isinstance(self.fill_value, str):
353355
return type(self.fill_value)
354356
return self.subtype
357+
358+
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
359+
360+
fill_values = [x.fill_value for x in dtypes if isinstance(x, SparseDtype)]
361+
fill_value = fill_values[0]
362+
363+
# np.nan isn't a singleton, so we may end up with multiple
364+
# NaNs here, so we ignore tha all NA case too.
365+
if not (len(set(fill_values)) == 1 or isna(fill_values).all()):
366+
warnings.warn(
367+
"Concatenating sparse arrays with multiple fill "
368+
f"values: '{fill_values}'. Picking the first and "
369+
"converting the rest.",
370+
PerformanceWarning,
371+
stacklevel=6,
372+
)
373+
374+
# TODO also handle non-numpy other dtypes
375+
np_dtypes = [x.subtype if isinstance(x, SparseDtype) else x for x in dtypes]
376+
return SparseDtype(np.find_common_type(np_dtypes, []), fill_value=fill_value)

pandas/core/dtypes/base.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88

9+
from pandas._typing import DtypeObj
910
from pandas.errors import AbstractMethodError
1011

1112
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
@@ -33,11 +34,12 @@ class ExtensionDtype:
3334
* type
3435
* name
3536
36-
The following attributes influence the behavior of the dtype in
37+
The following attributes and methods influence the behavior of the dtype in
3738
pandas operations
3839
3940
* _is_numeric
4041
* _is_boolean
42+
* _get_common_dtype
4143
4244
Optionally one can override construct_array_type for construction
4345
with the name of this dtype via the Registry. See
@@ -322,3 +324,31 @@ def _is_boolean(self) -> bool:
322324
bool
323325
"""
324326
return False
327+
328+
def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
329+
"""
330+
Return the common dtype, if one exists.
331+
332+
Used in `find_common_type` implementation. This is for example used
333+
to determine the resulting dtype in a concat operation.
334+
335+
If no common dtype exists, return None (which gives the other dtypes
336+
the chance to determine a common dtype). If all dtypes in the list
337+
return None, then the common dtype will be "object" dtype (this means
338+
it is never needed to return "object" dtype from this method itself).
339+
340+
Parameters
341+
----------
342+
dtypes : list of dtypes
343+
The dtypes for which to determine a common dtype. This is a list
344+
of np.dtype or ExtensionDtype instances.
345+
346+
Returns
347+
-------
348+
Common dtype (np.dtype or ExtensionDtype) or None
349+
"""
350+
if len(set(dtypes)) == 1:
351+
# only itself
352+
return self
353+
else:
354+
return None

pandas/core/dtypes/cast.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from datetime import date, datetime, timedelta
6-
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type
6+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type
77

88
import numpy as np
99

@@ -1423,7 +1423,7 @@ def maybe_cast_to_datetime(value, dtype, errors: str = "raise"):
14231423
return value
14241424

14251425

1426-
def find_common_type(types):
1426+
def find_common_type(types: List[DtypeObj]) -> DtypeObj:
14271427
"""
14281428
Find a common data type among the given dtypes.
14291429
@@ -1450,8 +1450,16 @@ def find_common_type(types):
14501450
if all(is_dtype_equal(first, t) for t in types[1:]):
14511451
return first
14521452

1453+
# get unique types (dict.fromkeys is used as order-preserving set())
1454+
types = list(dict.fromkeys(types).keys())
1455+
14531456
if any(isinstance(t, ExtensionDtype) for t in types):
1454-
return np.object
1457+
for t in types:
1458+
if isinstance(t, ExtensionDtype):
1459+
res = t._get_common_dtype(types)
1460+
if res is not None:
1461+
return res
1462+
return np.dtype("object")
14551463

14561464
# take lowest unit
14571465
if all(is_datetime64_dtype(t) for t in types):

0 commit comments

Comments
 (0)