Skip to content

Commit 185a654

Browse files
authored
BUG: scalar indexing on 2D DTA/TDA/PA (#33342)
1 parent 5f2cdf8 commit 185a654

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

pandas/core/arrays/datetimelike.py

+2
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,8 @@ def __getitem__(self, key):
574574
freq = self.freq
575575

576576
result = getitem(key)
577+
if lib.is_scalar(result):
578+
return self._box_func(result)
577579
return self._simple_new(result, dtype=self.dtype, freq=freq)
578580

579581
def __setitem__(

pandas/core/internals/blocks.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88

9-
from pandas._libs import NaT, Timestamp, algos as libalgos, lib, writers
9+
from pandas._libs import NaT, algos as libalgos, lib, writers
1010
import pandas._libs.internals as libinternals
1111
from pandas._libs.tslibs import Timedelta, conversion
1212
from pandas._libs.tslibs.timezones import tz_compare
@@ -2009,12 +2009,7 @@ def array_values(self):
20092009
def iget(self, key):
20102010
# GH#31649 we need to wrap scalars in Timestamp/Timedelta
20112011
# TODO(EA2D): this can be removed if we ever have 2D EA
2012-
result = super().iget(key)
2013-
if isinstance(result, np.datetime64):
2014-
result = Timestamp(result)
2015-
elif isinstance(result, np.timedelta64):
2016-
result = Timedelta(result)
2017-
return result
2012+
return self.array_values().reshape(self.shape)[key]
20182013

20192014
def shift(self, periods, axis=0, fill_value=None):
20202015
# TODO(EA2D) this is unnecessary if these blocks are backed by 2D EAs

pandas/tests/arrays/test_datetimelike.py

+5
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,11 @@ def test_getitem_2d(self, arr1d):
222222
result = arr2d[:3, 0]
223223
tm.assert_equal(result, expected)
224224

225+
# Scalar lookup
226+
result = arr2d[-1, 0]
227+
expected = arr1d[-1]
228+
assert result == expected
229+
225230
def test_setitem(self):
226231
data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
227232
arr = self.array_cls(data, freq="D")

0 commit comments

Comments
 (0)