Skip to content

Commit 50ed3ea

Browse files
committed
API: tests for NaT accessors
closes pandas-dev#15781
1 parent 79581ff commit 50ed3ea

File tree

3 files changed

+89
-56
lines changed

3 files changed

+89
-56
lines changed

pandas/_libs/tslib.pyx

+36-12
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,30 @@ class NaTType(_NaT):
849849
def is_leap_year(self):
850850
return False
851851

852+
@property
853+
def is_month_start(self):
854+
return False
855+
856+
@property
857+
def is_quarter_start(self):
858+
return False
859+
860+
@property
861+
def is_year_start(self):
862+
return False
863+
864+
@property
865+
def is_month_end(self):
866+
return False
867+
868+
@property
869+
def is_quarter_end(self):
870+
return False
871+
872+
@property
873+
def is_year_end(self):
874+
return False
875+
852876
def __rdiv__(self, other):
853877
return _nat_rdivide_op(self, other)
854878

@@ -4810,7 +4834,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48104834
if field == 'is_month_start':
48114835
if is_business:
48124836
for i in range(count):
4813-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4837+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
48144838

48154839
pandas_datetime_to_datetimestruct(
48164840
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4823,7 +4847,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48234847
return out.view(bool)
48244848
else:
48254849
for i in range(count):
4826-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4850+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
48274851

48284852
pandas_datetime_to_datetimestruct(
48294853
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4836,7 +4860,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48364860
elif field == 'is_month_end':
48374861
if is_business:
48384862
for i in range(count):
4839-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4863+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
48404864

48414865
pandas_datetime_to_datetimestruct(
48424866
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4854,7 +4878,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48544878
return out.view(bool)
48554879
else:
48564880
for i in range(count):
4857-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4881+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
48584882

48594883
pandas_datetime_to_datetimestruct(
48604884
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4871,7 +4895,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48714895
elif field == 'is_quarter_start':
48724896
if is_business:
48734897
for i in range(count):
4874-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4898+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
48754899

48764900
pandas_datetime_to_datetimestruct(
48774901
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4885,7 +4909,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48854909
return out.view(bool)
48864910
else:
48874911
for i in range(count):
4888-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4912+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
48894913

48904914
pandas_datetime_to_datetimestruct(
48914915
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4898,7 +4922,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
48984922
elif field == 'is_quarter_end':
48994923
if is_business:
49004924
for i in range(count):
4901-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4925+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
49024926

49034927
pandas_datetime_to_datetimestruct(
49044928
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4917,7 +4941,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
49174941
return out.view(bool)
49184942
else:
49194943
for i in range(count):
4920-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4944+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
49214945

49224946
pandas_datetime_to_datetimestruct(
49234947
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4934,7 +4958,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
49344958
elif field == 'is_year_start':
49354959
if is_business:
49364960
for i in range(count):
4937-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4961+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
49384962

49394963
pandas_datetime_to_datetimestruct(
49404964
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4948,7 +4972,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
49484972
return out.view(bool)
49494973
else:
49504974
for i in range(count):
4951-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4975+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
49524976

49534977
pandas_datetime_to_datetimestruct(
49544978
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4961,7 +4985,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
49614985
elif field == 'is_year_end':
49624986
if is_business:
49634987
for i in range(count):
4964-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
4988+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
49654989

49664990
pandas_datetime_to_datetimestruct(
49674991
dtindex[i], PANDAS_FR_ns, &dts)
@@ -4980,7 +5004,7 @@ def get_start_end_field(ndarray[int64_t] dtindex, object field,
49805004
return out.view(bool)
49815005
else:
49825006
for i in range(count):
4983-
if dtindex[i] == NPY_NAT: out[i] = -1; continue
5007+
if dtindex[i] == NPY_NAT: out[i] = 0; continue
49845008

49855009
pandas_datetime_to_datetimestruct(
49865010
dtindex[i], PANDAS_FR_ns, &dts)

pandas/tests/scalar/test_timestamp.py

+30-27
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@
2727

2828
class TestTimestamp(tm.TestCase):
2929

30+
# all accessors fields
31+
_field_accessors = ['year', 'month', 'day', 'hour', 'minute', 'second',
32+
'microsecond', 'nanosecond', 'dayofweek', 'quarter',
33+
'dayofyear', 'week', 'daysinmonth', 'days_in_month']
34+
_bool_accessors = ['is_month_start', 'is_quarter_start', 'is_year_start',
35+
'is_month_end', 'is_quarter_end', 'is_year_end',
36+
'is_leap_year']
37+
3038
def test_constructor(self):
3139
base_str = '2014-07-01 09:00'
3240
base_dt = datetime(2014, 7, 1, 9)
@@ -579,47 +587,42 @@ def check(value, equal):
579587
def test_nat_fields(self):
580588
# GH 10050
581589
ts = Timestamp('NaT')
582-
self.assertTrue(np.isnan(ts.year))
583-
self.assertTrue(np.isnan(ts.month))
584-
self.assertTrue(np.isnan(ts.day))
585-
self.assertTrue(np.isnan(ts.hour))
586-
self.assertTrue(np.isnan(ts.minute))
587-
self.assertTrue(np.isnan(ts.second))
588-
self.assertTrue(np.isnan(ts.microsecond))
589-
self.assertTrue(np.isnan(ts.nanosecond))
590-
self.assertTrue(np.isnan(ts.dayofweek))
591-
self.assertTrue(np.isnan(ts.quarter))
592-
self.assertTrue(np.isnan(ts.dayofyear))
593-
self.assertTrue(np.isnan(ts.week))
594-
self.assertTrue(np.isnan(ts.daysinmonth))
595-
self.assertTrue(np.isnan(ts.days_in_month))
590+
591+
for field in self._field_accessors:
592+
593+
result = getattr(NaT, field)
594+
assert np.isnan(result)
595+
596+
result = getattr(ts, field)
597+
assert np.isnan(result)
598+
599+
for field in self._bool_accessors:
600+
601+
result = getattr(NaT, field)
602+
assert not result
603+
604+
result = getattr(ts, field)
605+
assert not result
596606

597607
def test_nat_vector_field_access(self):
598608
idx = DatetimeIndex(['1/1/2000', None, None, '1/4/2000'])
599609

600-
fields = ['year', 'quarter', 'month', 'day', 'hour', 'minute',
601-
'second', 'microsecond', 'nanosecond', 'week', 'dayofyear',
602-
'days_in_month', 'is_leap_year']
603-
604-
for field in fields:
610+
for field in self._field_accessors:
605611
result = getattr(idx, field)
606612
expected = [getattr(x, field) for x in idx]
607613
self.assert_numpy_array_equal(result, np.array(expected))
608614

609615
s = pd.Series(idx)
610616

611-
for field in fields:
617+
for field in self._field_accessors:
612618
result = getattr(s.dt, field)
613619
expected = [getattr(x, field) for x in idx]
614620
self.assert_series_equal(result, pd.Series(expected))
615621

616-
def test_nat_scalar_field_access(self):
617-
fields = ['year', 'quarter', 'month', 'day', 'hour', 'minute',
618-
'second', 'microsecond', 'nanosecond', 'week', 'dayofyear',
619-
'days_in_month', 'daysinmonth', 'dayofweek', 'weekday_name']
620-
for field in fields:
621-
result = getattr(NaT, field)
622-
self.assertTrue(np.isnan(result))
622+
for field in self._bool_accessors:
623+
result = getattr(s.dt, field)
624+
expected = [getattr(x, field) for x in idx]
625+
self.assert_series_equal(result, pd.Series(expected))
623626

624627
def test_NaT_methods(self):
625628
# GH 9513

pandas/tseries/index.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,25 @@ def f(self):
6464
if self.tz is not utc:
6565
values = self._local_timestamps()
6666

67-
if field in ['is_month_start', 'is_month_end',
68-
'is_quarter_start', 'is_quarter_end',
69-
'is_year_start', 'is_year_end']:
70-
month_kw = (self.freq.kwds.get('startingMonth',
71-
self.freq.kwds.get('month', 12))
72-
if self.freq else 12)
73-
74-
result = libts.get_start_end_field(values, field, self.freqstr,
75-
month_kw)
76-
elif field in ['weekday_name']:
67+
if field in self._bool_ops:
68+
if field in ['is_month_start', 'is_month_end',
69+
'is_quarter_start', 'is_quarter_end',
70+
'is_year_start', 'is_year_end']:
71+
month_kw = (self.freq.kwds.get('startingMonth',
72+
self.freq.kwds.get('month', 12))
73+
if self.freq else 12)
74+
75+
result = libts.get_start_end_field(values, field, self.freqstr,
76+
month_kw)
77+
else:
78+
result = libts.get_date_field(values, field)
79+
80+
# these return a boolean by-definition
81+
return result
82+
83+
if field in self._object_ops:
7784
result = libts.get_date_name_field(values, field)
7885
return self._maybe_mask_results(result)
79-
elif field in ['is_leap_year']:
80-
# no need to mask NaT
81-
return libts.get_date_field(values, field)
8286
else:
8387
result = libts.get_date_field(values, field)
8488

@@ -227,14 +231,16 @@ def _join_i8_wrapper(joinf, **kwargs):
227231
offset = None
228232
_comparables = ['name', 'freqstr', 'tz']
229233
_attributes = ['name', 'freq', 'tz']
234+
_bool_ops = ['is_month_start', 'is_month_end',
235+
'is_quarter_start', 'is_quarter_end', 'is_year_start',
236+
'is_year_end', 'is_leap_year']
237+
_object_ops = ['weekday_name', 'tz', 'freq']
230238
_datetimelike_ops = ['year', 'month', 'day', 'hour', 'minute', 'second',
231239
'weekofyear', 'week', 'dayofweek', 'weekday',
232240
'dayofyear', 'quarter', 'days_in_month',
233241
'daysinmonth', 'date', 'time', 'microsecond',
234-
'nanosecond', 'is_month_start', 'is_month_end',
235-
'is_quarter_start', 'is_quarter_end', 'is_year_start',
236-
'is_year_end', 'tz', 'freq', 'weekday_name',
237-
'is_leap_year']
242+
'nanosecond'] + _object_ops + _bool_ops
243+
238244
_is_numeric_dtype = False
239245
_infer_as_myclass = True
240246

0 commit comments

Comments
 (0)