|
14 | 14 | timedelta_range,
|
15 | 15 | )
|
16 | 16 | import pandas._testing as tm
|
17 |
| -from pandas.core.arrays import ( |
18 |
| - DatetimeArray, |
19 |
| - PeriodArray, |
20 |
| - TimedeltaArray, |
21 |
| -) |
22 | 17 | from pandas.core.arrays.categorical import CategoricalAccessor
|
23 | 18 | from pandas.core.indexes.accessors import Properties
|
24 | 19 |
|
@@ -163,86 +158,84 @@ def test_categorical_delegations(self):
|
163 | 158 | )
|
164 | 159 | tm.assert_series_equal(result, expected)
|
165 | 160 |
|
166 |
| - def test_dt_accessor_api_for_categorical(self): |
| 161 | + @pytest.mark.parametrize( |
| 162 | + "idx", |
| 163 | + [ |
| 164 | + date_range("1/1/2015", periods=5), |
| 165 | + date_range("1/1/2015", periods=5, tz="MET"), |
| 166 | + period_range("1/1/2015", freq="D", periods=5), |
| 167 | + timedelta_range("1 days", "10 days"), |
| 168 | + ], |
| 169 | + ) |
| 170 | + def test_dt_accessor_api_for_categorical(self, idx): |
167 | 171 | # https://github.com/pandas-dev/pandas/issues/10661
|
168 | 172 |
|
169 |
| - s_dr = Series(date_range("1/1/2015", periods=5, tz="MET")) |
170 |
| - c_dr = s_dr.astype("category") |
171 |
| - |
172 |
| - s_pr = Series(period_range("1/1/2015", freq="D", periods=5)) |
173 |
| - c_pr = s_pr.astype("category") |
174 |
| - |
175 |
| - s_tdr = Series(timedelta_range("1 days", "10 days")) |
176 |
| - c_tdr = s_tdr.astype("category") |
| 173 | + ser = Series(idx) |
| 174 | + cat = ser.astype("category") |
177 | 175 |
|
178 | 176 | # only testing field (like .day)
|
179 | 177 | # and bool (is_month_start)
|
180 |
| - get_ops = lambda x: x._datetimelike_ops |
181 |
| - |
182 |
| - test_data = [ |
183 |
| - ("Datetime", get_ops(DatetimeArray), s_dr, c_dr), |
184 |
| - ("Period", get_ops(PeriodArray), s_pr, c_pr), |
185 |
| - ("Timedelta", get_ops(TimedeltaArray), s_tdr, c_tdr), |
186 |
| - ] |
| 178 | + attr_names = type(ser._values)._datetimelike_ops |
187 | 179 |
|
188 |
| - assert isinstance(c_dr.dt, Properties) |
| 180 | + assert isinstance(cat.dt, Properties) |
189 | 181 |
|
190 | 182 | special_func_defs = [
|
191 | 183 | ("strftime", ("%Y-%m-%d",), {}),
|
192 |
| - ("tz_convert", ("EST",), {}), |
193 | 184 | ("round", ("D",), {}),
|
194 | 185 | ("floor", ("D",), {}),
|
195 | 186 | ("ceil", ("D",), {}),
|
196 | 187 | ("asfreq", ("D",), {}),
|
197 |
| - # FIXME: don't leave commented-out |
198 |
| - # ('tz_localize', ("UTC",), {}), |
199 | 188 | ]
|
| 189 | + if idx.dtype == "M8[ns]": |
| 190 | + # exclude dt64tz since that is already localized and would raise |
| 191 | + tup = ("tz_localize", ("UTC",), {}) |
| 192 | + special_func_defs.append(tup) |
| 193 | + elif idx.dtype.kind == "M": |
| 194 | + # exclude dt64 since that is not localized so would raise |
| 195 | + tup = ("tz_convert", ("EST",), {}) |
| 196 | + special_func_defs.append(tup) |
| 197 | + |
200 | 198 | _special_func_names = [f[0] for f in special_func_defs]
|
201 | 199 |
|
202 |
| - # the series is already localized |
203 |
| - _ignore_names = ["tz_localize", "components"] |
204 |
| - |
205 |
| - for name, attr_names, s, c in test_data: |
206 |
| - func_names = [ |
207 |
| - f |
208 |
| - for f in dir(s.dt) |
209 |
| - if not ( |
210 |
| - f.startswith("_") |
211 |
| - or f in attr_names |
212 |
| - or f in _special_func_names |
213 |
| - or f in _ignore_names |
214 |
| - ) |
215 |
| - ] |
216 |
| - |
217 |
| - func_defs = [(f, (), {}) for f in func_names] |
218 |
| - for f_def in special_func_defs: |
219 |
| - if f_def[0] in dir(s.dt): |
220 |
| - func_defs.append(f_def) |
221 |
| - |
222 |
| - for func, args, kwargs in func_defs: |
223 |
| - with warnings.catch_warnings(): |
224 |
| - if func == "to_period": |
225 |
| - # dropping TZ |
226 |
| - warnings.simplefilter("ignore", UserWarning) |
227 |
| - res = getattr(c.dt, func)(*args, **kwargs) |
228 |
| - exp = getattr(s.dt, func)(*args, **kwargs) |
229 |
| - |
230 |
| - tm.assert_equal(res, exp) |
231 |
| - |
232 |
| - for attr in attr_names: |
233 |
| - if attr in ["week", "weekofyear"]: |
234 |
| - # GH#33595 Deprecate week and weekofyear |
235 |
| - continue |
236 |
| - res = getattr(c.dt, attr) |
237 |
| - exp = getattr(s.dt, attr) |
238 |
| - |
239 |
| - if isinstance(res, DataFrame): |
240 |
| - tm.assert_frame_equal(res, exp) |
241 |
| - elif isinstance(res, Series): |
242 |
| - tm.assert_series_equal(res, exp) |
243 |
| - else: |
244 |
| - tm.assert_almost_equal(res, exp) |
| 200 | + _ignore_names = ["components", "tz_localize", "tz_convert"] |
| 201 | + |
| 202 | + func_names = [ |
| 203 | + fname |
| 204 | + for fname in dir(ser.dt) |
| 205 | + if not ( |
| 206 | + fname.startswith("_") |
| 207 | + or fname in attr_names |
| 208 | + or fname in _special_func_names |
| 209 | + or fname in _ignore_names |
| 210 | + ) |
| 211 | + ] |
| 212 | + |
| 213 | + func_defs = [(fname, (), {}) for fname in func_names] |
| 214 | + |
| 215 | + for f_def in special_func_defs: |
| 216 | + if f_def[0] in dir(ser.dt): |
| 217 | + func_defs.append(f_def) |
| 218 | + |
| 219 | + for func, args, kwargs in func_defs: |
| 220 | + with warnings.catch_warnings(): |
| 221 | + if func == "to_period": |
| 222 | + # dropping TZ |
| 223 | + warnings.simplefilter("ignore", UserWarning) |
| 224 | + res = getattr(cat.dt, func)(*args, **kwargs) |
| 225 | + exp = getattr(ser.dt, func)(*args, **kwargs) |
| 226 | + |
| 227 | + tm.assert_equal(res, exp) |
| 228 | + |
| 229 | + for attr in attr_names: |
| 230 | + if attr in ["week", "weekofyear"]: |
| 231 | + # GH#33595 Deprecate week and weekofyear |
| 232 | + continue |
| 233 | + res = getattr(cat.dt, attr) |
| 234 | + exp = getattr(ser.dt, attr) |
| 235 | + |
| 236 | + tm.assert_equal(res, exp) |
245 | 237 |
|
| 238 | + def test_dt_accessor_api_for_categorical_invalid(self): |
246 | 239 | invalid = Series([1, 2, 3]).astype("category")
|
247 | 240 | msg = "Can only use .dt accessor with datetimelike"
|
248 | 241 |
|
|
0 commit comments