Skip to content

Commit 41d3c96

Browse files
Licht-TNo-Stream
authored andcommitted
BUG: Implement PeriodEngine to fix PeriodIndex truncate bug (pandas-dev#17755)
1 parent b68eb34 commit 41d3c96

File tree

6 files changed

+289
-6
lines changed

6 files changed

+289
-6
lines changed

doc/source/whatsnew/v0.22.0.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Conversion
100100
Indexing
101101
^^^^^^^^
102102

103-
-
103+
- Bug in :func:`PeriodIndex.truncate` which raises ``TypeError`` when ``PeriodIndex`` is monotonic (:issue:`17717`)
104104
-
105105
-
106106

pandas/_libs/index.pyx

+52-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from tslib cimport _to_i8
1717

1818
from hashtable cimport HashTable
1919

20-
from pandas._libs import algos, hashtable as _hash
20+
from pandas._libs import algos, period as periodlib, hashtable as _hash
2121
from pandas._libs.tslib import Timestamp, Timedelta
2222
from datetime import datetime, timedelta
2323

@@ -270,13 +270,16 @@ cdef class IndexEngine:
270270

271271
values = self._get_index_values()
272272
self.mapping = self._make_hash_table(len(values))
273-
self.mapping.map_locations(values)
273+
self._call_map_locations(values)
274274

275275
if len(self.mapping) == len(values):
276276
self.unique = 1
277277

278278
self.need_unique_check = 0
279279

280+
cpdef _call_map_locations(self, values):
281+
self.mapping.map_locations(values)
282+
280283
def clear_mapping(self):
281284
self.mapping = None
282285
self.need_monotonic_check = 1
@@ -490,6 +493,53 @@ cdef class TimedeltaEngine(DatetimeEngine):
490493
cdef _get_box_dtype(self):
491494
return 'm8[ns]'
492495

496+
497+
cdef class PeriodEngine(Int64Engine):
498+
499+
cdef _get_index_values(self):
500+
return super(PeriodEngine, self).vgetter()
501+
502+
cpdef _call_map_locations(self, values):
503+
super(PeriodEngine, self)._call_map_locations(values.view('i8'))
504+
505+
def _call_monotonic(self, values):
506+
return super(PeriodEngine, self)._call_monotonic(values.view('i8'))
507+
508+
def get_indexer(self, values):
509+
cdef ndarray[int64_t, ndim=1] ordinals
510+
511+
super(PeriodEngine, self)._ensure_mapping_populated()
512+
513+
freq = super(PeriodEngine, self).vgetter().freq
514+
ordinals = periodlib.extract_ordinals(values, freq)
515+
516+
return self.mapping.lookup(ordinals)
517+
518+
def get_pad_indexer(self, other, limit=None):
519+
freq = super(PeriodEngine, self).vgetter().freq
520+
ordinal = periodlib.extract_ordinals(other, freq)
521+
522+
return algos.pad_int64(self._get_index_values(),
523+
np.asarray(ordinal), limit=limit)
524+
525+
def get_backfill_indexer(self, other, limit=None):
526+
freq = super(PeriodEngine, self).vgetter().freq
527+
ordinal = periodlib.extract_ordinals(other, freq)
528+
529+
return algos.backfill_int64(self._get_index_values(),
530+
np.asarray(ordinal), limit=limit)
531+
532+
def get_indexer_non_unique(self, targets):
533+
freq = super(PeriodEngine, self).vgetter().freq
534+
ordinal = periodlib.extract_ordinals(targets, freq)
535+
ordinal_array = np.asarray(ordinal)
536+
537+
return super(PeriodEngine, self).get_indexer_non_unique(ordinal_array)
538+
539+
cdef _get_index_values_for_bool_indexer(self):
540+
return self._get_index_values().view('i8')
541+
542+
493543
cpdef convert_scalar(ndarray arr, object value):
494544
# we don't turn integers
495545
# into datetimes/timedeltas

pandas/_libs/index_class_helper.pxi.in

+4-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ cdef class {{name}}Engine(IndexEngine):
6666
raise KeyError(val)
6767
{{endif}}
6868

69-
values = self._get_index_values()
69+
values = self._get_index_values_for_bool_indexer()
7070
n = len(values)
7171

7272
result = np.empty(n, dtype=bool)
@@ -86,6 +86,9 @@ cdef class {{name}}Engine(IndexEngine):
8686
return last_true
8787

8888
return result
89+
90+
cdef _get_index_values_for_bool_indexer(self):
91+
return self._get_index_values()
8992
{{endif}}
9093

9194
{{endfor}}

pandas/core/indexes/period.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import pandas.tseries.offsets as offsets
3232

3333
from pandas._libs.lib import infer_dtype
34-
from pandas._libs import tslib, period
34+
from pandas._libs import tslib, period, index as libindex
3535
from pandas._libs.period import (Period, IncompatibleFrequency,
3636
get_period_field_arr, _validate_end_alias,
3737
_quarter_to_myear)
@@ -192,6 +192,8 @@ class PeriodIndex(DatelikeOps, DatetimeIndexOpsMixin, Int64Index):
192192

193193
freq = None
194194

195+
_engine_type = libindex.PeriodEngine
196+
195197
__eq__ = _period_index_cmp('__eq__')
196198
__ne__ = _period_index_cmp('__ne__', nat_result=True)
197199
__lt__ = _period_index_cmp('__lt__')
@@ -275,6 +277,10 @@ def __new__(cls, data=None, ordinal=None, freq=None, start=None, end=None,
275277
data = period.extract_ordinals(data, freq)
276278
return cls._from_ordinals(data, name=name, freq=freq)
277279

280+
@cache_readonly
281+
def _engine(self):
282+
return self._engine_type(lambda: self, len(self))
283+
278284
@classmethod
279285
def _generate_range(cls, start, end, periods, freq, fields):
280286
if freq is not None:

pandas/tests/indexes/period/test_indexing.py

+195-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pandas as pd
77
from pandas.util import testing as tm
88
from pandas.compat import lrange
9-
from pandas._libs import tslib
9+
from pandas._libs import tslib, tslibs
1010
from pandas import (PeriodIndex, Series, DatetimeIndex,
1111
period_range, Period)
1212

@@ -310,3 +310,197 @@ def test_take_fill_value(self):
310310

311311
with pytest.raises(IndexError):
312312
idx.take(np.array([1, -5]))
313+
314+
def test_get_loc(self):
315+
# GH 17717
316+
p0 = pd.Period('2017-09-01')
317+
p1 = pd.Period('2017-09-02')
318+
p2 = pd.Period('2017-09-03')
319+
320+
# get the location of p1/p2 from
321+
# monotonic increasing PeriodIndex with non-duplicate
322+
idx0 = pd.PeriodIndex([p0, p1, p2])
323+
expected_idx1_p1 = 1
324+
expected_idx1_p2 = 2
325+
326+
assert idx0.get_loc(p1) == expected_idx1_p1
327+
assert idx0.get_loc(str(p1)) == expected_idx1_p1
328+
assert idx0.get_loc(p2) == expected_idx1_p2
329+
assert idx0.get_loc(str(p2)) == expected_idx1_p2
330+
331+
pytest.raises(tslibs.parsing.DateParseError, idx0.get_loc, 'foo')
332+
pytest.raises(KeyError, idx0.get_loc, 1.1)
333+
pytest.raises(TypeError, idx0.get_loc, idx0)
334+
335+
# get the location of p1/p2 from
336+
# monotonic increasing PeriodIndex with duplicate
337+
idx1 = pd.PeriodIndex([p1, p1, p2])
338+
expected_idx1_p1 = slice(0, 2)
339+
expected_idx1_p2 = 2
340+
341+
assert idx1.get_loc(p1) == expected_idx1_p1
342+
assert idx1.get_loc(str(p1)) == expected_idx1_p1
343+
assert idx1.get_loc(p2) == expected_idx1_p2
344+
assert idx1.get_loc(str(p2)) == expected_idx1_p2
345+
346+
pytest.raises(tslibs.parsing.DateParseError, idx1.get_loc, 'foo')
347+
pytest.raises(KeyError, idx1.get_loc, 1.1)
348+
pytest.raises(TypeError, idx1.get_loc, idx1)
349+
350+
# get the location of p1/p2 from
351+
# non-monotonic increasing/decreasing PeriodIndex with duplicate
352+
idx2 = pd.PeriodIndex([p2, p1, p2])
353+
expected_idx2_p1 = 1
354+
expected_idx2_p2 = np.array([True, False, True])
355+
356+
assert idx2.get_loc(p1) == expected_idx2_p1
357+
assert idx2.get_loc(str(p1)) == expected_idx2_p1
358+
tm.assert_numpy_array_equal(idx2.get_loc(p2), expected_idx2_p2)
359+
tm.assert_numpy_array_equal(idx2.get_loc(str(p2)), expected_idx2_p2)
360+
361+
def test_is_monotonic_increasing(self):
362+
# GH 17717
363+
p0 = pd.Period('2017-09-01')
364+
p1 = pd.Period('2017-09-02')
365+
p2 = pd.Period('2017-09-03')
366+
367+
idx_inc0 = pd.PeriodIndex([p0, p1, p2])
368+
idx_inc1 = pd.PeriodIndex([p0, p1, p1])
369+
idx_dec0 = pd.PeriodIndex([p2, p1, p0])
370+
idx_dec1 = pd.PeriodIndex([p2, p1, p1])
371+
idx = pd.PeriodIndex([p1, p2, p0])
372+
373+
assert idx_inc0.is_monotonic_increasing
374+
assert idx_inc1.is_monotonic_increasing
375+
assert not idx_dec0.is_monotonic_increasing
376+
assert not idx_dec1.is_monotonic_increasing
377+
assert not idx.is_monotonic_increasing
378+
379+
def test_is_monotonic_decreasing(self):
380+
# GH 17717
381+
p0 = pd.Period('2017-09-01')
382+
p1 = pd.Period('2017-09-02')
383+
p2 = pd.Period('2017-09-03')
384+
385+
idx_inc0 = pd.PeriodIndex([p0, p1, p2])
386+
idx_inc1 = pd.PeriodIndex([p0, p1, p1])
387+
idx_dec0 = pd.PeriodIndex([p2, p1, p0])
388+
idx_dec1 = pd.PeriodIndex([p2, p1, p1])
389+
idx = pd.PeriodIndex([p1, p2, p0])
390+
391+
assert not idx_inc0.is_monotonic_decreasing
392+
assert not idx_inc1.is_monotonic_decreasing
393+
assert idx_dec0.is_monotonic_decreasing
394+
assert idx_dec1.is_monotonic_decreasing
395+
assert not idx.is_monotonic_decreasing
396+
397+
def test_is_unique(self):
398+
# GH 17717
399+
p0 = pd.Period('2017-09-01')
400+
p1 = pd.Period('2017-09-02')
401+
p2 = pd.Period('2017-09-03')
402+
403+
idx0 = pd.PeriodIndex([p0, p1, p2])
404+
assert idx0.is_unique
405+
406+
idx1 = pd.PeriodIndex([p1, p1, p2])
407+
assert not idx1.is_unique
408+
409+
def test_contains(self):
410+
# GH 17717
411+
p0 = pd.Period('2017-09-01')
412+
p1 = pd.Period('2017-09-02')
413+
p2 = pd.Period('2017-09-03')
414+
p3 = pd.Period('2017-09-04')
415+
416+
ps0 = [p0, p1, p2]
417+
idx0 = pd.PeriodIndex(ps0)
418+
419+
for p in ps0:
420+
assert idx0.contains(p)
421+
assert p in idx0
422+
423+
assert idx0.contains(str(p))
424+
assert str(p) in idx0
425+
426+
assert idx0.contains('2017-09-01 00:00:01')
427+
assert '2017-09-01 00:00:01' in idx0
428+
429+
assert idx0.contains('2017-09')
430+
assert '2017-09' in idx0
431+
432+
assert not idx0.contains(p3)
433+
assert p3 not in idx0
434+
435+
def test_get_value(self):
436+
# GH 17717
437+
p0 = pd.Period('2017-09-01')
438+
p1 = pd.Period('2017-09-02')
439+
p2 = pd.Period('2017-09-03')
440+
441+
idx0 = pd.PeriodIndex([p0, p1, p2])
442+
input0 = np.array([1, 2, 3])
443+
expected0 = 2
444+
445+
result0 = idx0.get_value(input0, p1)
446+
assert result0 == expected0
447+
448+
idx1 = pd.PeriodIndex([p1, p1, p2])
449+
input1 = np.array([1, 2, 3])
450+
expected1 = np.array([1, 2])
451+
452+
result1 = idx1.get_value(input1, p1)
453+
tm.assert_numpy_array_equal(result1, expected1)
454+
455+
idx2 = pd.PeriodIndex([p1, p2, p1])
456+
input2 = np.array([1, 2, 3])
457+
expected2 = np.array([1, 3])
458+
459+
result2 = idx2.get_value(input2, p1)
460+
tm.assert_numpy_array_equal(result2, expected2)
461+
462+
def test_get_indexer(self):
463+
# GH 17717
464+
p1 = pd.Period('2017-09-01')
465+
p2 = pd.Period('2017-09-04')
466+
p3 = pd.Period('2017-09-07')
467+
468+
tp0 = pd.Period('2017-08-31')
469+
tp1 = pd.Period('2017-09-02')
470+
tp2 = pd.Period('2017-09-05')
471+
tp3 = pd.Period('2017-09-09')
472+
473+
idx = pd.PeriodIndex([p1, p2, p3])
474+
475+
tm.assert_numpy_array_equal(idx.get_indexer(idx),
476+
np.array([0, 1, 2], dtype=np.intp))
477+
478+
target = pd.PeriodIndex([tp0, tp1, tp2, tp3])
479+
tm.assert_numpy_array_equal(idx.get_indexer(target, 'pad'),
480+
np.array([-1, 0, 1, 2], dtype=np.intp))
481+
tm.assert_numpy_array_equal(idx.get_indexer(target, 'backfill'),
482+
np.array([0, 1, 2, -1], dtype=np.intp))
483+
tm.assert_numpy_array_equal(idx.get_indexer(target, 'nearest'),
484+
np.array([0, 0, 1, 2], dtype=np.intp))
485+
486+
res = idx.get_indexer(target, 'nearest',
487+
tolerance=pd.Timedelta('1 day'))
488+
tm.assert_numpy_array_equal(res,
489+
np.array([0, 0, 1, -1], dtype=np.intp))
490+
491+
def test_get_indexer_non_unique(self):
492+
# GH 17717
493+
p1 = pd.Period('2017-09-02')
494+
p2 = pd.Period('2017-09-03')
495+
p3 = pd.Period('2017-09-04')
496+
p4 = pd.Period('2017-09-05')
497+
498+
idx1 = pd.PeriodIndex([p1, p2, p1])
499+
idx2 = pd.PeriodIndex([p2, p1, p3, p4])
500+
501+
result = idx1.get_indexer_non_unique(idx2)
502+
expected_indexer = np.array([1, 0, 2, -1, -1], dtype=np.int64)
503+
expected_missing = np.array([2, 3], dtype=np.int64)
504+
505+
tm.assert_numpy_array_equal(result[0], expected_indexer)
506+
tm.assert_numpy_array_equal(result[1], expected_missing)

pandas/tests/series/test_period.py

+30
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,33 @@ def test_align_series(self):
249249
msg = "Input has different freq=D from PeriodIndex\\(freq=A-DEC\\)"
250250
with tm.assert_raises_regex(period.IncompatibleFrequency, msg):
251251
ts + ts.asfreq('D', how="end")
252+
253+
def test_truncate(self):
254+
# GH 17717
255+
idx1 = pd.PeriodIndex([
256+
pd.Period('2017-09-02'),
257+
pd.Period('2017-09-02'),
258+
pd.Period('2017-09-03')
259+
])
260+
series1 = pd.Series([1, 2, 3], index=idx1)
261+
result1 = series1.truncate(after='2017-09-02')
262+
263+
expected_idx1 = pd.PeriodIndex([
264+
pd.Period('2017-09-02'),
265+
pd.Period('2017-09-02')
266+
])
267+
tm.assert_series_equal(result1, pd.Series([1, 2], index=expected_idx1))
268+
269+
idx2 = pd.PeriodIndex([
270+
pd.Period('2017-09-03'),
271+
pd.Period('2017-09-02'),
272+
pd.Period('2017-09-03')
273+
])
274+
series2 = pd.Series([1, 2, 3], index=idx2)
275+
result2 = series2.truncate(after='2017-09-02')
276+
277+
expected_idx2 = pd.PeriodIndex([
278+
pd.Period('2017-09-03'),
279+
pd.Period('2017-09-02')
280+
])
281+
tm.assert_series_equal(result2, pd.Series([1, 2], index=expected_idx2))

0 commit comments

Comments
 (0)