Skip to content

Commit 27ae7d1

Browse files
CLN: add typing for dtype arg in core/arrays (GH38808) (#38886)
1 parent 581bf70 commit 27ae7d1

File tree

8 files changed

+134
-53
lines changed

8 files changed

+134
-53
lines changed

pandas/core/arrays/interval.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import operator
22
from operator import le, lt
33
import textwrap
4-
from typing import Sequence, Type, TypeVar
4+
from typing import Optional, Sequence, Type, TypeVar, cast
55

66
import numpy as np
77

@@ -14,7 +14,7 @@
1414
intervals_to_interval_bounds,
1515
)
1616
from pandas._libs.missing import NA
17-
from pandas._typing import ArrayLike
17+
from pandas._typing import ArrayLike, Dtype, NpDtype
1818
from pandas.compat.numpy import function as nv
1919
from pandas.util._decorators import Appender
2020

@@ -170,7 +170,7 @@ def __new__(
170170
cls,
171171
data,
172172
closed=None,
173-
dtype=None,
173+
dtype: Optional[Dtype] = None,
174174
copy: bool = False,
175175
verify_integrity: bool = True,
176176
):
@@ -212,7 +212,13 @@ def __new__(
212212

213213
@classmethod
214214
def _simple_new(
215-
cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True
215+
cls,
216+
left,
217+
right,
218+
closed=None,
219+
copy=False,
220+
dtype: Optional[Dtype] = None,
221+
verify_integrity=True,
216222
):
217223
result = IntervalMixin.__new__(cls)
218224

@@ -223,12 +229,14 @@ def _simple_new(
223229
if dtype is not None:
224230
# GH 19262: dtype must be an IntervalDtype to override inferred
225231
dtype = pandas_dtype(dtype)
226-
if not is_interval_dtype(dtype):
232+
if is_interval_dtype(dtype):
233+
dtype = cast(IntervalDtype, dtype)
234+
if dtype.subtype is not None:
235+
left = left.astype(dtype.subtype)
236+
right = right.astype(dtype.subtype)
237+
else:
227238
msg = f"dtype must be an IntervalDtype, got {dtype}"
228239
raise TypeError(msg)
229-
elif dtype.subtype is not None:
230-
left = left.astype(dtype.subtype)
231-
right = right.astype(dtype.subtype)
232240

233241
# coerce dtypes to match if needed
234242
if is_float_dtype(left) and is_integer_dtype(right):
@@ -279,7 +287,9 @@ def _simple_new(
279287
return result
280288

281289
@classmethod
282-
def _from_sequence(cls, scalars, *, dtype=None, copy=False):
290+
def _from_sequence(
291+
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
292+
):
283293
return cls(scalars, dtype=dtype, copy=copy)
284294

285295
@classmethod
@@ -338,7 +348,9 @@ def _from_factorized(cls, values, original):
338348
),
339349
}
340350
)
341-
def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
351+
def from_breaks(
352+
cls, breaks, closed="right", copy: bool = False, dtype: Optional[Dtype] = None
353+
):
342354
breaks = maybe_convert_platform_interval(breaks)
343355

344356
return cls.from_arrays(breaks[:-1], breaks[1:], closed, copy=copy, dtype=dtype)
@@ -407,7 +419,9 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
407419
),
408420
}
409421
)
410-
def from_arrays(cls, left, right, closed="right", copy=False, dtype=None):
422+
def from_arrays(
423+
cls, left, right, closed="right", copy=False, dtype: Optional[Dtype] = None
424+
):
411425
left = maybe_convert_platform_interval(left)
412426
right = maybe_convert_platform_interval(right)
413427

@@ -464,7 +478,9 @@ def from_arrays(cls, left, right, closed="right", copy=False, dtype=None):
464478
),
465479
}
466480
)
467-
def from_tuples(cls, data, closed="right", copy=False, dtype=None):
481+
def from_tuples(
482+
cls, data, closed="right", copy=False, dtype: Optional[Dtype] = None
483+
):
468484
if len(data):
469485
left, right = [], []
470486
else:
@@ -1277,7 +1293,7 @@ def is_non_overlapping_monotonic(self):
12771293
# ---------------------------------------------------------------------
12781294
# Conversion
12791295

1280-
def __array__(self, dtype=None) -> np.ndarray:
1296+
def __array__(self, dtype: Optional[NpDtype] = None) -> np.ndarray:
12811297
"""
12821298
Return the IntervalArray's data as a numpy array of Interval
12831299
objects (with dtype='object')

pandas/core/arrays/masked.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from pandas._libs import lib, missing as libmissing
8-
from pandas._typing import ArrayLike, Dtype, Scalar
8+
from pandas._typing import ArrayLike, Dtype, NpDtype, Scalar
99
from pandas.errors import AbstractMethodError
1010
from pandas.util._decorators import cache_readonly, doc
1111

@@ -147,7 +147,10 @@ def __invert__(self: BaseMaskedArrayT) -> BaseMaskedArrayT:
147147
return type(self)(~self._data, self._mask)
148148

149149
def to_numpy(
150-
self, dtype=None, copy: bool = False, na_value: Scalar = lib.no_default
150+
self,
151+
dtype: Optional[NpDtype] = None,
152+
copy: bool = False,
153+
na_value: Scalar = lib.no_default,
151154
) -> np.ndarray:
152155
"""
153156
Convert to a NumPy Array.
@@ -257,7 +260,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
257260

258261
__array_priority__ = 1000 # higher than ndarray so ops dispatch to us
259262

260-
def __array__(self, dtype=None) -> np.ndarray:
263+
def __array__(self, dtype: Optional[NpDtype] = None) -> np.ndarray:
261264
"""
262265
the array interface, return my values
263266
We return an object array here to preserve our scalar values

pandas/core/arrays/numpy_.py

+60-12
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numbers
2-
from typing import Tuple, Type, Union
2+
from typing import Optional, Tuple, Type, Union
33

44
import numpy as np
55
from numpy.lib.mixins import NDArrayOperatorsMixin
66

77
from pandas._libs import lib
8-
from pandas._typing import Scalar
8+
from pandas._typing import Dtype, NpDtype, Scalar
99
from pandas.compat.numpy import function as nv
1010

1111
from pandas.core.dtypes.dtypes import ExtensionDtype
@@ -38,7 +38,7 @@ class PandasDtype(ExtensionDtype):
3838

3939
_metadata = ("_dtype",)
4040

41-
def __init__(self, dtype: object):
41+
def __init__(self, dtype: Optional[NpDtype]):
4242
self._dtype = np.dtype(dtype)
4343

4444
def __repr__(self) -> str:
@@ -173,7 +173,7 @@ def __init__(self, values: Union[np.ndarray, "PandasArray"], copy: bool = False)
173173

174174
@classmethod
175175
def _from_sequence(
176-
cls, scalars, *, dtype=None, copy: bool = False
176+
cls, scalars, *, dtype: Optional[Dtype] = None, copy: bool = False
177177
) -> "PandasArray":
178178
if isinstance(dtype, PandasDtype):
179179
dtype = dtype._dtype
@@ -200,7 +200,7 @@ def dtype(self) -> PandasDtype:
200200
# ------------------------------------------------------------------------
201201
# NumPy Array Interface
202202

203-
def __array__(self, dtype=None) -> np.ndarray:
203+
def __array__(self, dtype: Optional[NpDtype] = None) -> np.ndarray:
204204
return np.asarray(self._ndarray, dtype=dtype)
205205

206206
_HANDLED_TYPES = (np.ndarray, numbers.Number)
@@ -311,7 +311,15 @@ def prod(self, *, axis=None, skipna=True, min_count=0, **kwargs) -> Scalar:
311311
)
312312
return self._wrap_reduction_result(axis, result)
313313

314-
def mean(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
314+
def mean(
315+
self,
316+
*,
317+
axis=None,
318+
dtype: Optional[NpDtype] = None,
319+
out=None,
320+
keepdims=False,
321+
skipna=True,
322+
):
315323
nv.validate_mean((), {"dtype": dtype, "out": out, "keepdims": keepdims})
316324
result = nanops.nanmean(self._ndarray, axis=axis, skipna=skipna)
317325
return self._wrap_reduction_result(axis, result)
@@ -326,7 +334,14 @@ def median(
326334
return self._wrap_reduction_result(axis, result)
327335

328336
def std(
329-
self, *, axis=None, dtype=None, out=None, ddof=1, keepdims=False, skipna=True
337+
self,
338+
*,
339+
axis=None,
340+
dtype: Optional[NpDtype] = None,
341+
out=None,
342+
ddof=1,
343+
keepdims=False,
344+
skipna=True,
330345
):
331346
nv.validate_stat_ddof_func(
332347
(), {"dtype": dtype, "out": out, "keepdims": keepdims}, fname="std"
@@ -335,7 +350,14 @@ def std(
335350
return self._wrap_reduction_result(axis, result)
336351

337352
def var(
338-
self, *, axis=None, dtype=None, out=None, ddof=1, keepdims=False, skipna=True
353+
self,
354+
*,
355+
axis=None,
356+
dtype: Optional[NpDtype] = None,
357+
out=None,
358+
ddof=1,
359+
keepdims=False,
360+
skipna=True,
339361
):
340362
nv.validate_stat_ddof_func(
341363
(), {"dtype": dtype, "out": out, "keepdims": keepdims}, fname="var"
@@ -344,22 +366,45 @@ def var(
344366
return self._wrap_reduction_result(axis, result)
345367

346368
def sem(
347-
self, *, axis=None, dtype=None, out=None, ddof=1, keepdims=False, skipna=True
369+
self,
370+
*,
371+
axis=None,
372+
dtype: Optional[NpDtype] = None,
373+
out=None,
374+
ddof=1,
375+
keepdims=False,
376+
skipna=True,
348377
):
349378
nv.validate_stat_ddof_func(
350379
(), {"dtype": dtype, "out": out, "keepdims": keepdims}, fname="sem"
351380
)
352381
result = nanops.nansem(self._ndarray, axis=axis, skipna=skipna, ddof=ddof)
353382
return self._wrap_reduction_result(axis, result)
354383

355-
def kurt(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
384+
def kurt(
385+
self,
386+
*,
387+
axis=None,
388+
dtype: Optional[NpDtype] = None,
389+
out=None,
390+
keepdims=False,
391+
skipna=True,
392+
):
356393
nv.validate_stat_ddof_func(
357394
(), {"dtype": dtype, "out": out, "keepdims": keepdims}, fname="kurt"
358395
)
359396
result = nanops.nankurt(self._ndarray, axis=axis, skipna=skipna)
360397
return self._wrap_reduction_result(axis, result)
361398

362-
def skew(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
399+
def skew(
400+
self,
401+
*,
402+
axis=None,
403+
dtype: Optional[NpDtype] = None,
404+
out=None,
405+
keepdims=False,
406+
skipna=True,
407+
):
363408
nv.validate_stat_ddof_func(
364409
(), {"dtype": dtype, "out": out, "keepdims": keepdims}, fname="skew"
365410
)
@@ -370,7 +415,10 @@ def skew(self, *, axis=None, dtype=None, out=None, keepdims=False, skipna=True):
370415
# Additional Methods
371416

372417
def to_numpy(
373-
self, dtype=None, copy: bool = False, na_value=lib.no_default
418+
self,
419+
dtype: Optional[NpDtype] = None,
420+
copy: bool = False,
421+
na_value=lib.no_default,
374422
) -> np.ndarray:
375423
result = np.asarray(self._ndarray, dtype=dtype)
376424

pandas/core/arrays/period.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
get_period_field_arr,
2828
period_asfreq_arr,
2929
)
30-
from pandas._typing import AnyArrayLike, Dtype
30+
from pandas._typing import AnyArrayLike, Dtype, NpDtype
3131
from pandas.util._decorators import cache_readonly, doc
3232

3333
from pandas.core.dtypes.common import (
@@ -160,7 +160,7 @@ class PeriodArray(PeriodMixin, dtl.DatelikeOps):
160160
# --------------------------------------------------------------------
161161
# Constructors
162162

163-
def __init__(self, values, dtype=None, freq=None, copy=False):
163+
def __init__(self, values, dtype: Optional[Dtype] = None, freq=None, copy=False):
164164
freq = validate_dtype_freq(dtype, freq)
165165

166166
if freq is not None:
@@ -187,7 +187,10 @@ def __init__(self, values, dtype=None, freq=None, copy=False):
187187

188188
@classmethod
189189
def _simple_new(
190-
cls, values: np.ndarray, freq: Optional[BaseOffset] = None, dtype=None
190+
cls,
191+
values: np.ndarray,
192+
freq: Optional[BaseOffset] = None,
193+
dtype: Optional[Dtype] = None,
191194
) -> "PeriodArray":
192195
# alias for PeriodArray.__init__
193196
assertion_msg = "Should be numpy array of type i8"
@@ -221,7 +224,7 @@ def _from_sequence(
221224

222225
@classmethod
223226
def _from_sequence_of_strings(
224-
cls, strings, *, dtype=None, copy=False
227+
cls, strings, *, dtype: Optional[Dtype] = None, copy=False
225228
) -> "PeriodArray":
226229
return cls._from_sequence(strings, dtype=dtype, copy=copy)
227230

@@ -302,7 +305,7 @@ def freq(self) -> BaseOffset:
302305
"""
303306
return self.dtype.freq
304307

305-
def __array__(self, dtype=None) -> np.ndarray:
308+
def __array__(self, dtype: Optional[NpDtype] = None) -> np.ndarray:
306309
if dtype == "i8":
307310
return self.asi8
308311
elif dtype == bool:

0 commit comments

Comments
 (0)