Skip to content

Commit d9992fc

Browse files
authored
feat: Add Arrow types for efficient JSON data representation in pyarrow (#312)
* feat: add ArrowJSONtype to extend pyarrow for JSONDtype * nit * add JSONArrowScalar * fix cover
1 parent b6c1428 commit d9992fc

File tree

4 files changed

+166
-7
lines changed

4 files changed

+166
-7
lines changed

db_dtypes/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131
from db_dtypes import core
3232
from db_dtypes.version import __version__
33-
from . import _versions_helpers
3433

34+
from . import _versions_helpers
3535

3636
date_dtype_name = "dbdate"
3737
time_dtype_name = "dbtime"
@@ -50,7 +50,7 @@
5050
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
5151
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
5252
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
53-
from db_dtypes.json import JSONArray, JSONDtype
53+
from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype
5454
else:
5555
JSONArray = None
5656
JSONDtype = None
@@ -374,6 +374,8 @@ def __sub__(self, other):
374374
"DateDtype",
375375
"JSONDtype",
376376
"JSONArray",
377+
"JSONArrowType",
378+
"JSONArrowScalar",
377379
"TimeArray",
378380
"TimeDtype",
379381
]

db_dtypes/json.py

+40
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def construct_array_type(cls):
6464
"""Return the array type associated with this dtype."""
6565
return JSONArray
6666

67+
def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray:
68+
"""Convert the pyarrow array to the extension array."""
69+
return JSONArray(array)
70+
6771

6872
class JSONArray(arrays.ArrowExtensionArray):
6973
"""Extension array that handles BigQuery JSON data, leveraging a string-based
@@ -92,6 +96,10 @@ def __init__(self, values) -> None:
9296
else:
9397
raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}")
9498

99+
def __arrow_array__(self, type=None):
100+
"""Convert to an arrow array. This is required for pyarrow extension."""
101+
return pa.array(self.pa_data, type=JSONArrowType())
102+
95103
@classmethod
96104
def _box_pa(
97105
cls, value, pa_type: pa.DataType | None = None
@@ -208,6 +216,8 @@ def __getitem__(self, item):
208216
value = self.pa_data[item]
209217
if isinstance(value, pa.ChunkedArray):
210218
return type(self)(value)
219+
elif isinstance(value, pa.ExtensionScalar):
220+
return value.as_py()
211221
else:
212222
scalar = JSONArray._deserialize_json(value.as_py())
213223
if scalar is None:
@@ -244,3 +254,33 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
244254
result[mask] = self._dtype.na_value
245255
result[~mask] = data[~mask].pa_data.to_numpy()
246256
return result
257+
258+
259+
class JSONArrowScalar(pa.ExtensionScalar):
260+
def as_py(self):
261+
return JSONArray._deserialize_json(self.value.as_py() if self.value else None)
262+
263+
264+
class JSONArrowType(pa.ExtensionType):
265+
"""Arrow extension type for the `dbjson` Pandas extension type."""
266+
267+
def __init__(self) -> None:
268+
super().__init__(pa.string(), "dbjson")
269+
270+
def __arrow_ext_serialize__(self) -> bytes:
271+
return b""
272+
273+
@classmethod
274+
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType:
275+
return JSONArrowType()
276+
277+
def to_pandas_dtype(self):
278+
return JSONDtype()
279+
280+
def __arrow_ext_scalar_class__(self):
281+
return JSONArrowScalar
282+
283+
284+
# Register the type to be included in RecordBatches, sent over IPC and received in
285+
# another Python process.
286+
pa.register_extension_type(JSONArrowType())

tests/compliance/json/test_json_compliance.py

-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
import pytest
2323

2424

25-
class TestJSONArrayAccumulate(base.BaseAccumulateTests):
26-
pass
27-
28-
2925
class TestJSONArrayCasting(base.BaseCastingTests):
3026
def test_astype_str(self, data):
3127
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.

tests/unit/test_json.py

+122-1
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414

1515
import json
16+
import math
1617

1718
import numpy as np
1819
import pandas as pd
20+
import pyarrow as pa
1921
import pytest
2022

2123
import db_dtypes
@@ -36,7 +38,7 @@
3638
"null_field": None,
3739
"order": {
3840
"items": ["book", "pen", "computer"],
39-
"total": 15.99,
41+
"total": 15,
4042
"address": {"street": "123 Main St", "city": "Anytown"},
4143
},
4244
},
@@ -114,3 +116,122 @@ def test_as_numpy_array():
114116
]
115117
)
116118
pd._testing.assert_equal(result, expected)
119+
120+
121+
def test_json_arrow_array():
122+
data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values())
123+
assert isinstance(data.__arrow_array__(), pa.ExtensionArray)
124+
125+
126+
def test_json_arrow_storage_type():
127+
arrow_json_type = db_dtypes.JSONArrowType()
128+
assert arrow_json_type.extension_name == "dbjson"
129+
assert pa.types.is_string(arrow_json_type.storage_type)
130+
131+
132+
def test_json_arrow_constructors():
133+
data = [
134+
json.dumps(value, sort_keys=True, separators=(",", ":"))
135+
for value in JSON_DATA.values()
136+
]
137+
storage_array = pa.array(data, type=pa.string())
138+
139+
arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array)
140+
assert isinstance(arr_1, pa.ExtensionArray)
141+
142+
arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array)
143+
assert isinstance(arr_2, pa.ExtensionArray)
144+
145+
assert arr_1 == arr_2
146+
147+
148+
def test_json_arrow_to_pandas():
149+
data = [
150+
json.dumps(value, sort_keys=True, separators=(",", ":"))
151+
for value in JSON_DATA.values()
152+
]
153+
arr = pa.array(data, type=db_dtypes.JSONArrowType())
154+
155+
s = arr.to_pandas()
156+
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
157+
assert s[0]
158+
assert s[1] == 100
159+
assert math.isclose(s[2], 0.98)
160+
assert s[3] == "hello world"
161+
assert math.isclose(s[4][0], 0.1)
162+
assert math.isclose(s[4][1], 0.2)
163+
assert s[5] == {
164+
"null_field": None,
165+
"order": {
166+
"items": ["book", "pen", "computer"],
167+
"total": 15,
168+
"address": {"street": "123 Main St", "city": "Anytown"},
169+
},
170+
}
171+
assert pd.isna(s[6])
172+
173+
174+
def test_json_arrow_to_pylist():
175+
data = [
176+
json.dumps(value, sort_keys=True, separators=(",", ":"))
177+
for value in JSON_DATA.values()
178+
]
179+
arr = pa.array(data, type=db_dtypes.JSONArrowType())
180+
181+
s = arr.to_pylist()
182+
assert isinstance(s, list)
183+
assert s[0]
184+
assert s[1] == 100
185+
assert math.isclose(s[2], 0.98)
186+
assert s[3] == "hello world"
187+
assert math.isclose(s[4][0], 0.1)
188+
assert math.isclose(s[4][1], 0.2)
189+
assert s[5] == {
190+
"null_field": None,
191+
"order": {
192+
"items": ["book", "pen", "computer"],
193+
"total": 15,
194+
"address": {"street": "123 Main St", "city": "Anytown"},
195+
},
196+
}
197+
assert s[6] is None
198+
199+
200+
def test_json_arrow_record_batch():
201+
data = [
202+
json.dumps(value, sort_keys=True, separators=(",", ":"))
203+
for value in JSON_DATA.values()
204+
]
205+
arr = pa.array(data, type=db_dtypes.JSONArrowType())
206+
batch = pa.RecordBatch.from_arrays([arr], ["json_col"])
207+
sink = pa.BufferOutputStream()
208+
209+
with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
210+
writer.write_batch(batch)
211+
212+
buf = sink.getvalue()
213+
214+
with pa.ipc.open_stream(buf) as reader:
215+
result = reader.read_all()
216+
217+
json_col = result.column("json_col")
218+
assert isinstance(json_col.type, db_dtypes.JSONArrowType)
219+
220+
s = json_col.to_pylist()
221+
222+
assert isinstance(s, list)
223+
assert s[0]
224+
assert s[1] == 100
225+
assert math.isclose(s[2], 0.98)
226+
assert s[3] == "hello world"
227+
assert math.isclose(s[4][0], 0.1)
228+
assert math.isclose(s[4][1], 0.2)
229+
assert s[5] == {
230+
"null_field": None,
231+
"order": {
232+
"items": ["book", "pen", "computer"],
233+
"total": 15,
234+
"address": {"street": "123 Main St", "city": "Anytown"},
235+
},
236+
}
237+
assert s[6] is None

0 commit comments

Comments
 (0)