Skip to content

Commit 60e7706

Browse files
authored
ARROW-134 Cannot encode pandas NA objects (#118)
* ARROW-134 Cannot encode pandas NA objects * address review * try to fix csv test * handle runtimewarning * address review * cleanup
1 parent f7111d6 commit 60e7706

File tree

2 files changed

+59
-18
lines changed

2 files changed

+59
-18
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pymongo.errors
1818
from bson import encode
19+
from bson.codec_options import TypeEncoder, TypeRegistry
1920
from bson.raw_bson import RawBSONDocument
2021
from pyarrow import Schema as ArrowSchema
2122
from pyarrow import Table
@@ -26,9 +27,10 @@
2627
ndarray = None
2728

2829
try:
29-
from pandas import DataFrame
30+
from pandas import NA, DataFrame
3031
except ImportError:
3132
DataFrame = None
33+
NA = None
3234

3335
from pymongo.bulk import BulkWriteError
3436
from pymongo.common import MAX_WRITE_BATCH_SIZE
@@ -316,6 +318,18 @@ def _tabular_generator(tabular):
316318
return
317319

318320

321+
class _PandasNACodec(TypeEncoder):
322+
"""A custom type codec for Pandas NA objects."""
323+
324+
@property
325+
def python_type(self):
326+
return NA.__class__
327+
328+
def transform_python(self, _):
329+
"""Transform an NA object into 'None'"""
330+
return None
331+
332+
319333
def write(collection, tabular):
320334
"""Write data from `tabular` into the given MongoDB `collection`.
321335
@@ -352,6 +366,13 @@ def write(collection, tabular):
352366
)
353367

354368
tabular_gen = _tabular_generator(tabular)
369+
370+
# Handle Pandas NA objects.
371+
codec_options = collection.codec_options
372+
if DataFrame is not None:
373+
type_registry = TypeRegistry([_PandasNACodec()])
374+
codec_options = codec_options.with_options(type_registry=type_registry)
375+
355376
while cur_offset < tab_size:
356377
cur_size = 0
357378
cur_batch = []
@@ -361,9 +382,7 @@ def write(collection, tabular):
361382
and len(cur_batch) <= _MAX_WRITE_BATCH_SIZE
362383
and cur_offset + i < tab_size
363384
):
364-
enc_tab = RawBSONDocument(
365-
encode(next(tabular_gen), codec_options=collection.codec_options)
366-
)
385+
enc_tab = RawBSONDocument(encode(next(tabular_gen), codec_options=codec_options))
367386
cur_batch.append(enc_tab)
368387
cur_size += len(enc_tab.raw)
369388
i += 1

bindings/python/test/test_pandas.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import tempfile
1717
import unittest
1818
import unittest.mock as mock
19+
import warnings
1920
from test import client_context
2021
from test.utils import AllowListEventListener, TestNullsBase
2122

@@ -98,13 +99,24 @@ def test_aggregate_simple(self):
9899
self.assertEqual(agg_cmd.command["pipeline"][0]["$project"], projection)
99100
self.assertEqual(agg_cmd.command["pipeline"][1]["$project"], {"_id": True, "data": True})
100101

102+
def _assert_frames_equal(self, incoming, outgoing):
103+
for name in incoming.columns:
104+
in_col = incoming[name]
105+
out_col = outgoing[name]
106+
# Object types may lose type information in a round trip.
107+
# Integer types with missing values are converted to floating
108+
# point in a round trip.
109+
if str(out_col.dtype) in ["object", "float64"]:
110+
out_col = out_col.astype(in_col.dtype)
111+
pd.testing.assert_series_equal(in_col, out_col)
112+
101113
def round_trip(self, data, schema, coll=None):
102114
if coll is None:
103115
coll = self.coll
104116
coll.drop()
105117
res = write(self.coll, data)
106118
self.assertEqual(len(data), res.raw_result["insertedCount"])
107-
pd.testing.assert_frame_equal(data, find_pandas_all(coll, {}, schema=schema))
119+
self._assert_frames_equal(data, find_pandas_all(coll, {}, schema=schema))
108120
return res
109121

110122
def test_write_error(self):
@@ -129,23 +141,35 @@ def _create_data(self):
129141
if k.__name__ not in ("ObjectId", "Decimal128")
130142
}
131143
schema = {k: v.to_pandas_dtype() for k, v in arrow_schema.items()}
144+
schema["Int64"] = pd.Int64Dtype()
145+
schema["int"] = pd.Int32Dtype()
132146
schema["str"] = "U8"
133147
schema["datetime"] = "datetime64[ns]"
134148

135149
data = pd.DataFrame(
136150
data={
137-
"Int64": [i for i in range(2)],
138-
"float": [i for i in range(2)],
139-
"int": [i for i in range(2)],
140-
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)],
141-
"str": [f"a{i}" for i in range(2)],
142-
"bool": [True, False],
151+
"Int64": [i for i in range(2)] + [None],
152+
"float": [i for i in range(2)] + [None],
153+
"int": [i for i in range(2)] + [None],
154+
"datetime": [datetime.datetime(1970 + i, 1, 1) for i in range(2)] + [None],
155+
"str": [f"a{i}" for i in range(2)] + [None],
156+
"bool": [True, False, None],
143157
}
144158
).astype(schema)
145159
return arrow_schema, data
146160

147161
def test_write_schema_validation(self):
148162
arrow_schema, data = self._create_data()
163+
164+
# Work around https://github.com/pandas-dev/pandas/issues/16248,
165+
# Where pandas does not implement utcoffset for null timestamps.
166+
def new_replace(k):
167+
if isinstance(k, pd.NaT.__class__):
168+
return datetime.datetime(1970, 1, 1)
169+
return k.replace(tzinfo=None)
170+
171+
data["datetime"] = data.apply(lambda row: new_replace(row["datetime"]), axis=1)
172+
149173
self.round_trip(
150174
data,
151175
Schema(arrow_schema),
@@ -280,14 +304,12 @@ def test_csv(self):
280304
_, data = self._create_data()
281305
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
282306
f.close()
283-
data.to_csv(f.name, index=False)
307+
# May give RuntimeWarning due to the nulls.
308+
with warnings.catch_warnings():
309+
warnings.simplefilter("ignore", RuntimeWarning)
310+
data.to_csv(f.name, index=False, na_rep="")
284311
out = pd.read_csv(f.name)
285-
for name in data.columns:
286-
col = data[name]
287-
val = out[name]
288-
if str(val.dtype) == "object":
289-
val = val.astype(col.dtype)
290-
pd.testing.assert_series_equal(col, val)
312+
self._assert_frames_equal(data, out)
291313

292314

293315
class TestBSONTypes(PandasTestBase):

0 commit comments

Comments
 (0)