16
16
import tempfile
17
17
import unittest
18
18
import unittest .mock as mock
19
+ import warnings
19
20
from test import client_context
20
21
from test .utils import AllowListEventListener , TestNullsBase
21
22
@@ -98,13 +99,24 @@ def test_aggregate_simple(self):
98
99
self .assertEqual (agg_cmd .command ["pipeline" ][0 ]["$project" ], projection )
99
100
self .assertEqual (agg_cmd .command ["pipeline" ][1 ]["$project" ], {"_id" : True , "data" : True })
100
101
102
+ def _assert_frames_equal (self , incoming , outgoing ):
103
+ for name in incoming .columns :
104
+ in_col = incoming [name ]
105
+ out_col = outgoing [name ]
106
+ # Object types may lose type information in a round trip.
107
+ # Integer types with missing values are converted to floating
108
+ # point in a round trip.
109
+ if str (out_col .dtype ) in ["object" , "float64" ]:
110
+ out_col = out_col .astype (in_col .dtype )
111
+ pd .testing .assert_series_equal (in_col , out_col )
112
+
101
113
def round_trip (self , data , schema , coll = None ):
102
114
if coll is None :
103
115
coll = self .coll
104
116
coll .drop ()
105
117
res = write (self .coll , data )
106
118
self .assertEqual (len (data ), res .raw_result ["insertedCount" ])
107
- pd . testing . assert_frame_equal (data , find_pandas_all (coll , {}, schema = schema ))
119
+ self . _assert_frames_equal (data , find_pandas_all (coll , {}, schema = schema ))
108
120
return res
109
121
110
122
def test_write_error (self ):
@@ -129,23 +141,35 @@ def _create_data(self):
129
141
if k .__name__ not in ("ObjectId" , "Decimal128" )
130
142
}
131
143
schema = {k : v .to_pandas_dtype () for k , v in arrow_schema .items ()}
144
+ schema ["Int64" ] = pd .Int64Dtype ()
145
+ schema ["int" ] = pd .Int32Dtype ()
132
146
schema ["str" ] = "U8"
133
147
schema ["datetime" ] = "datetime64[ns]"
134
148
135
149
data = pd .DataFrame (
136
150
data = {
137
- "Int64" : [i for i in range (2 )],
138
- "float" : [i for i in range (2 )],
139
- "int" : [i for i in range (2 )],
140
- "datetime" : [datetime .datetime (1970 + i , 1 , 1 ) for i in range (2 )],
141
- "str" : [f"a{ i } " for i in range (2 )],
142
- "bool" : [True , False ],
151
+ "Int64" : [i for i in range (2 )] + [ None ] ,
152
+ "float" : [i for i in range (2 )] + [ None ] ,
153
+ "int" : [i for i in range (2 )] + [ None ] ,
154
+ "datetime" : [datetime .datetime (1970 + i , 1 , 1 ) for i in range (2 )] + [ None ] ,
155
+ "str" : [f"a{ i } " for i in range (2 )] + [ None ] ,
156
+ "bool" : [True , False , None ],
143
157
}
144
158
).astype (schema )
145
159
return arrow_schema , data
146
160
147
161
def test_write_schema_validation (self ):
148
162
arrow_schema , data = self ._create_data ()
163
+
164
+ # Work around https://github.com/pandas-dev/pandas/issues/16248,
165
+ # Where pandas does not implement utcoffset for null timestamps.
166
+ def new_replace (k ):
167
+ if isinstance (k , pd .NaT .__class__ ):
168
+ return datetime .datetime (1970 , 1 , 1 )
169
+ return k .replace (tzinfo = None )
170
+
171
+ data ["datetime" ] = data .apply (lambda row : new_replace (row ["datetime" ]), axis = 1 )
172
+
149
173
self .round_trip (
150
174
data ,
151
175
Schema (arrow_schema ),
@@ -280,14 +304,12 @@ def test_csv(self):
280
304
_ , data = self ._create_data ()
281
305
with tempfile .NamedTemporaryFile (suffix = ".csv" ) as f :
282
306
f .close ()
283
- data .to_csv (f .name , index = False )
307
+ # May give RuntimeWarning due to the nulls.
308
+ with warnings .catch_warnings ():
309
+ warnings .simplefilter ("ignore" , RuntimeWarning )
310
+ data .to_csv (f .name , index = False , na_rep = "" )
284
311
out = pd .read_csv (f .name )
285
- for name in data .columns :
286
- col = data [name ]
287
- val = out [name ]
288
- if str (val .dtype ) == "object" :
289
- val = val .astype (col .dtype )
290
- pd .testing .assert_series_equal (col , val )
312
+ self ._assert_frames_equal (data , out )
291
313
292
314
293
315
class TestBSONTypes (PandasTestBase ):
0 commit comments