8
8
TYPE_CHECKING ,
9
9
Any ,
10
10
TypeVar ,
11
+ cast ,
12
+ overload ,
11
13
)
12
14
13
15
import numpy as np
14
16
15
17
from pandas ._libs .hashtable import object_hash
16
18
from pandas ._typing import (
17
19
DtypeObj ,
20
+ npt ,
18
21
type_t ,
19
22
)
20
23
from pandas .errors import AbstractMethodError
29
32
from pandas .core .arrays import ExtensionArray
30
33
31
34
# To parameterize on same ExtensionDtype
32
- E = TypeVar ("E " , bound = "ExtensionDtype" )
35
+ ExtensionDtypeT = TypeVar ("ExtensionDtypeT " , bound = "ExtensionDtype" )
33
36
34
37
35
38
class ExtensionDtype :
@@ -206,7 +209,9 @@ def construct_array_type(cls) -> type_t[ExtensionArray]:
206
209
raise AbstractMethodError (cls )
207
210
208
211
@classmethod
209
- def construct_from_string (cls , string : str ):
212
+ def construct_from_string (
213
+ cls : type_t [ExtensionDtypeT ], string : str
214
+ ) -> ExtensionDtypeT :
210
215
r"""
211
216
Construct this type from a string.
212
217
@@ -368,7 +373,7 @@ def _can_hold_na(self) -> bool:
368
373
return True
369
374
370
375
371
- def register_extension_dtype (cls : type [ E ]) -> type [ E ]:
376
+ def register_extension_dtype (cls : type_t [ ExtensionDtypeT ]) -> type_t [ ExtensionDtypeT ]:
372
377
"""
373
378
Register an ExtensionType with pandas as class decorator.
374
379
@@ -409,9 +414,9 @@ class Registry:
409
414
"""
410
415
411
416
def __init__ (self ):
412
- self .dtypes : list [type [ExtensionDtype ]] = []
417
+ self .dtypes : list [type_t [ExtensionDtype ]] = []
413
418
414
- def register (self , dtype : type [ExtensionDtype ]) -> None :
419
+ def register (self , dtype : type_t [ExtensionDtype ]) -> None :
415
420
"""
416
421
Parameters
417
422
----------
@@ -422,22 +427,46 @@ def register(self, dtype: type[ExtensionDtype]) -> None:
422
427
423
428
self .dtypes .append (dtype )
424
429
425
- def find (self , dtype : type [ExtensionDtype ] | str ) -> type [ExtensionDtype ] | None :
430
+ @overload
431
+ def find (self , dtype : type_t [ExtensionDtypeT ]) -> type_t [ExtensionDtypeT ]:
432
+ ...
433
+
434
+ @overload
435
+ def find (self , dtype : ExtensionDtypeT ) -> ExtensionDtypeT :
436
+ ...
437
+
438
+ @overload
439
+ def find (self , dtype : str ) -> ExtensionDtype | None :
440
+ ...
441
+
442
+ @overload
443
+ def find (
444
+ self , dtype : npt .DTypeLike
445
+ ) -> type_t [ExtensionDtype ] | ExtensionDtype | None :
446
+ ...
447
+
448
+ def find (
449
+ self , dtype : type_t [ExtensionDtype ] | ExtensionDtype | npt .DTypeLike
450
+ ) -> type_t [ExtensionDtype ] | ExtensionDtype | None :
426
451
"""
427
452
Parameters
428
453
----------
429
- dtype : Type[ ExtensionDtype] or str
454
+ dtype : ExtensionDtype class or instance or str or numpy dtype or python type
430
455
431
456
Returns
432
457
-------
433
458
return the first matching dtype, otherwise return None
434
459
"""
435
460
if not isinstance (dtype , str ):
436
- dtype_type = dtype
461
+ dtype_type : type_t
437
462
if not isinstance (dtype , type ):
438
463
dtype_type = type (dtype )
464
+ else :
465
+ dtype_type = dtype
439
466
if issubclass (dtype_type , ExtensionDtype ):
440
- return dtype
467
+ # cast needed here as mypy doesn't know we have figured
468
+ # out it is an ExtensionDtype or type_t[ExtensionDtype]
469
+ return cast ("ExtensionDtype | type_t[ExtensionDtype]" , dtype )
441
470
442
471
return None
443
472
0 commit comments