12
12
from typing import (
13
13
TYPE_CHECKING ,
14
14
Any ,
15
+ Callable ,
15
16
Dict ,
16
17
List ,
17
18
Optional ,
@@ -2045,15 +2046,19 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str):
2045
2046
if self .freq is not None :
2046
2047
kwargs ["freq" ] = _ensure_decoded (self .freq )
2047
2048
2049
+ factory : Union [Type [Index ], Type [DatetimeIndex ]] = Index
2050
+ if is_datetime64_dtype (values .dtype ) or is_datetime64tz_dtype (values .dtype ):
2051
+ factory = DatetimeIndex
2052
+
2048
2053
# making an Index instance could throw a number of different errors
2049
2054
try :
2050
- new_pd_index = Index (values , ** kwargs )
2055
+ new_pd_index = factory (values , ** kwargs )
2051
2056
except ValueError :
2052
2057
# if the output freq is different that what we recorded,
2053
2058
# it should be None (see also 'doc example part 2')
2054
2059
if "freq" in kwargs :
2055
2060
kwargs ["freq" ] = None
2056
- new_pd_index = Index (values , ** kwargs )
2061
+ new_pd_index = factory (values , ** kwargs )
2057
2062
2058
2063
new_pd_index = _set_tz (new_pd_index , self .tz )
2059
2064
return new_pd_index , new_pd_index
@@ -2736,8 +2741,14 @@ def _alias_to_class(self, alias):
2736
2741
return alias
2737
2742
return self ._reverse_index_map .get (alias , Index )
2738
2743
2739
- def _get_index_factory (self , klass ):
2740
- if klass == DatetimeIndex :
2744
+ def _get_index_factory (self , attrs ):
2745
+ index_class = self ._alias_to_class (
2746
+ _ensure_decoded (getattr (attrs , "index_class" , "" ))
2747
+ )
2748
+
2749
+ factory : Callable
2750
+
2751
+ if index_class == DatetimeIndex :
2741
2752
2742
2753
def f (values , freq = None , tz = None ):
2743
2754
# data are already in UTC, localize and convert if tz present
@@ -2747,16 +2758,34 @@ def f(values, freq=None, tz=None):
2747
2758
result = result .tz_localize ("UTC" ).tz_convert (tz )
2748
2759
return result
2749
2760
2750
- return f
2751
- elif klass == PeriodIndex :
2761
+ factory = f
2762
+ elif index_class == PeriodIndex :
2752
2763
2753
2764
def f (values , freq = None , tz = None ):
2754
2765
parr = PeriodArray ._simple_new (values , freq = freq )
2755
2766
return PeriodIndex ._simple_new (parr , name = None )
2756
2767
2757
- return f
2768
+ factory = f
2769
+ else :
2770
+ factory = index_class
2771
+
2772
+ kwargs = {}
2773
+ if "freq" in attrs :
2774
+ kwargs ["freq" ] = attrs ["freq" ]
2775
+ if index_class is Index :
2776
+ # DTI/PI would be gotten by _alias_to_class
2777
+ factory = TimedeltaIndex
2778
+
2779
+ if "tz" in attrs :
2780
+ if isinstance (attrs ["tz" ], bytes ):
2781
+ # created by python2
2782
+ kwargs ["tz" ] = attrs ["tz" ].decode ("utf-8" )
2783
+ else :
2784
+ # created by python3
2785
+ kwargs ["tz" ] = attrs ["tz" ]
2786
+ assert index_class is DatetimeIndex # just checking
2758
2787
2759
- return klass
2788
+ return factory , kwargs
2760
2789
2761
2790
def validate_read (self , columns , where ):
2762
2791
"""
@@ -2928,22 +2957,8 @@ def read_index_node(
2928
2957
name = _ensure_str (node ._v_attrs .name )
2929
2958
name = _ensure_decoded (name )
2930
2959
2931
- index_class = self ._alias_to_class (
2932
- _ensure_decoded (getattr (node ._v_attrs , "index_class" , "" ))
2933
- )
2934
- factory = self ._get_index_factory (index_class )
2935
-
2936
- kwargs = {}
2937
- if "freq" in node ._v_attrs :
2938
- kwargs ["freq" ] = node ._v_attrs ["freq" ]
2939
-
2940
- if "tz" in node ._v_attrs :
2941
- if isinstance (node ._v_attrs ["tz" ], bytes ):
2942
- # created by python2
2943
- kwargs ["tz" ] = node ._v_attrs ["tz" ].decode ("utf-8" )
2944
- else :
2945
- # created by python3
2946
- kwargs ["tz" ] = node ._v_attrs ["tz" ]
2960
+ attrs = node ._v_attrs
2961
+ factory , kwargs = self ._get_index_factory (attrs )
2947
2962
2948
2963
if kind == "date" :
2949
2964
index = factory (
0 commit comments