Skip to content

Commit 8cda5e8

Browse files
committed
support array type
1 parent 1baecd5 commit 8cda5e8

File tree

3 files changed

+83
-58
lines changed

3 files changed

+83
-58
lines changed

db_dtypes/json.py

+3-38
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"ge": pyarrow.compute.greater_equal,
3535
}
3636

37+
3738
@pd.api.extensions.register_extension_dtype
3839
class JSONDtype(pd.api.extensions.ExtensionDtype):
3940
"""Extension dtype for BigQuery JSON data."""
@@ -90,6 +91,7 @@ def _box_pa(
9091
cls, value, pa_type: pa.DataType | None = None
9192
) -> pa.Array | pa.ChunkedArray | pa.Scalar:
9293
"""Box value into a pyarrow Array, ChunkedArray or Scalar."""
94+
9395
if isinstance(value, pa.Scalar) or not (
9496
common.is_list_like(value) and not common.is_dict_like(value)
9597
):
@@ -163,7 +165,7 @@ def _from_factorized(cls, values, original):
163165
@staticmethod
164166
def _serialize_json(value):
165167
"""A static method that converts a JSON value into a string representation."""
166-
if pd.isna(value):
168+
if not common.is_list_like(value) and pd.isna(value):
167169
return value
168170
else:
169171
# `sort_keys=True` sorts dictionary keys before serialization, making
@@ -254,40 +256,3 @@ def _reduce(
254256
if name in ["min", "max"]:
255257
raise TypeError("JSONArray does not support min/max reducntion.")
256258
super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
257-
258-
def __array__(
259-
self, dtype = None, copy = None
260-
) -> np.ndarray:
261-
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
262-
return self.to_numpy(dtype=dtype)
263-
264-
def to_numpy(self, dtype = None, copy = False, na_value = pd.NA) -> np.ndarray:
265-
dtype, na_value = self._to_numpy_dtype_inference(dtype, na_value, self._hasna)
266-
pa_type = self._pa_array.type
267-
if not self._hasna or pd.isna(na_value) or pa.types.is_null(pa_type):
268-
data = self
269-
else:
270-
data = self.fillna(na_value)
271-
result = np.array(list(data), dtype=dtype)
272-
273-
if data._hasna:
274-
result[data.isna()] = na_value
275-
return result
276-
277-
def _to_numpy_dtype_inference(
278-
self, dtype, na_value, hasna
279-
):
280-
if dtype is not None:
281-
dtype = np.dtype(dtype)
282-
283-
if dtype is None or not hasna:
284-
na_value = self.dtype.na_value
285-
elif dtype.kind == "f": # type: ignore[union-attr]
286-
na_value = np.nan
287-
elif dtype.kind == "M": # type: ignore[union-attr]
288-
na_value = np.datetime64("nat")
289-
elif dtype.kind == "m": # type: ignore[union-attr]
290-
na_value = np.timedelta64("nat")
291-
else:
292-
na_value = self.dtype.na_value
293-
return dtype, na_value

tests/compliance/json/conftest.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import json
17+
import random
1718

1819
import numpy as np
1920
import pandas as pd
@@ -24,18 +25,29 @@
2425

2526

2627
def make_data():
27-
# Sample data with varied lengths.
28+
# Since the `np.array` constructor needs a consistent shape after the first
29+
# dimension, the samples data in this instance doesn't include the array type.
2830
samples = [
29-
{"id": 1, "bool_value": True}, # Boolean
30-
{"id": 2, "float_num": 3.14159}, # Floating
31-
{"id": 3, "date": "2024-07-16"}, # Dates (as strings)
32-
{"id": 4, "null_field": None}, # Null
33-
{"list_data": [10, 20, 30]}, # Lists
34-
{"person": {"name": "Alice", "age": 35}}, # Nested objects
31+
True, # Boolean
32+
100, # Int
33+
0.98, # Float
34+
"str", # String
35+
{"bool_value": True}, # Dict with a boolean
36+
{"float_num": 3.14159}, # Dict with a float
37+
{"date": "2024-07-16"}, # Dict with a date (as strings)
38+
{"null_field": None}, # Dict with a null
39+
{"list_data": [10, 20, 30]}, # Dict with a list
40+
{"person": {"name": "Alice", "age": 35}}, # Dict with nested objects
3541
{"address": {"street": "123 Main St", "city": "Anytown"}},
3642
{"order": {"items": ["book", "pen"], "total": 15.99}},
3743
]
38-
return np.random.default_rng(2).choice(samples, size=100)
44+
data = np.random.default_rng(2).choice(samples, size=100)
45+
# This replaces a single data item with an array. We are skipping the first two
46+
# items to avoid some `setitem` tests failed, because setting with a list is
47+
# ambiguity in this context.
48+
id = random.randint(3, 99)
49+
data[id] = [0.1, 0.2] # Array
50+
return data
3951

4052

4153
@pytest.fixture
@@ -48,16 +60,6 @@ def data():
4860
"""Length-100 PeriodArray for semantics test."""
4961
data = make_data()
5062

51-
# Why the while loop? NumPy is unable to construct an ndarray from
52-
# equal-length ndarrays. Many of our operations involve coercing the
53-
# EA to an ndarray of objects. To avoid random test failures, we ensure
54-
# that our data is coercible to an ndarray. Several tests deal with only
55-
# the first two elements, so that's what we'll check.
56-
57-
while len(data[0]) == len(data[1]):
58-
print(data)
59-
data = make_data()
60-
6163
return JSONArray._from_sequence(data)
6264

6365

tests/compliance/json/test_json_compliance.py

+60-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
2222
from pandas.tests.extension import base
2323
import pytest
24-
import db_dtypes
2524

2625

2726
class TestJSONArray(base.ExtensionTests):
@@ -126,6 +125,43 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
126125
def test_searchsorted(self, data_for_sorting, as_series):
127126
super().test_searchsorted(self, data_for_sorting, as_series)
128127

128+
def test_astype_str(self, data):
129+
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.
130+
result = pd.Series(data[:5]).astype(str)
131+
expected = pd.Series(
132+
[json.dumps(x, sort_keys=True) for x in data[:5]], dtype=str
133+
)
134+
tm.assert_series_equal(result, expected)
135+
136+
@pytest.mark.parametrize(
137+
"nullable_string_dtype",
138+
[
139+
"string[python]",
140+
"string[pyarrow]",
141+
],
142+
)
143+
def test_astype_string(self, data, nullable_string_dtype):
144+
# Use `json.dumps(str)` instead of passing `str(obj)` directly to the super method.
145+
result = pd.Series(data[:5]).astype(nullable_string_dtype)
146+
expected = pd.Series(
147+
[json.dumps(x, sort_keys=True) for x in data[:5]],
148+
dtype=nullable_string_dtype,
149+
)
150+
tm.assert_series_equal(result, expected)
151+
152+
def test_array_interface(self, data):
153+
result = np.array(data)
154+
# Use `json.dumps(data[0])` instead of passing `data[0]` directly to the super method.
155+
assert result[0] == json.dumps(data[0])
156+
157+
result = np.array(data, dtype=object)
158+
# Use `json.dumps(x)` instead of passing `x` directly to the super method.
159+
expected = np.array([json.dumps(x) for x in data], dtype=object)
160+
if expected.ndim > 1:
161+
# nested data, explicitly construct as 1D
162+
expected = construct_1d_object_array_from_listlike(list(data))
163+
tm.assert_numpy_array_equal(result, expected)
164+
129165
@pytest.mark.xfail(reason="Setting a dict as a scalar")
130166
def test_fillna_series(self):
131167
"""We treat dictionaries as a mapping in fillna, not a scalar."""
@@ -251,7 +287,6 @@ def test_setitem_mask_boolean_array_with_na(self, data, box_in_series):
251287
super().test_setitem_mask_boolean_array_with_na(data, box_in_series)
252288

253289
@pytest.mark.parametrize("setter", ["loc", "iloc"])
254-
255290
@pytest.mark.xfail(reason="TODO: open an issue for ArrowExtentionArray")
256291
def test_setitem_scalar(self, data, setter):
257292
super().test_setitem_scalar(data, setter)
@@ -310,3 +345,26 @@ def test_setitem_2d_values(self, data):
310345
@pytest.mark.parametrize("engine", ["c", "python"])
311346
def test_EA_types(self, engine, data, request):
312347
super().test_EA_types(engine, data, request)
348+
349+
@pytest.mark.xfail(
350+
reason="`to_numpy` returns serialized JSON, "
351+
+ "while `__getitem__` returns JSON objects."
352+
)
353+
def test_setitem_frame_2d_values(self, data):
354+
super().test_setitem_frame_2d_values(data)
355+
356+
@pytest.mark.xfail(
357+
reason="`to_numpy` returns serialized JSON, "
358+
+ "while `__getitem__` returns JSON objects."
359+
)
360+
def test_transpose_frame(self, data):
361+
# `DataFrame.T` calls `to_numpy` to get results.
362+
super().test_transpose_frame(data)
363+
364+
@pytest.mark.xfail(
365+
reason="`to_numpy` returns serialized JSON, "
366+
+ "while `__getitem__` returns JSON objects."
367+
)
368+
def test_where_series(self, data, na_value, as_frame):
369+
# `Series.where` calls `to_numpy` to get results.
370+
super().test_where_series(data, na_value, as_frame)

0 commit comments

Comments
 (0)