diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index fe40bc42887c4..cf81e6f173bdd 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -297,7 +297,9 @@ def trans(x): return result -def maybe_cast_result(result, obj: "Series", numeric_only: bool = False, how: str = ""): +def maybe_cast_result( + result: ArrayLike, obj: "Series", numeric_only: bool = False, how: str = "" +) -> ArrayLike: """ Try casting result to a different type if appropriate @@ -320,19 +322,20 @@ def maybe_cast_result(result, obj: "Series", numeric_only: bool = False, how: st dtype = obj.dtype dtype = maybe_cast_result_dtype(dtype, how) - if not is_scalar(result): - if ( - is_extension_array_dtype(dtype) - and not is_categorical_dtype(dtype) - and dtype.kind != "M" - ): - # We have to special case categorical so as not to upcast - # things like counts back to categorical - cls = dtype.construct_array_type() - result = maybe_cast_to_extension_array(cls, result, dtype=dtype) + assert not is_scalar(result) + + if ( + is_extension_array_dtype(dtype) + and not is_categorical_dtype(dtype) + and dtype.kind != "M" + ): + # We have to special case categorical so as not to upcast + # things like counts back to categorical + cls = dtype.construct_array_type() + result = maybe_cast_to_extension_array(cls, result, dtype=dtype) - elif numeric_only and is_numeric_dtype(dtype) or not numeric_only: - result = maybe_downcast_to_dtype(result, dtype) + elif numeric_only and is_numeric_dtype(dtype) or not numeric_only: + result = maybe_downcast_to_dtype(result, dtype) return result