@@ -254,12 +254,21 @@ def test_read_dta4(self, file):
254
254
)
255
255
256
256
# these are all categoricals
257
- expected = pd .concat (
258
- [expected [col ].astype ("category" ) for col in expected ], axis = 1
259
- )
257
+ for col in expected :
258
+ orig = expected [col ].copy ()
259
+
260
+ categories = np .asarray (expected ["fully_labeled" ][orig .notna ()])
261
+ if col == "incompletely_labeled" :
262
+ categories = orig
263
+
264
+ cat = orig .astype ("category" )._values
265
+ cat = cat .set_categories (categories , ordered = True )
266
+ cat .categories .rename (None , inplace = True )
267
+
268
+ expected [col ] = cat
260
269
261
270
# stata doesn't save .category metadata
262
- tm .assert_frame_equal (parsed , expected , check_categorical = False )
271
+ tm .assert_frame_equal (parsed , expected )
263
272
264
273
# File containing strls
265
274
def test_read_dta12 (self ):
@@ -952,19 +961,27 @@ def test_categorical_writing(self, version):
952
961
original = pd .concat (
953
962
[original [col ].astype ("category" ) for col in original ], axis = 1
954
963
)
964
+ expected .index .name = "index"
955
965
956
966
expected ["incompletely_labeled" ] = expected ["incompletely_labeled" ].apply (str )
957
967
expected ["unlabeled" ] = expected ["unlabeled" ].apply (str )
958
- expected = pd .concat (
959
- [expected [col ].astype ("category" ) for col in expected ], axis = 1
960
- )
961
- expected .index .name = "index"
968
+ for col in expected :
969
+ orig = expected [col ].copy ()
970
+
971
+ cat = orig .astype ("category" )._values
972
+ cat = cat .as_ordered ()
973
+ if col == "unlabeled" :
974
+ cat = cat .set_categories (orig , ordered = True )
975
+
976
+ cat .categories .rename (None , inplace = True )
977
+
978
+ expected [col ] = cat
962
979
963
980
with tm .ensure_clean () as path :
964
981
original .to_stata (path , version = version )
965
982
written_and_read_again = self .read_dta (path )
966
983
res = written_and_read_again .set_index ("index" )
967
- tm .assert_frame_equal (res , expected , check_categorical = False )
984
+ tm .assert_frame_equal (res , expected )
968
985
969
986
def test_categorical_warnings_and_errors (self ):
970
987
# Warning for non-string labels
@@ -1056,9 +1073,11 @@ def test_categorical_sorting(self, file):
1056
1073
parsed .index = np .arange (parsed .shape [0 ])
1057
1074
codes = [- 1 , - 1 , 0 , 1 , 1 , 1 , 2 , 2 , 3 , 4 ]
1058
1075
categories = ["Poor" , "Fair" , "Good" , "Very good" , "Excellent" ]
1059
- cat = pd .Categorical .from_codes (codes = codes , categories = categories )
1076
+ cat = pd .Categorical .from_codes (
1077
+ codes = codes , categories = categories , ordered = True
1078
+ )
1060
1079
expected = pd .Series (cat , name = "srh" )
1061
- tm .assert_series_equal (expected , parsed ["srh" ], check_categorical = False )
1080
+ tm .assert_series_equal (expected , parsed ["srh" ])
1062
1081
1063
1082
@pytest .mark .parametrize ("file" , ["dta19_115" , "dta19_117" ])
1064
1083
def test_categorical_ordering (self , file ):
0 commit comments