Skip to content

Commit 5402ea5

Browse files
jbrockmendeljreback
authored andcommitted
REF: make PeriodIndex.get_value wrap PeriodIndex.get_loc (#31318)
1 parent 3c76318 commit 5402ea5

File tree

2 files changed

+56
-52
lines changed

2 files changed

+56
-52
lines changed

pandas/core/indexes/period.py

+13-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timedelta
2-
from typing import Any
2+
from typing import TYPE_CHECKING, Any
33
import weakref
44

55
import numpy as np
@@ -20,6 +20,7 @@
2020
is_integer_dtype,
2121
is_list_like,
2222
is_object_dtype,
23+
is_scalar,
2324
pandas_dtype,
2425
)
2526

@@ -33,6 +34,7 @@
3334
import pandas.core.common as com
3435
import pandas.core.indexes.base as ibase
3536
from pandas.core.indexes.base import (
37+
InvalidIndexError,
3638
_index_shared_docs,
3739
ensure_index,
3840
maybe_extract_name,
@@ -52,6 +54,8 @@
5254
_index_doc_kwargs = dict(ibase._index_doc_kwargs)
5355
_index_doc_kwargs.update(dict(target_klass="PeriodIndex or list of Periods"))
5456

57+
if TYPE_CHECKING:
58+
from pandas import Series
5559

5660
# --- Period index sketch
5761

@@ -479,43 +483,16 @@ def inferred_type(self) -> str:
479483
# indexing
480484
return "period"
481485

482-
def get_value(self, series, key):
486+
def get_value(self, series: "Series", key):
483487
"""
484488
Fast lookup of value from 1-dimensional ndarray. Only use this if you
485489
know what you're doing
486490
"""
487491
if is_integer(key):
488-
return series.iat[key]
489-
490-
if isinstance(key, str):
491-
try:
492-
loc = self._get_string_slice(key)
493-
return series[loc]
494-
except (TypeError, ValueError, OverflowError):
495-
pass
496-
497-
asdt, reso = parse_time_string(key, self.freq)
498-
grp = resolution.Resolution.get_freq_group(reso)
499-
freqn = resolution.get_freq_group(self.freq)
500-
501-
# _get_string_slice will handle cases where grp < freqn
502-
assert grp >= freqn
503-
504-
if grp == freqn:
505-
key = Period(asdt, freq=self.freq)
506-
loc = self.get_loc(key)
507-
return series.iloc[loc]
508-
else:
509-
raise KeyError(key)
510-
511-
elif isinstance(key, Period) or key is NaT:
512-
ordinal = key.ordinal if key is not NaT else NaT.value
513-
loc = self._engine.get_loc(ordinal)
514-
return series[loc]
515-
516-
# slice, PeriodIndex, np.ndarray, List[Period]
517-
value = Index.get_value(self, series, key)
518-
return com.maybe_box(self, value, series, key)
492+
loc = key
493+
else:
494+
loc = self.get_loc(key)
495+
return self._get_values_for_loc(series, loc)
519496

520497
@Appender(_index_shared_docs["get_indexer"] % _index_doc_kwargs)
521498
def get_indexer(self, target, method=None, limit=None, tolerance=None):
@@ -571,6 +548,9 @@ def get_loc(self, key, method=None, tolerance=None):
571548
If key is listlike or otherwise not hashable.
572549
"""
573550

551+
if not is_scalar(key):
552+
raise InvalidIndexError(key)
553+
574554
if isinstance(key, str):
575555

576556
try:

pandas/tests/indexes/period/test_indexing.py

+43-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime, timedelta
2+
import re
23

34
import numpy as np
45
import pytest
@@ -8,6 +9,7 @@
89
import pandas as pd
910
from pandas import DatetimeIndex, Period, PeriodIndex, Series, notna, period_range
1011
import pandas._testing as tm
12+
from pandas.core.indexes.base import InvalidIndexError
1113

1214

1315
class TestGetItem:
@@ -408,11 +410,7 @@ def test_get_loc(self):
408410
with pytest.raises(KeyError, match=r"^1\.1$"):
409411
idx0.get_loc(1.1)
410412

411-
msg = (
412-
r"'PeriodIndex\(\['2017-09-01', '2017-09-02', '2017-09-03'\], "
413-
r"dtype='period\[D\]', freq='D'\)' is an invalid key"
414-
)
415-
with pytest.raises(TypeError, match=msg):
413+
with pytest.raises(InvalidIndexError, match=re.escape(str(idx0))):
416414
idx0.get_loc(idx0)
417415

418416
# get the location of p1/p2 from
@@ -433,11 +431,7 @@ def test_get_loc(self):
433431
with pytest.raises(KeyError, match=r"^1\.1$"):
434432
idx1.get_loc(1.1)
435433

436-
msg = (
437-
r"'PeriodIndex\(\['2017-09-02', '2017-09-02', '2017-09-03'\], "
438-
r"dtype='period\[D\]', freq='D'\)' is an invalid key"
439-
)
440-
with pytest.raises(TypeError, match=msg):
434+
with pytest.raises(InvalidIndexError, match=re.escape(str(idx1))):
441435
idx1.get_loc(idx1)
442436

443437
# get the location of p1/p2 from
@@ -461,16 +455,46 @@ def test_get_loc_integer(self):
461455
with pytest.raises(KeyError, match="46"):
462456
pi2.get_loc(46)
463457

458+
@pytest.mark.parametrize("freq", ["H", "D"])
459+
def test_get_value_datetime_hourly(self, freq):
460+
# get_loc and get_value should treat datetime objects symmetrically
461+
dti = pd.date_range("2016-01-01", periods=3, freq="MS")
462+
pi = dti.to_period(freq)
463+
ser = pd.Series(range(7, 10), index=pi)
464+
465+
ts = dti[0]
466+
467+
assert pi.get_loc(ts) == 0
468+
assert pi.get_value(ser, ts) == 7
469+
assert ser[ts] == 7
470+
assert ser.loc[ts] == 7
471+
472+
ts2 = ts + pd.Timedelta(hours=3)
473+
if freq == "H":
474+
with pytest.raises(KeyError, match="2016-01-01 03:00"):
475+
pi.get_loc(ts2)
476+
with pytest.raises(KeyError, match="2016-01-01 03:00"):
477+
pi.get_value(ser, ts2)
478+
with pytest.raises(KeyError, match="2016-01-01 03:00"):
479+
ser[ts2]
480+
with pytest.raises(KeyError, match="2016-01-01 03:00"):
481+
ser.loc[ts2]
482+
else:
483+
assert pi.get_loc(ts2) == 0
484+
assert pi.get_value(ser, ts2) == 7
485+
assert ser[ts2] == 7
486+
assert ser.loc[ts2] == 7
487+
464488
def test_get_value_integer(self):
465489
dti = pd.date_range("2016-01-01", periods=3)
466490
pi = dti.to_period("D")
467491
ser = pd.Series(range(3), index=pi)
468-
with pytest.raises(IndexError, match="is out of bounds for axis 0 with size 3"):
492+
with pytest.raises(IndexError, match="index out of bounds"):
469493
pi.get_value(ser, 16801)
470494

471495
pi2 = dti.to_period("Y") # duplicates, ordinals are all 46
472496
ser2 = pd.Series(range(3), index=pi2)
473-
with pytest.raises(IndexError, match="is out of bounds for axis 0 with size 3"):
497+
with pytest.raises(IndexError, match="index out of bounds"):
474498
pi2.get_value(ser2, 46)
475499

476500
def test_is_monotonic_increasing(self):
@@ -544,25 +568,25 @@ def test_get_value(self):
544568
p2 = pd.Period("2017-09-03")
545569

546570
idx0 = pd.PeriodIndex([p0, p1, p2])
547-
input0 = np.array([1, 2, 3])
571+
input0 = pd.Series(np.array([1, 2, 3]), index=idx0)
548572
expected0 = 2
549573

550574
result0 = idx0.get_value(input0, p1)
551575
assert result0 == expected0
552576

553577
idx1 = pd.PeriodIndex([p1, p1, p2])
554-
input1 = np.array([1, 2, 3])
555-
expected1 = np.array([1, 2])
578+
input1 = pd.Series(np.array([1, 2, 3]), index=idx1)
579+
expected1 = input1.iloc[[0, 1]]
556580

557581
result1 = idx1.get_value(input1, p1)
558-
tm.assert_numpy_array_equal(result1, expected1)
582+
tm.assert_series_equal(result1, expected1)
559583

560584
idx2 = pd.PeriodIndex([p1, p2, p1])
561-
input2 = np.array([1, 2, 3])
562-
expected2 = np.array([1, 3])
585+
input2 = pd.Series(np.array([1, 2, 3]), index=idx2)
586+
expected2 = input2.iloc[[0, 2]]
563587

564588
result2 = idx2.get_value(input2, p1)
565-
tm.assert_numpy_array_equal(result2, expected2)
589+
tm.assert_series_equal(result2, expected2)
566590

567591
def test_get_indexer(self):
568592
# GH 17717

0 commit comments

Comments
 (0)