From 057586c8b705e89d0dec04fbb94c40f8e9768865 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 3 May 2023 13:26:17 -0700 Subject: [PATCH 1/5] ENH: BaseStringArray._from_scalars --- pandas/core/arrays/string_.py | 8 ++++++++ pandas/core/dtypes/cast.py | 11 +++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index c9dc20cf93ddd..677c8c4fff091 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -57,6 +57,7 @@ NumpySorter, NumpyValueArrayLike, Scalar, + Self, npt, type_t, ) @@ -228,6 +229,13 @@ def tolist(self): return [x.tolist() for x in self] return list(self.to_numpy()) + @classmethod + def _from_scalars(cls, scalars, dtype=None) -> Self: + if lib.infer_dtype(scalars, skipna=True) != "string": + # TODO: require any NAs be valid-for-string + raise ValueError + return cls._from_sequence(scalars, dtype=dtype) + # error: Definition of "_concat_same_type" in base class "NDArrayBacked" is # incompatible with definition in base class "ExtensionArray" diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 6dabb866b8f5c..83b0260645aec 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -489,11 +489,14 @@ def _maybe_cast_to_extension_array( ------- ExtensionArray or obj """ - from pandas.core.arrays.string_ import BaseStringArray - # Everything can be converted to StringArrays, but we may not want to convert - if issubclass(cls, BaseStringArray) and lib.infer_dtype(obj) != "string": - return obj + if hasattr(cls, "_from_scalars"): + # TODO: get this everywhere! + try: + result = cls._from_scalars(obj, dtype=dtype) + except (TypeError, ValueError): + return obj + return result try: result = cls._from_sequence(obj, dtype=dtype) From edc8b9b5befc9e6d73ba3b260fc6a367b8a1404e Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 4 May 2023 16:03:26 -0700 Subject: [PATCH 2/5] WIP: EA._from_scalars --- pandas/core/arrays/categorical.py | 16 ++++++++++++++++ pandas/core/arrays/datetimes.py | 8 ++++++++ pandas/core/dtypes/cast.py | 20 +++++++++----------- pandas/tests/resample/test_timedelta.py | 2 +- 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index a4447bffed5f5..69d0dd25d8ff0 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -494,6 +494,22 @@ def _from_sequence( ) -> Self: return cls(scalars, dtype=dtype, copy=copy) + @classmethod + def _from_scalars(cls, scalars, dtype=None): + if dtype is None: + # The _from_scalars strictness doesn't make much sense in this case. + raise NotImplementedError + + res = cls._from_sequence(scalars, dtype=dtype) + + # if there are any non-category elements in scalars, these will be + # converted to NAs in res. + mask = isna(scalars) + if not (mask == res.isna()).all(): + # Some non-category element in scalars got converted to NA in res. + raise ValueError + return res + @overload def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray: ... diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 126a70a930065..4c757ae26b8b8 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -258,6 +258,14 @@ def _scalar_type(self) -> type[Timestamp]: _freq: BaseOffset | None = None _default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__ + @classmethod + def _from_scalars(cls, scalars, dtype=None): + if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]: + # TODO: require any NAs be valid-for-DTA + # TODO: if dtype is passed, check for tzawareness compat? + raise ValueError + return cls._from_sequence(scalars, dtype=dtype) + @classmethod def _validate_dtype(cls, values, dtype): # used in TimeLikeOps.__init__ diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 83b0260645aec..4367791895ab2 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -455,16 +455,11 @@ def maybe_cast_pointwise_result( """ if isinstance(dtype, ExtensionDtype): - if not isinstance(dtype, (CategoricalDtype, DatetimeTZDtype)): - # TODO: avoid this special-casing - # We have to special case categorical so as not to upcast - # things like counts back to categorical - - cls = dtype.construct_array_type() - if same_dtype: - result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) - else: - result = _maybe_cast_to_extension_array(cls, result) + cls = dtype.construct_array_type() + if same_dtype: + result = _maybe_cast_to_extension_array(cls, result, dtype=dtype) + else: + result = _maybe_cast_to_extension_array(cls, result) elif (numeric_only and dtype.kind in "iufcb") or not numeric_only: result = maybe_downcast_to_dtype(result, dtype) @@ -494,7 +489,10 @@ def _maybe_cast_to_extension_array( # TODO: get this everywhere! try: result = cls._from_scalars(obj, dtype=dtype) - except (TypeError, ValueError): + except (TypeError, ValueError, NotImplementedError): + # TODO: document that _from_scalars should only raise ValueError + # or TypeError; NotImplementedError is here until we decide what + # to do for Categorical. return obj return result diff --git a/pandas/tests/resample/test_timedelta.py b/pandas/tests/resample/test_timedelta.py index 8b6e757c0a46a..0b7280b3b3d05 100644 --- a/pandas/tests/resample/test_timedelta.py +++ b/pandas/tests/resample/test_timedelta.py @@ -103,7 +103,7 @@ def test_resample_categorical_data_with_timedeltaindex(): index=pd.TimedeltaIndex([0, 10], unit="s", freq="10s"), ) expected = expected.reindex(["Group_obj", "Group"], axis=1) - expected["Group"] = expected["Group_obj"] + expected["Group"] = expected["Group_obj"].astype("category") tm.assert_frame_equal(result, expected) From e9853c8c50d2df7a98b4e596ca35589b36d12ef6 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 15 May 2023 16:15:08 -0700 Subject: [PATCH 3/5] ENH: implement EA._from_scalars --- pandas/core/arrays/base.py | 35 +++++++++++++++++++++++++++++++ pandas/core/arrays/categorical.py | 2 +- pandas/core/arrays/datetimes.py | 3 ++- pandas/core/arrays/string_.py | 3 ++- pandas/core/dtypes/cast.py | 3 +-- 5 files changed, 41 insertions(+), 5 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 27eb7994d3ccb..a902bb5bb45e9 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -20,6 +20,7 @@ cast, overload, ) +import warnings import numpy as np @@ -32,6 +33,7 @@ Substitution, cache_readonly, ) +from pandas.util._exceptions import find_stack_level from pandas.util._validators import ( validate_bool_kwarg, validate_fillna_kwargs, @@ -77,6 +79,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, FillnaOptions, NumpySorter, NumpyValueArrayLike, @@ -262,6 +265,38 @@ def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy: bool = Fal """ raise AbstractMethodError(cls) + @classmethod + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: + """ + Strict analogue to _from_sequence, allowing only sequences of scalars + that should be specifically inferred to the given dtype. + + Parameters + ---------- + scalars : sequence + dtype : ExtensionDtype + + Raises + ------ + TypeError or ValueError + + Notes + ----- + This is called in a try/except block when casting the result of a + pointwise operation. + """ + try: + return cls._from_sequence(scalars, dtype=dtype, copy=False) + except (ValueError, TypeError): + raise + except Exception: + warnings.warn( + "_from_scalars should only raise ValueError or TypeError. " + "Consider overriding _from_scalars where appropriate.", + stacklevel=find_stack_level(), + ) + raise + @classmethod def _from_sequence_of_strings( cls, strings, *, dtype: Dtype | None = None, copy: bool = False diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 15a5fe4cbbbe2..e0c2794642f56 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -495,7 +495,7 @@ def _from_sequence( return cls(scalars, dtype=dtype, copy=copy) @classmethod - def _from_scalars(cls, scalars, dtype=None): + def _from_scalars(cls, scalars, *, dtype): if dtype is None: # The _from_scalars strictness doesn't make much sense in this case. raise NotImplementedError diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index a6648ac79814e..08a56e5c5aa10 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -74,6 +74,7 @@ if TYPE_CHECKING: from pandas._typing import ( DateTimeErrorChoices, + DtypeObj, IntervalClosedType, Self, TimeAmbiguous, @@ -258,7 +259,7 @@ def _scalar_type(self) -> type[Timestamp]: _default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__ @classmethod - def _from_scalars(cls, scalars, dtype=None): + def _from_scalars(cls, scalars, *, dtype: DtypeObj): if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]: # TODO: require any NAs be valid-for-DTA # TODO: if dtype is passed, check for tzawareness compat? diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 677c8c4fff091..a2ebd865f40cf 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -54,6 +54,7 @@ from pandas._typing import ( AxisInt, Dtype, + DtypeObj, NumpySorter, NumpyValueArrayLike, Scalar, @@ -230,7 +231,7 @@ def tolist(self): return list(self.to_numpy()) @classmethod - def _from_scalars(cls, scalars, dtype=None) -> Self: + def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self: if lib.infer_dtype(scalars, skipna=True) != "string": # TODO: require any NAs be valid-for-string raise ValueError diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index c81e665b0da4b..c325dfc7edba8 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -485,8 +485,7 @@ def _maybe_cast_to_extension_array( ExtensionArray or obj """ - if hasattr(cls, "_from_scalars"): - # TODO: get this everywhere! + if dtype is not None: try: result = cls._from_scalars(obj, dtype=dtype) except (TypeError, ValueError, NotImplementedError): From 8c1cfcefa7d660c8825922a210133bbcbd551e1b Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 10 Jul 2023 14:41:56 -0700 Subject: [PATCH 4/5] Fix StringDtype/CategoricalDtype combine --- pandas/core/arrays/categorical.py | 3 ++- pandas/core/arrays/datetimes.py | 2 +- pandas/core/dtypes/cast.py | 5 +---- pandas/core/series.py | 9 ++++++++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 16bbf57abc29c..30739158c9fce 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -98,6 +98,7 @@ AstypeArg, AxisInt, Dtype, + DtypeObj, NpDtype, Ordered, Self, @@ -507,7 +508,7 @@ def _from_sequence( return cls(scalars, dtype=dtype, copy=copy) @classmethod - def _from_scalars(cls, scalars, *, dtype): + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: if dtype is None: # The _from_scalars strictness doesn't make much sense in this case. raise NotImplementedError diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index e68b427cbf3cd..fc59036c6bf31 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -267,7 +267,7 @@ def _scalar_type(self) -> type[Timestamp]: _default_dtype = DT64NS_DTYPE # used in TimeLikeOps.__init__ @classmethod - def _from_scalars(cls, scalars, *, dtype: DtypeObj): + def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self: if lib.infer_dtype(scalars, skipna=True) not in ["datetime", "datetime64"]: # TODO: require any NAs be valid-for-DTA # TODO: if dtype is passed, check for tzawareness compat? diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 072753b050c96..d47ed3123c85c 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -491,10 +491,7 @@ def _maybe_cast_to_extension_array( if dtype is not None: try: result = cls._from_scalars(obj, dtype=dtype) - except (TypeError, ValueError, NotImplementedError): - # TODO: document that _from_scalars should only raise ValueError - # or TypeError; NotImplementedError is here until we decide what - # to do for Categorical. + except (TypeError, ValueError): return obj return result diff --git a/pandas/core/series.py b/pandas/core/series.py index 2fc926d7e43d1..ba1e348a4801f 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -74,6 +74,7 @@ ) from pandas.core.dtypes.dtypes import ( ArrowDtype, + CategoricalDtype, ExtensionDtype, ) from pandas.core.dtypes.generic import ABCDataFrame @@ -99,6 +100,7 @@ from pandas.core.arrays import ExtensionArray from pandas.core.arrays.categorical import CategoricalAccessor from pandas.core.arrays.sparse import SparseAccessor +from pandas.core.arrays.string_ import StringDtype from pandas.core.construction import ( extract_array, sanitize_array, @@ -3298,7 +3300,12 @@ def combine( # try_float=False is to match agg_series npvalues = lib.maybe_convert_objects(new_values, try_float=False) - res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False) + # same_dtype here is a kludge to avoid casting e.g. [True, False] to + # ["True", "False"] + same_dtype = isinstance(self.dtype, (StringDtype, CategoricalDtype)) + res_values = maybe_cast_pointwise_result( + npvalues, self.dtype, same_dtype=same_dtype + ) return self._constructor(res_values, index=new_index, name=new_name, copy=False) def combine_first(self, other) -> Series: From 52381413e022b87549dd13202611854c9a202f8b Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 10 Jul 2023 18:21:16 -0700 Subject: [PATCH 5/5] mypy fixup --- pandas/core/dtypes/cast.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 9db8b02e60c22..dbad0e65dd4f2 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -490,6 +490,7 @@ def _maybe_cast_to_extension_array( ------- ExtensionArray or obj """ + result: ArrayLike if dtype is not None: try: