diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index 952643b..d5b05dc 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -30,8 +30,8 @@ from db_dtypes import core from db_dtypes.version import __version__ -from . import _versions_helpers +from . import _versions_helpers date_dtype_name = "dbdate" time_dtype_name = "dbtime" @@ -50,7 +50,7 @@ # To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal # of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0. if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"): - from db_dtypes.json import JSONArray, JSONDtype + from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype else: JSONArray = None JSONDtype = None @@ -374,6 +374,8 @@ def __sub__(self, other): "DateDtype", "JSONDtype", "JSONArray", + "JSONArrowType", + "JSONArrowScalar", "TimeArray", "TimeDtype", ] diff --git a/db_dtypes/json.py b/db_dtypes/json.py index c43ebc2..145eec3 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -64,6 +64,10 @@ def construct_array_type(cls): """Return the array type associated with this dtype.""" return JSONArray + def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> JSONArray: + """Convert the pyarrow array to the extension array.""" + return JSONArray(array) + class JSONArray(arrays.ArrowExtensionArray): """Extension array that handles BigQuery JSON data, leveraging a string-based @@ -92,6 +96,10 @@ def __init__(self, values) -> None: else: raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}") + def __arrow_array__(self, type=None): + """Convert to an arrow array. This is required for pyarrow extension.""" + return pa.array(self.pa_data, type=JSONArrowType()) + @classmethod def _box_pa( cls, value, pa_type: pa.DataType | None = None @@ -208,6 +216,8 @@ def __getitem__(self, item): value = self.pa_data[item] if isinstance(value, pa.ChunkedArray): return type(self)(value) + elif isinstance(value, pa.ExtensionScalar): + return value.as_py() else: scalar = JSONArray._deserialize_json(value.as_py()) if scalar is None: @@ -244,3 +254,33 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray: result[mask] = self._dtype.na_value result[~mask] = data[~mask].pa_data.to_numpy() return result + + +class JSONArrowScalar(pa.ExtensionScalar): + def as_py(self): + return JSONArray._deserialize_json(self.value.as_py() if self.value else None) + + +class JSONArrowType(pa.ExtensionType): + """Arrow extension type for the `dbjson` Pandas extension type.""" + + def __init__(self) -> None: + super().__init__(pa.string(), "dbjson") + + def __arrow_ext_serialize__(self) -> bytes: + return b"" + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType: + return JSONArrowType() + + def to_pandas_dtype(self): + return JSONDtype() + + def __arrow_ext_scalar_class__(self): + return JSONArrowScalar + + +# Register the type to be included in RecordBatches, sent over IPC and received in +# another Python process. +pa.register_extension_type(JSONArrowType()) diff --git a/tests/compliance/json/test_json_compliance.py b/tests/compliance/json/test_json_compliance.py index 2a8e69a..9a0d0ef 100644 --- a/tests/compliance/json/test_json_compliance.py +++ b/tests/compliance/json/test_json_compliance.py @@ -22,10 +22,6 @@ import pytest -class TestJSONArrayAccumulate(base.BaseAccumulateTests): - pass - - class TestJSONArrayCasting(base.BaseCastingTests): def test_astype_str(self, data): # Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method. diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index 112b50c..055eef0 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -13,9 +13,11 @@ # limitations under the License. import json +import math import numpy as np import pandas as pd +import pyarrow as pa import pytest import db_dtypes @@ -36,7 +38,7 @@ "null_field": None, "order": { "items": ["book", "pen", "computer"], - "total": 15.99, + "total": 15, "address": {"street": "123 Main St", "city": "Anytown"}, }, }, @@ -114,3 +116,122 @@ def test_as_numpy_array(): ] ) pd._testing.assert_equal(result, expected) + + +def test_json_arrow_array(): + data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values()) + assert isinstance(data.__arrow_array__(), pa.ExtensionArray) + + +def test_json_arrow_storage_type(): + arrow_json_type = db_dtypes.JSONArrowType() + assert arrow_json_type.extension_name == "dbjson" + assert pa.types.is_string(arrow_json_type.storage_type) + + +def test_json_arrow_constructors(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + storage_array = pa.array(data, type=pa.string()) + + arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array) + assert isinstance(arr_1, pa.ExtensionArray) + + arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array) + assert isinstance(arr_2, pa.ExtensionArray) + + assert arr_1 == arr_2 + + +def test_json_arrow_to_pandas(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + + s = arr.to_pandas() + assert isinstance(s.dtypes, db_dtypes.JSONDtype) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert pd.isna(s[6]) + + +def test_json_arrow_to_pylist(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + + s = arr.to_pylist() + assert isinstance(s, list) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert s[6] is None + + +def test_json_arrow_record_batch(): + data = [ + json.dumps(value, sort_keys=True, separators=(",", ":")) + for value in JSON_DATA.values() + ] + arr = pa.array(data, type=db_dtypes.JSONArrowType()) + batch = pa.RecordBatch.from_arrays([arr], ["json_col"]) + sink = pa.BufferOutputStream() + + with pa.RecordBatchStreamWriter(sink, batch.schema) as writer: + writer.write_batch(batch) + + buf = sink.getvalue() + + with pa.ipc.open_stream(buf) as reader: + result = reader.read_all() + + json_col = result.column("json_col") + assert isinstance(json_col.type, db_dtypes.JSONArrowType) + + s = json_col.to_pylist() + + assert isinstance(s, list) + assert s[0] + assert s[1] == 100 + assert math.isclose(s[2], 0.98) + assert s[3] == "hello world" + assert math.isclose(s[4][0], 0.1) + assert math.isclose(s[4][1], 0.2) + assert s[5] == { + "null_field": None, + "order": { + "items": ["book", "pen", "computer"], + "total": 15, + "address": {"street": "123 Main St", "city": "Anytown"}, + }, + } + assert s[6] is None