Skip to content

Commit 4475f9c

Browse files
committed
add JSONArrowScalar
1 parent 6a7e82d commit 4475f9c

File tree

4 files changed

+118
-32
lines changed

4 files changed

+118
-32
lines changed

db_dtypes/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
# To use JSONArray and JSONDtype, you'll need Pandas 1.5.0 or later. With the removal
5151
# of Python 3.7 compatibility, the minimum Pandas version will be updated to 1.5.0.
5252
if packaging.version.Version(pandas.__version__) >= packaging.version.Version("1.5.0"):
53-
from db_dtypes.json import ArrowJSONType, JSONArray, JSONDtype
53+
from db_dtypes.json import JSONArray, JSONArrowScalar, JSONArrowType, JSONDtype
5454
else:
5555
JSONArray = None
5656
JSONDtype = None
@@ -359,7 +359,7 @@ def __sub__(self, other):
359359
)
360360

361361

362-
if not JSONArray or not JSONDtype or not ArrowJSONType:
362+
if not JSONArray or not JSONDtype:
363363
__all__ = [
364364
"__version__",
365365
"DateArray",
@@ -370,11 +370,12 @@ def __sub__(self, other):
370370
else:
371371
__all__ = [
372372
"__version__",
373-
"ArrowJSONType",
374373
"DateArray",
375374
"DateDtype",
376375
"JSONDtype",
377376
"JSONArray",
377+
"JSONArrowType",
378+
"JSONArrowScalar",
378379
"TimeArray",
379380
"TimeDtype",
380381
]

db_dtypes/json.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def __getitem__(self, item):
221221
value = self.pa_data[item]
222222
if isinstance(value, pa.ChunkedArray):
223223
return type(self)(value)
224+
elif isinstance(value, pa.ExtensionScalar):
225+
return value.as_py()
224226
else:
225227
scalar = JSONArray._deserialize_json(value.as_py())
226228
if scalar is None:
@@ -259,28 +261,34 @@ def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
259261
return result
260262

261263

262-
class ArrowJSONType(pa.ExtensionType):
264+
class JSONArrowScalar(pa.ExtensionScalar):
265+
def as_py(self):
266+
return JSONArray._deserialize_json(self.value.as_py() if self.value else None)
267+
268+
269+
class JSONArrowType(pa.ExtensionType):
263270
"""Arrow extension type for the `dbjson` Pandas extension type."""
264271

265272
def __init__(self) -> None:
266273
super().__init__(pa.string(), "dbjson")
267274

268275
def __arrow_ext_serialize__(self) -> bytes:
269-
# No parameters are necessary
270276
return b""
271277

272278
@classmethod
273-
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowJSONType:
274-
# return an instance of this subclass
275-
return ArrowJSONType()
279+
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> JSONArrowType:
280+
return JSONArrowType()
276281

277282
def __hash__(self) -> int:
278283
return hash(str(self))
279284

280285
def to_pandas_dtype(self):
281286
return JSONDtype()
282287

288+
def __arrow_ext_scalar_class__(self):
289+
return JSONArrowScalar
290+
283291

284292
# Register the type to be included in RecordBatches, sent over IPC and received in
285293
# another Python process.
286-
pa.register_extension_type(ArrowJSONType())
294+
pa.register_extension_type(JSONArrowType())

tests/compliance/json/test_json_compliance.py

-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
import pytest
2323

2424

25-
class TestJSONArrayAccumulate(base.BaseAccumulateTests):
26-
pass
27-
28-
2925
class TestJSONArrayCasting(base.BaseCastingTests):
3026
def test_astype_str(self, data):
3127
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.

tests/unit/test_json.py

+100-19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import math
1617

1718
import numpy as np
1819
import pandas as pd
@@ -37,7 +38,7 @@
3738
"null_field": None,
3839
"order": {
3940
"items": ["book", "pen", "computer"],
40-
"total": 15.99,
41+
"total": 15,
4142
"address": {"street": "123 Main St", "city": "Anytown"},
4243
},
4344
},
@@ -117,35 +118,115 @@ def test_as_numpy_array():
117118
pd._testing.assert_equal(result, expected)
118119

119120

120-
def test_arrow_json_storage_type():
121-
arrow_json_type = db_dtypes.ArrowJSONType()
121+
def test_json_arrow_storage_type():
122+
arrow_json_type = db_dtypes.JSONArrowType()
122123
assert arrow_json_type.extension_name == "dbjson"
123124
assert pa.types.is_string(arrow_json_type.storage_type)
124125

125126

126-
def test_arrow_json_constructors():
127-
storage_array = pa.array(
128-
["0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
129-
)
130-
arr_1 = db_dtypes.ArrowJSONType().wrap_array(storage_array)
127+
def test_json_arrow_constructors():
128+
data = [
129+
json.dumps(value, sort_keys=True, separators=(",", ":"))
130+
for value in JSON_DATA.values()
131+
]
132+
storage_array = pa.array(data, type=pa.string())
133+
134+
arr_1 = db_dtypes.JSONArrowType().wrap_array(storage_array)
131135
assert isinstance(arr_1, pa.ExtensionArray)
132136

133-
arr_2 = pa.ExtensionArray.from_storage(db_dtypes.ArrowJSONType(), storage_array)
137+
arr_2 = pa.ExtensionArray.from_storage(db_dtypes.JSONArrowType(), storage_array)
134138
assert isinstance(arr_2, pa.ExtensionArray)
135139

136140
assert arr_1 == arr_2
137141

138142

139-
def test_arrow_json_to_pandas():
140-
storage_array = pa.array(
141-
[None, "0", "str", '{"b": 2}', '{"a": [1, 2, 3]}'], type=pa.string()
142-
)
143-
arr = db_dtypes.ArrowJSONType().wrap_array(storage_array)
143+
def test_json_arrow_to_pandas():
144+
data = [
145+
json.dumps(value, sort_keys=True, separators=(",", ":"))
146+
for value in JSON_DATA.values()
147+
]
148+
arr = pa.array(data, type=db_dtypes.JSONArrowType())
144149

145150
s = arr.to_pandas()
146151
assert isinstance(s.dtypes, db_dtypes.JSONDtype)
147-
assert pd.isna(s[0])
148-
assert s[1] == 0
149-
assert s[2] == "str"
150-
assert s[3]["b"] == 2
151-
assert s[4]["a"] == [1, 2, 3]
152+
assert s[0]
153+
assert s[1] == 100
154+
assert math.isclose(s[2], 0.98)
155+
assert s[3] == "hello world"
156+
assert math.isclose(s[4][0], 0.1)
157+
assert math.isclose(s[4][1], 0.2)
158+
assert s[5] == {
159+
"null_field": None,
160+
"order": {
161+
"items": ["book", "pen", "computer"],
162+
"total": 15,
163+
"address": {"street": "123 Main St", "city": "Anytown"},
164+
},
165+
}
166+
assert pd.isna(s[6])
167+
168+
169+
def test_json_arrow_to_pylist():
170+
data = [
171+
json.dumps(value, sort_keys=True, separators=(",", ":"))
172+
for value in JSON_DATA.values()
173+
]
174+
arr = pa.array(data, type=db_dtypes.JSONArrowType())
175+
176+
s = arr.to_pylist()
177+
assert isinstance(s, list)
178+
assert s[0]
179+
assert s[1] == 100
180+
assert math.isclose(s[2], 0.98)
181+
assert s[3] == "hello world"
182+
assert math.isclose(s[4][0], 0.1)
183+
assert math.isclose(s[4][1], 0.2)
184+
assert s[5] == {
185+
"null_field": None,
186+
"order": {
187+
"items": ["book", "pen", "computer"],
188+
"total": 15,
189+
"address": {"street": "123 Main St", "city": "Anytown"},
190+
},
191+
}
192+
assert s[6] is None
193+
194+
195+
def test_json_arrow_record_batch():
196+
data = [
197+
json.dumps(value, sort_keys=True, separators=(",", ":"))
198+
for value in JSON_DATA.values()
199+
]
200+
arr = pa.array(data, type=db_dtypes.JSONArrowType())
201+
batch = pa.RecordBatch.from_arrays([arr], ["json_col"])
202+
sink = pa.BufferOutputStream()
203+
204+
with pa.RecordBatchStreamWriter(sink, batch.schema) as writer:
205+
writer.write_batch(batch)
206+
207+
buf = sink.getvalue()
208+
209+
with pa.ipc.open_stream(buf) as reader:
210+
result = reader.read_all()
211+
212+
json_col = result.column("json_col")
213+
assert isinstance(json_col.type, db_dtypes.JSONArrowType)
214+
215+
s = json_col.to_pylist()
216+
217+
assert isinstance(s, list)
218+
assert s[0]
219+
assert s[1] == 100
220+
assert math.isclose(s[2], 0.98)
221+
assert s[3] == "hello world"
222+
assert math.isclose(s[4][0], 0.1)
223+
assert math.isclose(s[4][1], 0.2)
224+
assert s[5] == {
225+
"null_field": None,
226+
"order": {
227+
"items": ["book", "pen", "computer"],
228+
"total": 15,
229+
"address": {"street": "123 Main St", "city": "Anytown"},
230+
},
231+
}
232+
assert s[6] is None

0 commit comments

Comments
 (0)