8
8
import numpy as np
9
9
import pytest
10
10
11
- from pandas ._config import using_string_dtype
12
-
13
- from pandas .compat import HAS_PYARROW
14
11
from pandas .compat .numpy import (
15
12
np_version_gt2 ,
16
13
np_version_gte1p24 ,
37
34
concat ,
38
35
date_range ,
39
36
interval_range ,
37
+ isna ,
40
38
period_range ,
41
39
timedelta_range ,
42
40
)
@@ -564,14 +562,16 @@ def test_append_timedelta_does_not_cast(self, td, using_infer_string, request):
564
562
tm .assert_series_equal (ser , expected )
565
563
assert isinstance (ser ["td" ], Timedelta )
566
564
567
- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" )
568
565
def test_setitem_with_expansion_type_promotion (self ):
569
566
# GH#12599
570
567
ser = Series (dtype = object )
571
568
ser ["a" ] = Timestamp ("2016-01-01" )
572
569
ser ["b" ] = 3.0
573
570
ser ["c" ] = "foo"
574
- expected = Series ([Timestamp ("2016-01-01" ), 3.0 , "foo" ], index = ["a" , "b" , "c" ])
571
+ expected = Series (
572
+ [Timestamp ("2016-01-01" ), 3.0 , "foo" ],
573
+ index = Index (["a" , "b" , "c" ], dtype = object ),
574
+ )
575
575
tm .assert_series_equal (ser , expected )
576
576
577
577
def test_setitem_not_contained (self , string_series ):
@@ -850,11 +850,6 @@ def test_mask_key(self, obj, key, expected, warn, val, indexer_sli):
850
850
indexer_sli (obj )[mask ] = val
851
851
tm .assert_series_equal (obj , expected )
852
852
853
- @pytest .mark .xfail (
854
- using_string_dtype () and not HAS_PYARROW ,
855
- reason = "TODO(infer_string)" ,
856
- strict = False ,
857
- )
858
853
def test_series_where (self , obj , key , expected , warn , val , is_inplace ):
859
854
mask = np .zeros (obj .shape , dtype = bool )
860
855
mask [key ] = True
@@ -870,6 +865,11 @@ def test_series_where(self, obj, key, expected, warn, val, is_inplace):
870
865
obj = obj .copy ()
871
866
arr = obj ._values
872
867
868
+ if obj .dtype == "string" and not (isinstance (val , str ) or isna (val )):
869
+ with pytest .raises (TypeError , match = "Invalid value" ):
870
+ obj .where (~ mask , val )
871
+ return
872
+
873
873
res = obj .where (~ mask , val )
874
874
875
875
if val is NA and res .dtype == object :
@@ -882,29 +882,27 @@ def test_series_where(self, obj, key, expected, warn, val, is_inplace):
882
882
883
883
self ._check_inplace (is_inplace , orig , arr , obj )
884
884
885
- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False )
886
- def test_index_where (self , obj , key , expected , warn , val , using_infer_string ):
885
+ def test_index_where (self , obj , key , expected , warn , val ):
887
886
mask = np .zeros (obj .shape , dtype = bool )
888
887
mask [key ] = True
889
888
890
- if using_infer_string and obj .dtype == object :
889
+ if obj .dtype == "string" and not ( isinstance ( val , str ) or isna ( val )) :
891
890
with pytest .raises (TypeError , match = "Invalid value" ):
892
- Index (obj ).where (~ mask , val )
891
+ Index (obj , dtype = obj . dtype ).where (~ mask , val )
893
892
else :
894
- res = Index (obj ).where (~ mask , val )
893
+ res = Index (obj , dtype = obj . dtype ).where (~ mask , val )
895
894
expected_idx = Index (expected , dtype = expected .dtype )
896
895
tm .assert_index_equal (res , expected_idx )
897
896
898
- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False )
899
- def test_index_putmask (self , obj , key , expected , warn , val , using_infer_string ):
897
+ def test_index_putmask (self , obj , key , expected , warn , val ):
900
898
mask = np .zeros (obj .shape , dtype = bool )
901
899
mask [key ] = True
902
900
903
- if using_infer_string and obj .dtype == object :
901
+ if obj .dtype == "string" and not ( isinstance ( val , str ) or isna ( val )) :
904
902
with pytest .raises (TypeError , match = "Invalid value" ):
905
- Index (obj ).putmask (mask , val )
903
+ Index (obj , dtype = obj . dtype ).putmask (mask , val )
906
904
else :
907
- res = Index (obj ).putmask (mask , val )
905
+ res = Index (obj , dtype = obj . dtype ).putmask (mask , val )
908
906
tm .assert_index_equal (res , Index (expected , dtype = expected .dtype ))
909
907
910
908
0 commit comments