Skip to content

Commit 746e5ee

Browse files
authored
ENH: EA._from_scalars (#53089)
* ENH: BaseStringArray._from_scalars * WIP: EA._from_scalars * ENH: implement EA._from_scalars * Fix StringDtype/CategoricalDtype combine * mypy fixup
1 parent 32c9c8f commit 746e5ee

File tree

7 files changed

+91
-17
lines changed

7 files changed

+91
-17
lines changed

pandas/core/arrays/base.py

+33
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
AstypeArg,
8787
AxisInt,
8888
Dtype,
89+
DtypeObj,
8990
FillnaOptions,
9091
InterpolateOptions,
9192
NumpySorter,
@@ -293,6 +294,38 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal
293294
"""
294295
raise AbstractMethodError(cls)
295296

297+
@classmethod
298+
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
299+
"""
300+
Strict analogue to _from_sequence, allowing only sequences of scalars
301+
that should be specifically inferred to the given dtype.
302+
303+
Parameters
304+
----------
305+
scalars : sequence
306+
dtype : ExtensionDtype
307+
308+
Raises
309+
------
310+
TypeError or ValueError
311+
312+
Notes
313+
-----
314+
This is called in a try/except block when casting the result of a
315+
pointwise operation.
316+
"""
317+
try:
318+
return cls._from_sequence(scalars, dtype=dtype, copy=False)
319+
except (ValueError, TypeError):
320+
raise
321+
except Exception:
322+
warnings.warn(
323+
"_from_scalars should only raise ValueError or TypeError. "
324+
"Consider overriding _from_scalars where appropriate.",
325+
stacklevel=find_stack_level(),
326+
)
327+
raise
328+
296329
@classmethod
297330
def _from_sequence_of_strings(
298331
cls, strings, *, dtype: Dtype | None = None, copy: bool = False

pandas/core/arrays/categorical.py

+17
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
AstypeArg,
102102
AxisInt,
103103
Dtype,
104+
DtypeObj,
104105
NpDtype,
105106
Ordered,
106107
Self,
@@ -509,6 +510,22 @@ def _from_sequence(
509510
) -> Self:
510511
return cls(scalars, dtype=dtype, copy=copy)
511512

513+
@classmethod
514+
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
515+
if dtype is None:
516+
# The _from_scalars strictness doesn't make much sense in this case.
517+
raise NotImplementedError
518+
519+
res = cls._from_sequence(scalars, dtype=dtype)
520+
521+
# if there are any non-category elements in scalars, these will be
522+
# converted to NAs in res.
523+
mask = isna(scalars)
524+
if not (mask == res.isna()).all():
525+
# Some non-category element in scalars got converted to NA in res.
526+
raise ValueError
527+
return res
528+
512529
@overload
513530
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
514531
...

pandas/core/arrays/datetimes.py

+9
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777

7878
from pandas._typing import (
7979
DateTimeErrorChoices,
80+
DtypeObj,
8081
IntervalClosedType,
8182
Self,
8283
TimeAmbiguous,
@@ -266,6 +267,14 @@ def _scalar_type(self) -> type[Timestamp]:
266267
_freq: BaseOffset | None = None
267268
_default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__
268269

270+
@classmethod
271+
def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
272+
if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]:
273+
# TODO: require any NAs be valid-for-DTA
274+
# TODO: if dtype is passed, check for tzawareness compat?
275+
raise ValueError
276+
return cls._from_sequence(scalars, dtype=dtype)
277+
269278
@classmethod
270279
def _validate_dtype(cls, values, dtype):
271280
# used in TimeLikeOps.__init__

pandas/core/arrays/string_.py

+8
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from pandas._typing import (
5757
AxisInt,
5858
Dtype,
59+
DtypeObj,
5960
NumpySorter,
6061
NumpyValueArrayLike,
6162
Scalar,
@@ -253,6 +254,13 @@ def tolist(self):
253254
return [x.tolist() for x in self]
254255
return list(self.to_numpy())
255256

257+
@classmethod
258+
def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
259+
if lib.infer_dtype(scalars, skipna=True) != "string":
260+
# TODO: require any NAs be valid-for-string
261+
raise ValueError
262+
return cls._from_sequence(scalars, dtype=dtype)
263+
256264

257265
# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
258266
# incompatible with definition in base class "ExtensionArray"

pandas/core/dtypes/cast.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -464,16 +464,11 @@ def maybe_cast_pointwise_result(
464464
"""
465465

466466
if isinstance(dtype, ExtensionDtype):
467-
if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)):
468-
# TODO: avoid this special-casing
469-
# We have to special case categorical so as not to upcast
470-
# things like counts back to categorical
471-
472-
cls = dtype.construct_array_type()
473-
if same_dtype:
474-
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
475-
else:
476-
result = _maybe_cast_to_extension_array(cls, result)
467+
cls = dtype.construct_array_type()
468+
if same_dtype:
469+
result = _maybe_cast_to_extension_array(cls, result, dtype=dtype)
470+
else:
471+
result = _maybe_cast_to_extension_array(cls, result)
477472

478473
elif (numeric_only and dtype.kind in "iufcb") or not numeric_only:
479474
result = maybe_downcast_to_dtype(result, dtype)
@@ -498,11 +493,14 @@ def _maybe_cast_to_extension_array(
498493
-------
499494
ExtensionArray or obj
500495
"""
501-
from pandas.core.arrays.string_ import BaseStringArray
496+
result: ArrayLike
502497

503-
# Everything can be converted to StringArrays, but we may not want to convert
504-
if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string":
505-
return obj
498+
if dtype is not None:
499+
try:
500+
result = cls._from_scalars(obj, dtype=dtype)
501+
except (TypeError, ValueError):
502+
return obj
503+
return result
506504

507505
try:
508506
result = cls._from_sequence(obj, dtype=dtype)

pandas/core/series.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@
7575
pandas_dtype,
7676
validate_all_hashable,
7777
)
78-
from pandas.core.dtypes.dtypes import ExtensionDtype
78+
from pandas.core.dtypes.dtypes import (
79+
CategoricalDtype,
80+
ExtensionDtype,
81+
)
7982
from pandas.core.dtypes.generic import ABCDataFrame
8083
from pandas.core.dtypes.inference import is_hashable
8184
from pandas.core.dtypes.missing import (
@@ -100,6 +103,7 @@
100103
from pandas.core.arrays.arrow import StructAccessor
101104
from pandas.core.arrays.categorical import CategoricalAccessor
102105
from pandas.core.arrays.sparse import SparseAccessor
106+
from pandas.core.arrays.string_ import StringDtype
103107
from pandas.core.construction import (
104108
extract_array,
105109
sanitize_array,
@@ -3377,7 +3381,12 @@ def combine(
33773381

33783382
# try_float=False is to match agg_series
33793383
npvalues = lib.maybe_convert_objects(new_values, try_float=False)
3380-
res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False)
3384+
# same_dtype here is a kludge to avoid casting e.g. [True, False] to
3385+
# ["True", "False"]
3386+
same_dtype = isinstance(self.dtype, (StringDtype, CategoricalDtype))
3387+
res_values = maybe_cast_pointwise_result(
3388+
npvalues, self.dtype, same_dtype=same_dtype
3389+
)
33813390
return self._constructor(res_values, index=new_index, name=new_name, copy=False)
33823391

33833392
def combine_first(self, other) -> Series:

pandas/tests/resample/test_timedelta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_resample_categorical_data_with_timedeltaindex():
103103
index=pd.TimedeltaIndex([0, 10], unit="s", freq="10s"),
104104
)
105105
expected = expected.reindex(["Group_obj", "Group"], axis=1)
106-
expected["Group"] = expected["Group_obj"]
106+
expected["Group"] = expected["Group_obj"].astype("category")
107107
tm.assert_frame_equal(result, expected)
108108

109109

0 commit comments

Comments
 (0)