|
32 | 32 | FrameOrSeries,
|
33 | 33 | Shape,
|
34 | 34 | final,
|
| 35 | + npt, |
35 | 36 | )
|
36 | 37 | from pandas.errors import AbstractMethodError
|
37 | 38 | from pandas.util._decorators import cache_readonly
|
@@ -341,95 +342,54 @@ def _ea_wrap_cython_operation(
|
341 | 342 | comp_ids=comp_ids,
|
342 | 343 | **kwargs,
|
343 | 344 | )
|
344 |
| - orig_values = values |
345 | 345 |
|
346 |
| - if isinstance(orig_values, (DatetimeArray, PeriodArray)): |
| 346 | + if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)): |
347 | 347 | # All of the functions implemented here are ordinal, so we can
|
348 | 348 | # operate on the tz-naive equivalents
|
349 |
| - npvalues = orig_values._ndarray.view("M8[ns]") |
350 |
| - res_values = self._cython_op_ndim_compat( |
351 |
| - npvalues, |
352 |
| - min_count=min_count, |
353 |
| - ngroups=ngroups, |
354 |
| - comp_ids=comp_ids, |
355 |
| - mask=None, |
356 |
| - **kwargs, |
357 |
| - ) |
358 |
| - if self.how in ["rank"]: |
359 |
| - # i.e. how in WrappedCythonOp.cast_blocklist, since |
360 |
| - # other cast_blocklist methods dont go through cython_operation |
361 |
| - # preserve float64 dtype |
362 |
| - return res_values |
363 |
| - |
364 |
| - res_values = res_values.view("i8") |
365 |
| - result = type(orig_values)(res_values, dtype=orig_values.dtype) |
366 |
| - return result |
367 |
| - |
368 |
| - elif isinstance(orig_values, TimedeltaArray): |
369 |
| - # We have an ExtensionArray but not ExtensionDtype |
370 |
| - res_values = self._cython_op_ndim_compat( |
371 |
| - orig_values._ndarray, |
372 |
| - min_count=min_count, |
373 |
| - ngroups=ngroups, |
374 |
| - comp_ids=comp_ids, |
375 |
| - mask=None, |
376 |
| - **kwargs, |
377 |
| - ) |
378 |
| - if self.how in ["rank"]: |
379 |
| - # i.e. how in WrappedCythonOp.cast_blocklist, since |
380 |
| - # other cast_blocklist methods dont go through cython_operation |
381 |
| - # preserve float64 dtype |
382 |
| - return res_values |
383 |
| - |
384 |
| - # otherwise res_values has the same dtype as original values |
385 |
| - return type(orig_values)(res_values) |
386 |
| - |
| 349 | + npvalues = values._ndarray.view("M8[ns]") |
387 | 350 | elif isinstance(values.dtype, (BooleanDtype, _IntegerDtype)):
|
388 | 351 | # IntegerArray or BooleanArray
|
389 | 352 | npvalues = values.to_numpy("float64", na_value=np.nan)
|
390 |
| - res_values = self._cython_op_ndim_compat( |
391 |
| - npvalues, |
392 |
| - min_count=min_count, |
393 |
| - ngroups=ngroups, |
394 |
| - comp_ids=comp_ids, |
395 |
| - mask=None, |
396 |
| - **kwargs, |
397 |
| - ) |
398 |
| - if self.how in ["rank"]: |
399 |
| - # i.e. how in WrappedCythonOp.cast_blocklist, since |
400 |
| - # other cast_blocklist methods dont go through cython_operation |
401 |
| - return res_values |
402 |
| - |
403 |
| - dtype = self._get_result_dtype(orig_values.dtype) |
404 |
| - cls = dtype.construct_array_type() |
405 |
| - return cls._from_sequence(res_values, dtype=dtype) |
406 |
| - |
407 | 353 | elif isinstance(values.dtype, FloatingDtype):
|
408 | 354 | # FloatingArray
|
409 |
| - npvalues = values.to_numpy( |
410 |
| - values.dtype.numpy_dtype, |
411 |
| - na_value=np.nan, |
412 |
| - ) |
413 |
| - res_values = self._cython_op_ndim_compat( |
414 |
| - npvalues, |
415 |
| - min_count=min_count, |
416 |
| - ngroups=ngroups, |
417 |
| - comp_ids=comp_ids, |
418 |
| - mask=None, |
419 |
| - **kwargs, |
| 355 | + npvalues = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan) |
| 356 | + else: |
| 357 | + raise NotImplementedError( |
| 358 | + f"function is not implemented for this dtype: {values.dtype}" |
420 | 359 | )
|
421 |
| - if self.how in ["rank"]: |
422 |
| - # i.e. how in WrappedCythonOp.cast_blocklist, since |
423 |
| - # other cast_blocklist methods dont go through cython_operation |
424 |
| - return res_values |
425 | 360 |
|
426 |
| - dtype = self._get_result_dtype(orig_values.dtype) |
| 361 | + res_values = self._cython_op_ndim_compat( |
| 362 | + npvalues, |
| 363 | + min_count=min_count, |
| 364 | + ngroups=ngroups, |
| 365 | + comp_ids=comp_ids, |
| 366 | + mask=None, |
| 367 | + **kwargs, |
| 368 | + ) |
| 369 | + |
| 370 | + if self.how in ["rank"]: |
| 371 | + # i.e. how in WrappedCythonOp.cast_blocklist, since |
| 372 | + # other cast_blocklist methods dont go through cython_operation |
| 373 | + return res_values |
| 374 | + |
| 375 | + return self._reconstruct_ea_result(values, res_values) |
| 376 | + |
| 377 | + def _reconstruct_ea_result(self, values, res_values): |
| 378 | + """ |
| 379 | + Construct an ExtensionArray result from an ndarray result. |
| 380 | + """ |
| 381 | + # TODO: allow EAs to override this logic |
| 382 | + |
| 383 | + if isinstance(values.dtype, (BooleanDtype, _IntegerDtype, FloatingDtype)): |
| 384 | + dtype = self._get_result_dtype(values.dtype) |
427 | 385 | cls = dtype.construct_array_type()
|
428 | 386 | return cls._from_sequence(res_values, dtype=dtype)
|
429 | 387 |
|
430 |
| - raise NotImplementedError( |
431 |
| - f"function is not implemented for this dtype: {values.dtype}" |
432 |
| - ) |
| 388 | + elif needs_i8_conversion(values.dtype): |
| 389 | + i8values = res_values.view("i8") |
| 390 | + return type(values)(i8values, dtype=values.dtype) |
| 391 | + |
| 392 | + raise NotImplementedError |
433 | 393 |
|
434 | 394 | @final
|
435 | 395 | def _masked_ea_wrap_cython_operation(
|
@@ -478,6 +438,8 @@ def _cython_op_ndim_compat(
|
478 | 438 | if values.ndim == 1:
|
479 | 439 | # expand to 2d, dispatch, then squeeze if appropriate
|
480 | 440 | values2d = values[None, :]
|
| 441 | + if mask is not None: |
| 442 | + mask = mask[None, :] |
481 | 443 | res = self._call_cython_op(
|
482 | 444 | values2d,
|
483 | 445 | min_count=min_count,
|
@@ -533,9 +495,8 @@ def _call_cython_op(
|
533 | 495 | values = ensure_float64(values)
|
534 | 496 |
|
535 | 497 | values = values.T
|
536 |
| - |
537 | 498 | if mask is not None:
|
538 |
| - mask = mask.reshape(values.shape, order="C") |
| 499 | + mask = mask.T |
539 | 500 |
|
540 | 501 | out_shape = self._get_output_shape(ngroups, values)
|
541 | 502 | func, values = self.get_cython_func_and_vals(values, is_numeric)
|
@@ -677,7 +638,7 @@ def __init__(
|
677 | 638 | sort: bool = True,
|
678 | 639 | group_keys: bool = True,
|
679 | 640 | mutated: bool = False,
|
680 |
| - indexer: np.ndarray | None = None, |
| 641 | + indexer: npt.NDArray[np.intp] | None = None, |
681 | 642 | dropna: bool = True,
|
682 | 643 | ):
|
683 | 644 | assert isinstance(axis, Index), axis
|
@@ -1268,7 +1229,13 @@ def _is_indexed_like(obj, axes, axis: int) -> bool:
|
1268 | 1229 |
|
1269 | 1230 |
|
1270 | 1231 | class DataSplitter(Generic[FrameOrSeries]):
|
1271 |
| - def __init__(self, data: FrameOrSeries, labels, ngroups: int, axis: int = 0): |
| 1232 | + def __init__( |
| 1233 | + self, |
| 1234 | + data: FrameOrSeries, |
| 1235 | + labels: npt.NDArray[np.intp], |
| 1236 | + ngroups: int, |
| 1237 | + axis: int = 0, |
| 1238 | + ): |
1272 | 1239 | self.data = data
|
1273 | 1240 | self.labels = ensure_platform_int(labels) # _should_ already be np.intp
|
1274 | 1241 | self.ngroups = ngroups
|
|
0 commit comments