|
16 | 16 | Iterator,
|
17 | 17 | Sequence,
|
18 | 18 | final,
|
19 |
| - overload, |
20 | 19 | )
|
21 | 20 |
|
22 | 21 | import numpy as np
|
|
57 | 56 | is_timedelta64_dtype,
|
58 | 57 | needs_i8_conversion,
|
59 | 58 | )
|
60 |
| -from pandas.core.dtypes.dtypes import ExtensionDtype |
61 | 59 | from pandas.core.dtypes.missing import (
|
62 | 60 | isna,
|
63 | 61 | maybe_fill,
|
|
70 | 68 | TimedeltaArray,
|
71 | 69 | )
|
72 | 70 | from pandas.core.arrays.boolean import BooleanDtype
|
73 |
| -from pandas.core.arrays.floating import ( |
74 |
| - Float64Dtype, |
75 |
| - FloatingDtype, |
76 |
| -) |
77 |
| -from pandas.core.arrays.integer import ( |
78 |
| - Int64Dtype, |
79 |
| - IntegerDtype, |
80 |
| -) |
| 71 | +from pandas.core.arrays.floating import FloatingDtype |
| 72 | +from pandas.core.arrays.integer import IntegerDtype |
81 | 73 | from pandas.core.arrays.masked import (
|
82 | 74 | BaseMaskedArray,
|
83 | 75 | BaseMaskedDtype,
|
@@ -277,41 +269,27 @@ def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
|
277 | 269 | out_dtype = "object"
|
278 | 270 | return np.dtype(out_dtype)
|
279 | 271 |
|
280 |
| - @overload |
281 | 272 | def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
|
282 |
| - ... # pragma: no cover |
283 |
| - |
284 |
| - @overload |
285 |
| - def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype: |
286 |
| - ... # pragma: no cover |
287 |
| - |
288 |
| - # TODO: general case implementation overridable by EAs. |
289 |
| - def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: |
290 | 273 | """
|
291 | 274 | Get the desired dtype of a result based on the
|
292 | 275 | input dtype and how it was computed.
|
293 | 276 |
|
294 | 277 | Parameters
|
295 | 278 | ----------
|
296 |
| - dtype : np.dtype or ExtensionDtype |
297 |
| - Input dtype. |
| 279 | + dtype : np.dtype |
298 | 280 |
|
299 | 281 | Returns
|
300 | 282 | -------
|
301 |
| - np.dtype or ExtensionDtype |
| 283 | + np.dtype |
302 | 284 | The desired dtype of the result.
|
303 | 285 | """
|
304 | 286 | how = self.how
|
305 | 287 |
|
306 | 288 | if how in ["add", "cumsum", "sum", "prod"]:
|
307 | 289 | if dtype == np.dtype(bool):
|
308 | 290 | return np.dtype(np.int64)
|
309 |
| - elif isinstance(dtype, (BooleanDtype, IntegerDtype)): |
310 |
| - return Int64Dtype() |
311 | 291 | elif how in ["mean", "median", "var"]:
|
312 |
| - if isinstance(dtype, (BooleanDtype, IntegerDtype)): |
313 |
| - return Float64Dtype() |
314 |
| - elif is_float_dtype(dtype) or is_complex_dtype(dtype): |
| 292 | + if is_float_dtype(dtype) or is_complex_dtype(dtype): |
315 | 293 | return dtype
|
316 | 294 | elif is_numeric_dtype(dtype):
|
317 | 295 | return np.dtype(np.float64)
|
@@ -390,8 +368,18 @@ def _reconstruct_ea_result(
|
390 | 368 | Construct an ExtensionArray result from an ndarray result.
|
391 | 369 | """
|
392 | 370 |
|
393 |
| - if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)): |
394 |
| - dtype = self._get_result_dtype(values.dtype) |
| 371 | + if isinstance(values.dtype, StringDtype): |
| 372 | + dtype = values.dtype |
| 373 | + cls = dtype.construct_array_type() |
| 374 | + return cls._from_sequence(res_values, dtype=dtype) |
| 375 | + |
| 376 | + elif isinstance(values.dtype, BaseMaskedDtype): |
| 377 | + new_dtype = self._get_result_dtype(values.dtype.numpy_dtype) |
| 378 | + # error: Incompatible types in assignment (expression has type |
| 379 | + # "BaseMaskedDtype", variable has type "StringDtype") |
| 380 | + dtype = BaseMaskedDtype.from_numpy_dtype( # type: ignore[assignment] |
| 381 | + new_dtype |
| 382 | + ) |
395 | 383 | cls = dtype.construct_array_type()
|
396 | 384 | return cls._from_sequence(res_values, dtype=dtype)
|
397 | 385 |
|
@@ -433,7 +421,8 @@ def _masked_ea_wrap_cython_operation(
|
433 | 421 | **kwargs,
|
434 | 422 | )
|
435 | 423 |
|
436 |
| - dtype = self._get_result_dtype(orig_values.dtype) |
| 424 | + new_dtype = self._get_result_dtype(orig_values.dtype.numpy_dtype) |
| 425 | + dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype) |
437 | 426 | # TODO: avoid cast as res_values *should* already have the right
|
438 | 427 | # dtype; last attempt ran into trouble on 32bit linux build
|
439 | 428 | res_values = res_values.astype(dtype.type, copy=False)
|
|
0 commit comments