Skip to content

Commit 3238381

Browse files
committed
TYP: Use Self for type checking (pandas/core/internals/)
1 parent 707dda0 commit 3238381

File tree

4 files changed

+128
-82
lines changed

4 files changed

+128
-82
lines changed

.startup.ipy

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
>>> from itertools import product
2+
>>> import numpy as np
3+
>>> import pandas as pd
4+
>>> from pandas.core.reshape.concat import _Concatenator
5+
>>>
6+
>>> def manual_concat(df_list: list[pd.DataFrame]) -> pd.DataFrame:
7+
... columns = [col for df in df_list for col in df.columns]
8+
... columns = list(dict.fromkeys(columns))
9+
... index = np.hstack([df.index.values for df in df_list])
10+
... df_list = [df.reindex(columns=columns) for df in df_list]
11+
... values = np.vstack([df.values for df in df_list])
12+
... return pd.DataFrame(values, index=index, columns=columns, dtype=df_list[0].dtypes[0])
13+
>>>
14+
>>> def compare_frames(df_list: list[pd.DataFrame]) -> None:
15+
... concat_df = pd.concat(df_list)
16+
... manual_df = manual_concat(df_list)
17+
... if not concat_df.equals(manual_df):
18+
... raise ValueError("different concatenations!")
19+
>>>
20+
>>> def make_dataframes(num_dfs, num_idx, num_cols, dtype=np.int32, drop_column=False) -> list[pd.DataFrame]:
21+
... values = np.random.randint(-100, 100, size=[num_idx, num_cols])
22+
... index = [f"i{i}" for i in range(num_idx)]
23+
... columns = np.random.choice([f"c{i}" for i in range(num_cols)], num_cols, replace=False)
24+
... df = pd.DataFrame(values, index=index, columns=columns, dtype=dtype)
25+
...
26+
... df_list = []
27+
... for i in range(num_dfs):
28+
... new_df = df.copy()
29+
... if drop_column:
30+
... label = new_df.columns[i]
31+
... new_df = new_df.drop(label, axis=1)
32+
... df_list.append(new_df)
33+
... return df_list
34+
>>>
35+
>>> test_data = [ # num_idx, num_cols, num_dfs
36+
... [100, 1_000, 3],
37+
... ]
38+
>>> for i, (num_idx, num_cols, num_dfs) in enumerate(test_data):
39+
... print(f"\n{i}: {num_dfs=}, {num_idx=}, {num_cols=}")
40+
... df_list = make_dataframes(num_dfs, num_idx, num_cols, drop_column=False)
41+
... df_list_dropped = make_dataframes(num_dfs, num_idx, num_cols, drop_column=True)
42+
... print("manual:")
43+
... %timeit manual_concat(df_list)
44+
... compare_frames(df_list)
45+
... for use_dropped in [False, True]:
46+
... print(f"pd.concat: {use_dropped=}")
47+
... this_df_list = df_list if not use_dropped else df_list_dropped
48+
... %timeit pd.concat(this_df_list)

pandas/core/internals/array_manager.py

+34-34
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Callable,
1010
Hashable,
1111
Literal,
12-
TypeVar,
1312
)
1413

1514
import numpy as np
@@ -92,10 +91,10 @@
9291
ArrayLike,
9392
AxisInt,
9493
DtypeObj,
94+
Self,
9595
QuantileInterpolation,
9696
npt,
9797
)
98-
T = TypeVar("T", bound="BaseArrayManager")
9998

10099

101100
class BaseArrayManager(DataManager):
@@ -131,7 +130,7 @@ def __init__(
131130
) -> None:
132131
raise NotImplementedError
133132

134-
def make_empty(self: T, axes=None) -> T:
133+
def make_empty(self, axes=None) -> Self:
135134
"""Return an empty ArrayManager with the items axis of len 0 (no columns)"""
136135
if axes is None:
137136
axes = [self.axes[1:], Index([])]
@@ -195,11 +194,11 @@ def __repr__(self) -> str:
195194
return output
196195

197196
def apply(
198-
self: T,
197+
self,
199198
f,
200199
align_keys: list[str] | None = None,
201200
**kwargs,
202-
) -> T:
201+
) -> Self:
203202
"""
204203
Iterate over the arrays, collect and create a new ArrayManager.
205204
@@ -257,8 +256,8 @@ def apply(
257256
return type(self)(result_arrays, new_axes) # type: ignore[arg-type]
258257

259258
def apply_with_block(
260-
self: T, f, align_keys=None, swap_axis: bool = True, **kwargs
261-
) -> T:
259+
self, f, align_keys=None, swap_axis: bool = True, **kwargs
260+
) -> Self:
262261
# switch axis to follow BlockManager logic
263262
if swap_axis and "axis" in kwargs and self.ndim == 2:
264263
kwargs["axis"] = 1 if kwargs["axis"] == 0 else 0
@@ -311,7 +310,7 @@ def apply_with_block(
311310

312311
return type(self)(result_arrays, self._axes)
313312

314-
def where(self: T, other, cond, align: bool) -> T:
313+
def where(self, other, cond, align: bool) -> Self:
315314
if align:
316315
align_keys = ["other", "cond"]
317316
else:
@@ -325,13 +324,13 @@ def where(self: T, other, cond, align: bool) -> T:
325324
cond=cond,
326325
)
327326

328-
def round(self: T, decimals: int, using_cow: bool = False) -> T:
327+
def round(self, decimals: int, using_cow: bool = False) -> Self:
329328
return self.apply_with_block("round", decimals=decimals, using_cow=using_cow)
330329

331-
def setitem(self: T, indexer, value) -> T:
330+
def setitem(self, indexer, value) -> Self:
332331
return self.apply_with_block("setitem", indexer=indexer, value=value)
333332

334-
def putmask(self: T, mask, new, align: bool = True) -> T:
333+
def putmask(self, mask, new, align: bool = True) -> Self:
335334
if align:
336335
align_keys = ["new", "mask"]
337336
else:
@@ -345,14 +344,14 @@ def putmask(self: T, mask, new, align: bool = True) -> T:
345344
new=new,
346345
)
347346

348-
def diff(self: T, n: int, axis: AxisInt) -> T:
347+
def diff(self, n: int, axis: AxisInt) -> Self:
349348
assert self.ndim == 2 and axis == 0 # caller ensures
350349
return self.apply(algos.diff, n=n, axis=axis)
351350

352-
def interpolate(self: T, **kwargs) -> T:
351+
def interpolate(self, **kwargs) -> Self:
353352
return self.apply_with_block("interpolate", swap_axis=False, **kwargs)
354353

355-
def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
354+
def shift(self, periods: int, axis: AxisInt, fill_value) -> Self:
356355
if fill_value is lib.no_default:
357356
fill_value = None
358357

@@ -364,7 +363,7 @@ def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
364363
"shift", periods=periods, axis=axis, fill_value=fill_value
365364
)
366365

367-
def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
366+
def fillna(self, value, limit, inplace: bool, downcast) -> Self:
368367
if limit is not None:
369368
# Do this validation even if we go through one of the no-op paths
370369
limit = libalgos.validate_limit(None, limit=limit)
@@ -373,13 +372,13 @@ def fillna(self: T, value, limit, inplace: bool, downcast) -> T:
373372
"fillna", value=value, limit=limit, inplace=inplace, downcast=downcast
374373
)
375374

376-
def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") -> T:
375+
def astype(self, dtype, copy: bool | None = False, errors: str = "raise") -> Self:
377376
if copy is None:
378377
copy = True
379378

380379
return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors)
381380

382-
def convert(self: T, copy: bool | None) -> T:
381+
def convert(self, copy: bool | None) -> Self:
383382
if copy is None:
384383
copy = True
385384

@@ -402,10 +401,10 @@ def _convert(arr):
402401

403402
return self.apply(_convert)
404403

405-
def replace_regex(self: T, **kwargs) -> T:
404+
def replace_regex(self, **kwargs) -> Self:
406405
return self.apply_with_block("_replace_regex", **kwargs)
407406

408-
def replace(self: T, to_replace, value, inplace: bool) -> T:
407+
def replace(self, to_replace, value, inplace: bool) -> Self:
409408
inplace = validate_bool_kwarg(inplace, "inplace")
410409
assert np.ndim(value) == 0, value
411410
# TODO "replace" is right now implemented on the blocks, we should move
@@ -415,12 +414,12 @@ def replace(self: T, to_replace, value, inplace: bool) -> T:
415414
)
416415

417416
def replace_list(
418-
self: T,
417+
self,
419418
src_list: list[Any],
420419
dest_list: list[Any],
421420
inplace: bool = False,
422421
regex: bool = False,
423-
) -> T:
422+
) -> Self:
424423
"""do a list replace"""
425424
inplace = validate_bool_kwarg(inplace, "inplace")
426425

@@ -432,7 +431,7 @@ def replace_list(
432431
regex=regex,
433432
)
434433

435-
def to_native_types(self: T, **kwargs) -> T:
434+
def to_native_types(self, **kwargs) -> Self:
436435
return self.apply(to_native_types, **kwargs)
437436

438437
@property
@@ -458,7 +457,7 @@ def is_view(self) -> bool:
458457
def is_single_block(self) -> bool:
459458
return len(self.arrays) == 1
460459

461-
def _get_data_subset(self: T, predicate: Callable) -> T:
460+
def _get_data_subset(self, predicate: Callable) -> Self:
462461
indices = [i for i, arr in enumerate(self.arrays) if predicate(arr)]
463462
arrays = [self.arrays[i] for i in indices]
464463
# TODO copy?
@@ -469,7 +468,7 @@ def _get_data_subset(self: T, predicate: Callable) -> T:
469468
new_axes = [self._axes[0], new_cols]
470469
return type(self)(arrays, new_axes, verify_integrity=False)
471470

472-
def get_bool_data(self: T, copy: bool = False) -> T:
471+
def get_bool_data(self, copy: bool = False) -> Self:
473472
"""
474473
Select columns that are bool-dtype and object-dtype columns that are all-bool.
475474
@@ -480,7 +479,7 @@ def get_bool_data(self: T, copy: bool = False) -> T:
480479
"""
481480
return self._get_data_subset(lambda x: x.dtype == np.dtype(bool))
482481

483-
def get_numeric_data(self: T, copy: bool = False) -> T:
482+
def get_numeric_data(self, copy: bool = False) -> Self:
484483
"""
485484
Select columns that have a numeric dtype.
486485
@@ -494,7 +493,7 @@ def get_numeric_data(self: T, copy: bool = False) -> T:
494493
or getattr(arr.dtype, "_is_numeric", False)
495494
)
496495

497-
def copy(self: T, deep: bool | Literal["all"] | None = True) -> T:
496+
def copy(self, deep: bool | Literal["all"] | None = True) -> Self:
498497
"""
499498
Make deep or shallow copy of ArrayManager
500499
@@ -531,7 +530,7 @@ def copy_func(ax):
531530
return type(self)(new_arrays, new_axes, verify_integrity=False)
532531

533532
def reindex_indexer(
534-
self: T,
533+
self,
535534
new_axis,
536535
indexer,
537536
axis: AxisInt,
@@ -542,7 +541,7 @@ def reindex_indexer(
542541
only_slice: bool = False,
543542
# ArrayManager specific keywords
544543
use_na_proxy: bool = False,
545-
) -> T:
544+
) -> Self:
546545
axis = self._normalize_axis(axis)
547546
return self._reindex_indexer(
548547
new_axis,
@@ -555,15 +554,15 @@ def reindex_indexer(
555554
)
556555

557556
def _reindex_indexer(
558-
self: T,
557+
self,
559558
new_axis,
560559
indexer: npt.NDArray[np.intp] | None,
561560
axis: AxisInt,
562561
fill_value=None,
563562
allow_dups: bool = False,
564563
copy: bool | None = True,
565564
use_na_proxy: bool = False,
566-
) -> T:
565+
) -> Self:
567566
"""
568567
Parameters
569568
----------
@@ -634,11 +633,12 @@ def _reindex_indexer(
634633
return type(self)(new_arrays, new_axes, verify_integrity=False)
635634

636635
def take(
637-
self: T,
636+
self,
638637
indexer: npt.NDArray[np.intp],
639638
axis: AxisInt = 1,
640639
verify: bool = True,
641-
) -> T:
640+
convert_indices: bool = True,
641+
) -> Self:
642642
"""
643643
Take items along any axis.
644644
"""
@@ -926,7 +926,7 @@ def idelete(self, indexer) -> ArrayManager:
926926
# --------------------------------------------------------------------
927927
# Array-wise Operation
928928

929-
def grouped_reduce(self: T, func: Callable) -> T:
929+
def grouped_reduce(self, func: Callable) -> Self:
930930
"""
931931
Apply grouped reduction function columnwise, returning a new ArrayManager.
932932
@@ -965,7 +965,7 @@ def grouped_reduce(self: T, func: Callable) -> T:
965965
# expected "List[Union[ndarray, ExtensionArray]]"
966966
return type(self)(result_arrays, [index, columns]) # type: ignore[arg-type]
967967

968-
def reduce(self: T, func: Callable) -> T:
968+
def reduce(self, func: Callable) -> Self:
969969
"""
970970
Apply reduction function column-wise, returning a single-row ArrayManager.
971971

0 commit comments

Comments
 (0)