@@ -175,40 +175,56 @@ def test_where_other(self):
175
175
def test_where_invalid_dtypes (self ):
176
176
dti = date_range ("20130101" , periods = 3 , tz = "US/Eastern" )
177
177
178
- i2 = Index ([pd .NaT , pd .NaT ] + dti [2 :].tolist ())
178
+ tail = dti [2 :].tolist ()
179
+ i2 = Index ([pd .NaT , pd .NaT ] + tail )
179
180
180
- msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
181
- msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects"
182
- with pytest .raises (TypeError , match = msg2 ):
183
- # passing tz-naive ndarray to tzaware DTI
184
- dti .where (notna (i2 ), i2 .values )
181
+ mask = notna (i2 )
185
182
186
- with pytest .raises (TypeError , match = msg2 ):
187
- # passing tz-aware DTI to tznaive DTI
188
- dti .tz_localize (None ).where (notna (i2 ), i2 )
183
+ # passing tz-naive ndarray to tzaware DTI
184
+ result = dti .where (mask , i2 .values )
185
+ expected = Index ([pd .NaT .asm8 , pd .NaT .asm8 ] + tail , dtype = object )
186
+ tm .assert_index_equal (result , expected )
189
187
190
- with pytest .raises (TypeError , match = msg ):
191
- dti .where (notna (i2 ), i2 .tz_localize (None ).to_period ("D" ))
188
+ # passing tz-aware DTI to tznaive DTI
189
+ naive = dti .tz_localize (None )
190
+ result = naive .where (mask , i2 )
191
+ expected = Index ([i2 [0 ], i2 [1 ]] + naive [2 :].tolist (), dtype = object )
192
+ tm .assert_index_equal (result , expected )
192
193
193
- with pytest .raises (TypeError , match = msg ):
194
- dti .where (notna (i2 ), i2 .asi8 .view ("timedelta64[ns]" ))
194
+ pi = i2 .tz_localize (None ).to_period ("D" )
195
+ result = dti .where (mask , pi )
196
+ expected = Index ([pi [0 ], pi [1 ]] + tail , dtype = object )
197
+ tm .assert_index_equal (result , expected )
195
198
196
- with pytest .raises (TypeError , match = msg ):
197
- dti .where (notna (i2 ), i2 .asi8 )
199
+ tda = i2 .asi8 .view ("timedelta64[ns]" )
200
+ result = dti .where (mask , tda )
201
+ expected = Index ([tda [0 ], tda [1 ]] + tail , dtype = object )
202
+ assert isinstance (expected [0 ], np .timedelta64 )
203
+ tm .assert_index_equal (result , expected )
198
204
199
- with pytest .raises (TypeError , match = msg ):
200
- # non-matching scalar
201
- dti .where (notna (i2 ), pd .Timedelta (days = 4 ))
205
+ result = dti .where (mask , i2 .asi8 )
206
+ expected = Index ([pd .NaT .value , pd .NaT .value ] + tail , dtype = object )
207
+ assert isinstance (expected [0 ], int )
208
+ tm .assert_index_equal (result , expected )
209
+
210
+ # non-matching scalar
211
+ td = pd .Timedelta (days = 4 )
212
+ result = dti .where (mask , td )
213
+ expected = Index ([td , td ] + tail , dtype = object )
214
+ assert expected [0 ] is td
215
+ tm .assert_index_equal (result , expected )
202
216
203
217
def test_where_mismatched_nat (self , tz_aware_fixture ):
204
218
tz = tz_aware_fixture
205
219
dti = date_range ("2013-01-01" , periods = 3 , tz = tz )
206
220
cond = np .array ([True , False , True ])
207
221
208
- msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
209
- with pytest .raises (TypeError , match = msg ):
210
- # wrong-dtyped NaT
211
- dti .where (cond , np .timedelta64 ("NaT" , "ns" ))
222
+ tdnat = np .timedelta64 ("NaT" , "ns" )
223
+ expected = Index ([dti [0 ], tdnat , dti [2 ]], dtype = object )
224
+ assert expected [1 ] is tdnat
225
+
226
+ result = dti .where (cond , tdnat )
227
+ tm .assert_index_equal (result , expected )
212
228
213
229
def test_where_tz (self ):
214
230
i = date_range ("20130101" , periods = 3 , tz = "US/Eastern" )
0 commit comments