Skip to content

Commit e7538e6

Browse files
committed
Fixed: where
1 parent 7ec7351 commit e7538e6

File tree

7 files changed

+24
-26
lines changed

7 files changed

+24
-26
lines changed

pandas/core/arrays/datetimelike.py

+14
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,20 @@ def take(self, indices, allow_fill=False, fill_value=None):
600600

601601
return type(self)(new_values, dtype=self.dtype)
602602

603+
def where(self, cond, other):
604+
i8 = self.asi8
605+
if lib.is_scalar(other):
606+
if isna(other):
607+
other = iNaT
608+
elif isinstance(other, self._scalar_type):
609+
self._check_compatible_with(other)
610+
other = other.ordinal
611+
elif isinstance(other, type(self)):
612+
self._check_compatible_with(other)
613+
other = other.asi8
614+
result = np.where(cond, i8, other)
615+
return type(self)._simple_new(result, dtype=self.dtype)
616+
603617
@classmethod
604618
def _concat_same_type(cls, to_concat):
605619
dtypes = {x.dtype for x in to_concat}

pandas/core/arrays/datetimes.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ class DatetimeArrayMixin(dtl.DatetimeLikeArrayMixin,
205205
_freq = None
206206

207207
@classmethod
208-
def _simple_new(cls, values, freq=None, tz=None):
208+
def _simple_new(cls, values, freq=None, tz=None, dtype=None):
209+
# TODO: can we make this signature just
210+
# values, dtype, freq?
209211
"""
210212
we require the we have a dtype compat for the values
211213
if we are passed a non-dtype compat, then coerce using the constructor
@@ -218,6 +220,8 @@ def _simple_new(cls, values, freq=None, tz=None):
218220
values = values.view('M8[ns]')
219221

220222
assert values.dtype == 'M8[ns]', values.dtype
223+
if tz is None and dtype:
224+
tz = getattr(dtype, 'tz')
221225

222226
result = object.__new__(cls)
223227
result._data = values

pandas/core/arrays/period.py

-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66

7-
from pandas._libs import lib
87
from pandas._libs.tslibs import NaT, iNaT, period as libperiod
98
from pandas._libs.tslibs.fields import isleapyear_arr
109
from pandas._libs.tslibs.period import (
@@ -348,22 +347,6 @@ def to_timestamp(self, freq=None, how='start'):
348347
# --------------------------------------------------------------------
349348
# Array-like / EA-Interface Methods
350349

351-
def where(self, cond, other):
352-
# TODO(DatetimeArray): move to DatetimeLikeArrayMixin
353-
# n.b. _ndarray_values candidate.
354-
i8 = self.asi8
355-
if lib.is_scalar(other):
356-
if isna(other):
357-
other = iNaT
358-
elif isinstance(other, Period):
359-
self._check_compatible_with(other)
360-
other = other.ordinal
361-
elif isinstance(other, type(self)):
362-
self._check_compatible_with(other)
363-
other = other.asi8
364-
result = np.where(cond, i8, other)
365-
return type(self)._simple_new(result, dtype=self.dtype)
366-
367350
def _formatter(self, boxed=False):
368351
if boxed:
369352
return str

pandas/core/internals/blocks.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from pandas.core.dtypes.generic import (
3232
ABCDataFrame, ABCDatetimeIndex, ABCExtensionArray, ABCIndexClass,
3333
ABCSeries)
34-
from pandas.core.dtypes.inference import is_scalar
3534
from pandas.core.dtypes.missing import (
3635
_isna_compat, array_equivalent, is_null_datelike_scalar, isna, notna)
3736

@@ -3089,7 +3088,7 @@ def setitem(self, indexer, value):
30893088
return_object = (
30903089
(maybe_tz
30913090
and not timezones.tz_compare(self.values.tz, maybe_tz)) or
3092-
(is_scalar(value)
3091+
(lib.is_scalar(value)
30933092
and not isna(value)
30943093
and not value == tslib.iNaT
30953094
and not (isinstance(value, self.values._scalar_type) and

pandas/tests/frame/test_indexing.py

-1
Original file line numberDiff line numberDiff line change
@@ -3075,7 +3075,6 @@ def test_where_callable(self):
30753075
tm.assert_frame_equal(result,
30763076
(df + 2).where((df + 2) > 8, (df + 2) + 10))
30773077

3078-
@pytest.mark.xfail(reason="TODO-where", strict=False)
30793078
def test_where_tz_values(self, tz_naive_fixture):
30803079
df1 = DataFrame(DatetimeIndex(['20150101', '20150102', '20150103'],
30813080
tz=tz_naive_fixture),

pandas/tests/indexes/timedeltas/test_timedelta.py

-4
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ def setup_method(self, method):
2828
def create_index(self):
2929
return pd.to_timedelta(range(5), unit='d') + pd.offsets.Hour(1)
3030

31-
@pytest.mark.skip(reason="TODO-where")
32-
def test_where(self, klass):
33-
return super().test_where(klass)
34-
3531
def test_numeric_compat(self):
3632
# Dummy method to override super's version; this test is now done
3733
# in test_arithmetic.py

pandas/tests/series/test_combine_concat.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ def get_result_type(dtype, dtype2):
198198
]).dtype
199199
assert result.kind == expected
200200

201-
@pytest.mark.xfail(resson="TODO-where", strict=False)
201+
@pytest.mark.xfail(resson="TODO-where-internals", strict=False)
202+
# Something strange with internals shapes.
203+
# After reindexing in combine_first, our tz-block mananger is
204+
# (maybe?) in a bad state.
202205
def test_combine_first_dt_tz_values(self, tz_naive_fixture):
203206
ser1 = pd.Series(pd.DatetimeIndex(['20150101', '20150102', '20150103'],
204207
tz=tz_naive_fixture),

0 commit comments

Comments
 (0)