Skip to content

Commit 1731852

Browse files
committed
BUG: GroupBy.get_group raises ValueError when group key contains NaT
1 parent f9db166 commit 1731852

File tree

11 files changed

+114
-14
lines changed

11 files changed

+114
-14
lines changed

doc/source/groupby.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -784,11 +784,11 @@ will be (silently) dropped. Thus, this does not pose any problems:
784784
785785
df.groupby('A').std()
786786
787-
NA group handling
787+
NA and NaT group handling
788788
~~~~~~~~~~~~~~~~~
789789

790-
If there are any NaN values in the grouping key, these will be automatically
791-
excluded. So there will never be an "NA group". This was not the case in older
790+
If there are any NaN or NaT values in the grouping key, these will be automatically
791+
excluded. So there will never be an "NA group" or "NaT group". This was not the case in older
792792
versions of pandas, but users were generally discarding the NA group anyway
793793
(and supporting it was an implementation headache).
794794

doc/source/whatsnew/v0.17.0.txt

+5
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,8 @@ Bug Fixes
6868

6969

7070

71+
72+
73+
- Bug in GroupBy.get_group raises ValueError when group key contains NaT (:issue:`6992`)
74+
75+

pandas/algos.pyx

+4-2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ cdef extern from "src/headers/math.h":
6161
int signbit(double)
6262

6363
from pandas import lib
64+
from pandas import tslib
65+
cdef object NaT = tslib.NaT
6466

6567
include "skiplist.pyx"
6668

@@ -2010,7 +2012,7 @@ def groupby_indices(ndarray values):
20102012
k = labels[i]
20112013

20122014
# was NaN
2013-
if k == -1:
2015+
if k == -1 or k is NaT:
20142016
continue
20152017

20162018
loc = seen[k]
@@ -2043,7 +2045,7 @@ def group_labels(ndarray[object] values):
20432045
val = values[i]
20442046

20452047
# is NaN
2046-
if val != val:
2048+
if val != val or val is NaT:
20472049
labels[i] = -1
20482050
continue
20492051

pandas/core/groupby.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,11 @@ def convert(key, s):
426426
return Timestamp(key).asm8
427427
return key
428428

429-
sample = next(iter(self.indices))
429+
if len(self.indices) > 0:
430+
sample = next(iter(self.indices))
431+
else:
432+
sample = None # Dummy sample
433+
430434
if isinstance(sample, tuple):
431435
if not isinstance(name, tuple):
432436
msg = ("must supply a tuple to get_group with multiple"

pandas/src/generate_code.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
cimport util
3939
from util cimport is_array, _checknull, _checknan, get_nat
40+
cimport tslib
41+
from tslib cimport _checknull_with_np_nat
4042
4143
cdef int64_t iNaT = get_nat()
4244
@@ -673,7 +675,7 @@ def groupby_%(name)s(ndarray[%(c_type)s] index, ndarray labels):
673675
for i in range(length):
674676
key = util.get_value_1d(labels, i)
675677
676-
if _checknull(key):
678+
if _checknull_with_np_nat(key):
677679
continue
678680
679681
idx = index[i]

pandas/src/generated.pyx

+9-6
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ ctypedef unsigned char UChar
2828

2929
cimport util
3030
from util cimport is_array, _checknull, _checknan, get_nat
31+
cimport tslib
32+
from tslib cimport _checknull_with_np_nat
3133

3234
cdef int64_t iNaT = get_nat()
35+
np_NaT = np.datetime64('NaT')
3336

3437
# import datetime C API
3538
PyDateTime_IMPORT
@@ -2096,7 +2099,7 @@ def groupby_float64(ndarray[float64_t] index, ndarray labels):
20962099
for i in range(length):
20972100
key = util.get_value_1d(labels, i)
20982101

2099-
if _checknull(key):
2102+
if _checknull_with_np_nat(key):
21002103
continue
21012104

21022105
idx = index[i]
@@ -2124,7 +2127,7 @@ def groupby_float32(ndarray[float32_t] index, ndarray labels):
21242127
for i in range(length):
21252128
key = util.get_value_1d(labels, i)
21262129

2127-
if _checknull(key):
2130+
if _checknull_with_np_nat(key):
21282131
continue
21292132

21302133
idx = index[i]
@@ -2152,7 +2155,7 @@ def groupby_object(ndarray[object] index, ndarray labels):
21522155
for i in range(length):
21532156
key = util.get_value_1d(labels, i)
21542157

2155-
if _checknull(key):
2158+
if _checknull_with_np_nat(key):
21562159
continue
21572160

21582161
idx = index[i]
@@ -2180,7 +2183,7 @@ def groupby_int32(ndarray[int32_t] index, ndarray labels):
21802183
for i in range(length):
21812184
key = util.get_value_1d(labels, i)
21822185

2183-
if _checknull(key):
2186+
if _checknull_with_np_nat(key):
21842187
continue
21852188

21862189
idx = index[i]
@@ -2208,7 +2211,7 @@ def groupby_int64(ndarray[int64_t] index, ndarray labels):
22082211
for i in range(length):
22092212
key = util.get_value_1d(labels, i)
22102213

2211-
if _checknull(key):
2214+
if _checknull_with_np_nat(key):
22122215
continue
22132216

22142217
idx = index[i]
@@ -2236,7 +2239,7 @@ def groupby_bool(ndarray[uint8_t] index, ndarray labels):
22362239
for i in range(length):
22372240
key = util.get_value_1d(labels, i)
22382241

2239-
if _checknull(key):
2242+
if _checknull_with_np_nat(key):
22402243
continue
22412244

22422245
idx = index[i]

pandas/tests/test_groupby.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,6 @@ def test_get_group(self):
699699
expected = wp.reindex(major=[x for x in wp.major_axis if x.month == 1])
700700
assert_panel_equal(gp, expected)
701701

702-
703702
# GH 5267
704703
# be datelike friendly
705704
df = DataFrame({'DATE' : pd.to_datetime(['10-Oct-2013', '10-Oct-2013', '10-Oct-2013',
@@ -2837,6 +2836,49 @@ def test_groupby_list_infer_array_like(self):
28372836
result = df.groupby(['foo', 'bar']).mean()
28382837
expected = df.groupby([df['foo'], df['bar']]).mean()[['val']]
28392838

2839+
def test_groupby_nat_exclude(self):
2840+
# GH 6992
2841+
df = pd.DataFrame({'values': np.random.randn(8),
2842+
'dt': [np.nan, pd.Timestamp('2013-01-01'), np.nan, pd.Timestamp('2013-02-01'),
2843+
np.nan, pd.Timestamp('2013-02-01'), np.nan, pd.Timestamp('2013-01-01')],
2844+
'str': [np.nan, 'a', np.nan, 'a',
2845+
np.nan, 'a', np.nan, 'b']})
2846+
grouped = df.groupby('dt')
2847+
2848+
expected = [[1, 7], [3, 5]]
2849+
keys = sorted(grouped.groups.keys())
2850+
self.assertEqual(len(keys), 2)
2851+
for k, e in zip(keys, expected):
2852+
# grouped.groups keys are np.datetime64 with system tz
2853+
# not to be affected by tz, only compare values
2854+
self.assertEqual(grouped.groups[k], e)
2855+
2856+
# confirm obj is not filtered
2857+
tm.assert_frame_equal(grouped.grouper.groupings[0].obj, df)
2858+
self.assertEqual(grouped.ngroups, 2)
2859+
expected = {Timestamp('2013-01-01 00:00:00'): np.array([1, 7]),
2860+
Timestamp('2013-02-01 00:00:00'): np.array([3, 5])}
2861+
for k in grouped.indices:
2862+
self.assert_numpy_array_equal(grouped.indices[k], expected[k])
2863+
2864+
tm.assert_frame_equal(grouped.get_group(Timestamp('2013-01-01')), df.iloc[[1, 7]])
2865+
tm.assert_frame_equal(grouped.get_group(Timestamp('2013-02-01')), df.iloc[[3, 5]])
2866+
2867+
self.assertRaises(KeyError, grouped.get_group, pd.NaT)
2868+
2869+
nan_df = DataFrame({'nan': [np.nan, np.nan, np.nan],
2870+
'nat': [pd.NaT, pd.NaT, pd.NaT]})
2871+
self.assertEqual(nan_df['nan'].dtype, 'float64')
2872+
self.assertEqual(nan_df['nat'].dtype, 'datetime64[ns]')
2873+
2874+
for key in ['nan', 'nat']:
2875+
grouped = nan_df.groupby(key)
2876+
self.assertEqual(grouped.groups, {})
2877+
self.assertEqual(grouped.ngroups, 0)
2878+
self.assertEqual(grouped.indices, {})
2879+
self.assertRaises(KeyError, grouped.get_group, np.nan)
2880+
self.assertRaises(KeyError, grouped.get_group, pd.NaT)
2881+
28402882
def test_dictify(self):
28412883
dict(iter(self.df.groupby('A')))
28422884
dict(iter(self.df.groupby(['A', 'B'])))

pandas/tests/test_index.py

+28
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,34 @@ def test_ufunc_compat(self):
18581858
expected = Float64Index(np.sin(np.arange(5,dtype='int64')))
18591859
tm.assert_index_equal(result, expected)
18601860

1861+
def test_index_groupby(self):
1862+
int_idx = Index(range(6))
1863+
float_idx = Index(np.arange(0, 0.6, 0.1))
1864+
obj_idx = Index('A B C D E F'.split())
1865+
dt_idx = pd.date_range('2013-01-01', freq='M', periods=6)
1866+
1867+
for idx in [int_idx, float_idx, obj_idx, dt_idx]:
1868+
to_groupby = np.array([1, 2, np.nan, np.nan, 2, 1])
1869+
self.assertEqual(idx.groupby(to_groupby),
1870+
{1.0: [idx[0], idx[5]], 2.0: [idx[1], idx[4]]})
1871+
self.assertEqual(idx.groupby(to_groupby, dropna=False),
1872+
{np.nan:[idx[2], idx[3]], 1.0: [idx[0], idx[5]], 2.0: [idx[1], idx[4]]})
1873+
1874+
to_groupby = Index([datetime(2011, 11, 1), datetime(2011, 12, 1),
1875+
pd.NaT, pd.NaT,
1876+
datetime(2011, 12, 1), datetime(2011, 11, 1)], tz='UTC').values
1877+
1878+
ex_keys = pd.tslib.datetime_to_datetime64(np.array([Timestamp('2011-11-01'), Timestamp('2011-12-01')]))
1879+
expected = {ex_keys[0][0]: [idx[0], idx[5]], ex_keys[0][1]: [idx[1], idx[4]]}
1880+
self.assertEqual(idx.groupby(to_groupby), expected)
1881+
1882+
ex_keys = pd.tslib.datetime_to_datetime64(np.array([pd.NaT, Timestamp('2011-11-01'),
1883+
Timestamp('2011-12-01')]))
1884+
expected = {ex_keys[0][0]: [idx[2], idx[3]],
1885+
ex_keys[0][1]: [idx[0], idx[5]],
1886+
ex_keys[0][2]: [idx[1], idx[4]]}
1887+
self.assertEqual(idx.groupby(to_groupby, dropna=False), expected)
1888+
18611889

18621890
class TestFloat64Index(Numeric, tm.TestCase):
18631891
_holder = Float64Index

pandas/tseries/tests/test_timeseries.py

+6
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,12 @@ def test_pickle(self):
20612061
self.assertTrue(idx_p[1] is NaT)
20622062
self.assertTrue(idx_p[2] == idx[2])
20632063

2064+
def test_indexing_doesnt_change_class(self):
2065+
idx = Index([1, 2, 3, 'a', 'b', 'c'])
2066+
2067+
self.assertTrue(idx[1:3].identical(pd.Index([2, 3], dtype=np.object_)))
2068+
self.assertTrue(idx[[0,1]].identical(pd.Index([1, 2], dtype=np.object_)))
2069+
20642070

20652071
def _simple_ts(start, end, freq='D'):
20662072
rng = date_range(start, end, freq=freq)

pandas/tslib.pxd

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ cdef bint _is_utc(object)
77
cdef bint _is_tzlocal(object)
88
cdef object _get_dst_info(object)
99
cdef bint _nat_scalar_rules[6]
10+
cdef bint _checknull_with_np_nat(object)

pandas/tslib.pyx

+7
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,13 @@ cdef inline bint _checknull_with_nat(object val):
646646
return val is None or (
647647
PyFloat_Check(val) and val != val) or val is NaT
648648

649+
650+
cdef inline bint _checknull_with_np_nat(object val):
651+
""" utility to check if a value is a nat or not """
652+
return val is None or (
653+
PyFloat_Check(val) and val != val) or val == np_NaT
654+
655+
649656
cdef inline bint _cmp_nat_dt(_NaT lhs, _Timestamp rhs, int op) except -1:
650657
return _nat_scalar_rules[op]
651658

0 commit comments

Comments
 (0)