@@ -3712,6 +3712,81 @@ def test_str_accessor_api_for_categorical(self):
3712
3712
invalid .str
3713
3713
self .assertFalse (hasattr (invalid , 'str' ))
3714
3714
3715
+ def test_dt_accessor_api_for_categorical (self ):
3716
+ # https://github.com/pydata/pandas/issues/10661
3717
+ from pandas .tseries .common import Properties
3718
+ from pandas .tseries .index import date_range , DatetimeIndex
3719
+ from pandas .tseries .period import period_range , PeriodIndex
3720
+ from pandas .tseries .tdi import timedelta_range , TimedeltaIndex
3721
+
3722
+ s_dr = Series (date_range ('1/1/2015' , periods = 5 , tz = "MET" ))
3723
+ c_dr = s_dr .astype ("category" )
3724
+
3725
+ s_pr = Series (period_range ('1/1/2015' , freq = 'D' , periods = 5 ))
3726
+ c_pr = s_pr .astype ("category" )
3727
+
3728
+ s_tdr = Series (timedelta_range ('1 days' ,'10 days' ))
3729
+ c_tdr = s_tdr .astype ("category" )
3730
+
3731
+ test_data = [
3732
+ ("Datetime" , DatetimeIndex ._datetimelike_ops , s_dr , c_dr ),
3733
+ ("Period" , PeriodIndex ._datetimelike_ops , s_pr , c_pr ),
3734
+ ("Timedelta" , TimedeltaIndex ._datetimelike_ops , s_tdr , c_tdr )]
3735
+
3736
+ self .assertIsInstance (c_dr .dt , Properties )
3737
+
3738
+ special_func_defs = [
3739
+ ('strftime' , ("%Y-%m-%d" ,), {}),
3740
+ ('tz_convert' , ("EST" ,), {}),
3741
+ #('tz_localize', ("UTC",), {}),
3742
+ ]
3743
+ _special_func_names = [f [0 ] for f in special_func_defs ]
3744
+
3745
+ # the series is already localized
3746
+ _ignore_names = ['tz_localize' ]
3747
+
3748
+ for name , attr_names , s , c in test_data :
3749
+ func_names = [f for f in dir (s .dt ) if not (f .startswith ("_" ) or
3750
+ f in attr_names or
3751
+ f in _special_func_names or
3752
+ f in _ignore_names )]
3753
+
3754
+ func_defs = [(f , (), {}) for f in func_names ]
3755
+ for f_def in special_func_defs :
3756
+ if f_def [0 ] in dir (s .dt ):
3757
+ func_defs .append (f_def )
3758
+
3759
+ for func , args , kwargs in func_defs :
3760
+ res = getattr (c .dt , func )(* args , ** kwargs )
3761
+ exp = getattr (s .dt , func )(* args , ** kwargs )
3762
+
3763
+ if isinstance (res , pd .DataFrame ):
3764
+ tm .assert_frame_equal (res , exp )
3765
+ elif isinstance (res , pd .Series ):
3766
+ tm .assert_series_equal (res , exp )
3767
+ else :
3768
+ tm .assert_numpy_array_equal (res , exp )
3769
+
3770
+ for attr in attr_names :
3771
+ try :
3772
+ res = getattr (c .dt , attr )
3773
+ exp = getattr (s .dt , attr )
3774
+ except Exception as e :
3775
+ print (name , attr )
3776
+ raise e
3777
+
3778
+ if isinstance (res , pd .DataFrame ):
3779
+ tm .assert_frame_equal (res , exp )
3780
+ elif isinstance (res , pd .Series ):
3781
+ tm .assert_series_equal (res , exp )
3782
+ else :
3783
+ tm .assert_numpy_array_equal (res , exp )
3784
+
3785
+ invalid = Series ([1 ,2 ,3 ]).astype ('category' )
3786
+ with tm .assertRaisesRegexp (AttributeError , "Can only use .dt accessor with datetimelike" ):
3787
+ invalid .dt
3788
+ self .assertFalse (hasattr (invalid , 'str' ))
3789
+
3715
3790
def test_pickle_v0_14_1 (self ):
3716
3791
3717
3792
# we have the name warning
0 commit comments