Skip to content

feat: Add Arrow types for efficient JSON data representation in pyarrow #312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions db_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from db_dtypes import core
from db_dtypes.version import __version__
from . import _versions_helpers

from . import _versions_helpers

date_dtype_name = "dbdate"
time_dtype_name = "dbtime"
Expand All @@ -50,7 +50,7 @@
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
from db_dtypes.json import JSONArray, JSONDtype
from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype
else:
JSONArray = None
JSONDtype = None
Expand Down Expand Up @@ -374,6 +374,8 @@ def __sub__(self, other):
"DateDtype",
"JSONDtype",
"JSONArray",
"JSONArrowType",
"JSONArrowScalar",
"TimeArray",
"TimeDtype",
]
50 changes: 49 additions & 1 deletion db_dtypes/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def construct_array_type(cls):
"""Return the array type associated with this dtype."""
return JSONArray

def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray:
"""Convert the pyarrow array to the extension array."""
return JSONArray(array)


class JSONArray(arrays.ArrowExtensionArray):
"""Extension array that handles BigQuery JSON data, leveraging a string-based
Expand Down Expand Up @@ -92,6 +96,10 @@ def __init__(self, values) -> None:
else:
raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}")

def __arrow_array__(self):
"""Convert to an arrow array. This is required for pyarrow extension."""
return self.pa_data

@classmethod
def _box_pa(
cls, value, pa_type: pa.DataType | None = None
Expand Down Expand Up @@ -151,7 +159,12 @@ def _serialize_json(value):
def _deserialize_json(value):
"""A static method that converts a JSON string back into its original value."""
if not pd.isna(value):
return json.loads(value)
# Attempt to interpret the value as a JSON object.
# If it's not valid JSON, treat it as a regular string.
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
return value

Expand Down Expand Up @@ -208,6 +221,8 @@ def __getitem__(self, item):
value = self.pa_data[item]
if isinstance(value, pa.ChunkedArray):
return type(self)(value)
elif isinstance(value, pa.ExtensionScalar):
return value.as_py()
else:
scalar = JSONArray._deserialize_json(value.as_py())
if scalar is None:
Expand Down Expand Up @@ -244,3 +259,36 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
result[mask] = self._dtype.na_value
result[~mask] = data[~mask].pa_data.to_numpy()
return result


class JSONArrowScalar(pa.ExtensionScalar):
def as_py(self):
return JSONArray._deserialize_json(self.value.as_py() if self.value else None)


class JSONArrowType(pa.ExtensionType):
"""Arrow extension type for the `dbjson` Pandas extension type."""

def __init__(self) -> None:
super().__init__(pa.string(), "dbjson")

def __arrow_ext_serialize__(self) -> bytes:
return b""

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType:
return JSONArrowType()

def __hash__(self) -> int:
return hash(str(self))

def to_pandas_dtype(self):
return JSONDtype()

def __arrow_ext_scalar_class__(self):
return JSONArrowScalar


# Register the type to be included in RecordBatches, sent over IPC and received in
# another Python process.
pa.register_extension_type(JSONArrowType())
4 changes: 0 additions & 4 deletions tests/compliance/json/test_json_compliance.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
import pytest


class TestJSONArrayAccumulate(base.BaseAccumulateTests):
pass


class TestJSONArrayCasting(base.BaseCastingTests):
def test_astype_str(self, data):
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.
Expand Down
118 changes: 117 additions & 1 deletion tests/unit/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

import json
import math

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import db_dtypes
Expand All @@ -36,7 +38,7 @@
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15.99,
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
},
Expand Down Expand Up @@ -114,3 +116,117 @@ def test_as_numpy_array():
]
)
pd._testing.assert_equal(result, expected)


def test_json_arrow_storage_type():
arrow_json_type = db_dtypes.JSONArrowType()
assert arrow_json_type.extension_name == "dbjson"
assert pa.types.is_string(arrow_json_type.storage_type)


def test_json_arrow_constructors():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
storage_array = pa.array(data, type=pa.string())

arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array)
assert isinstance(arr_1, pa.ExtensionArray)

arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array)
assert isinstance(arr_2, pa.ExtensionArray)

assert arr_1 == arr_2


def test_json_arrow_to_pandas():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pandas()
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert pd.isna(s[6])


def test_json_arrow_to_pylist():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())

s = arr.to_pylist()
assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None


def test_json_arrow_record_batch():
data = [
json.dumps(value, sort_keys=True, separators=(",", ":"))
for value in JSON_DATA.values()
]
arr = pa.array(data, type=db_dtypes.JSONArrowType())
batch = pa.RecordBatch.from_arrays([arr], ["json_col"])
sink = pa.BufferOutputStream()

with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
writer.write_batch(batch)

buf = sink.getvalue()

with pa.ipc.open_stream(buf) as reader:
result = reader.read_all()

json_col = result.column("json_col")
assert isinstance(json_col.type, db_dtypes.JSONArrowType)

s = json_col.to_pylist()

assert isinstance(s, list)
assert s[0]
assert s[1] == 100
assert math.isclose(s[2], 0.98)
assert s[3] == "hello world"
assert math.isclose(s[4][0], 0.1)
assert math.isclose(s[4][1], 0.2)
assert s[5] == {
"null_field": None,
"order": {
"items": ["book", "pen", "computer"],
"total": 15,
"address": {"street": "123 Main St", "city": "Anytown"},
},
}
assert s[6] is None
Loading