Skip to content

Commit 17f560e

Browse files
committed
address comments
1 parent b4cfcd9 commit 17f560e

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed

db_dtypes/json.py

+15-24
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,6 @@
2424
import pyarrow as pa
2525
import pyarrow.compute
2626

27-
ARROW_CMP_FUNCS = {
28-
"eq": pyarrow.compute.equal,
29-
"ne": pyarrow.compute.not_equal,
30-
"lt": pyarrow.compute.less,
31-
"gt": pyarrow.compute.greater,
32-
"le": pyarrow.compute.less_equal,
33-
"ge": pyarrow.compute.greater_equal,
34-
}
35-
3627

3728
@pd.api.extensions.register_extension_dtype
3829
class JSONDtype(pd.api.extensions.ExtensionDtype):
@@ -68,11 +59,6 @@ def construct_array_type(cls):
6859
"""Return the array type associated with this dtype."""
6960
return JSONArray
7061

71-
# @staticmethod
72-
# def __from_arrow__(array: typing.Union[pa.Array, pa.ChunkedArray]) -> JSONArray:
73-
# """Convert to JSONArray from an Arrow array."""
74-
# return JSONArray(array)
75-
7662

7763
class JSONArray(arrays.ArrowExtensionArray):
7864
"""Extension array that handles BigQuery JSON data, leveraging a string-based
@@ -95,26 +81,26 @@ def _box_pa(
9581
cls, value, pa_type: pa.DataType | None = None
9682
) -> pa.Array | pa.ChunkedArray | pa.Scalar:
9783
"""Box value into a pyarrow Array, ChunkedArray or Scalar."""
84+
if pa_type is not None and pa_type != pa.string():
85+
raise ValueError(f"Unsupported type '{pa_type}' for JSONArray")
9886

9987
if isinstance(value, pa.Scalar) or not (
10088
common.is_list_like(value) and not common.is_dict_like(value)
10189
):
102-
return cls._box_pa_scalar(value, pa_type)
103-
return cls._box_pa_array(value, pa_type)
90+
return cls._box_pa_scalar(value)
91+
return cls._box_pa_array(value)
10492

10593
@classmethod
106-
def _box_pa_scalar(cls, value, pa_type: pa.DataType | None = None) -> pa.Scalar:
94+
def _box_pa_scalar(cls, value) -> pa.Scalar:
10795
"""Box value into a pyarrow Scalar."""
10896
if isinstance(value, pa.Scalar):
10997
pa_scalar = value
11098
if pd.isna(value):
111-
pa_scalar = pa.scalar(None, type=pa_type)
99+
pa_scalar = pa.scalar(None, type=pa.string())
112100
else:
113101
value = JSONArray._serialize_json(value)
114-
pa_scalar = pa.scalar(value, type=pa_type, from_pandas=True)
102+
pa_scalar = pa.scalar(value, type=pa.string(), from_pandas=True)
115103

116-
if pa_type is not None and pa_scalar.type != pa_type:
117-
pa_scalar = pa_scalar.cast(pa_type)
118104
return pa_scalar
119105

120106
@classmethod
@@ -131,7 +117,8 @@ def _box_pa_array(
131117
value = [JSONArray._serialize_json(x) for x in value]
132118
pa_array = pa.array(value, type=pa_type, from_pandas=True)
133119
except (pa.ArrowInvalid, pa.ArrowTypeError):
134-
# GH50430: let pyarrow infer type, then cast
120+
# https://github.com/pandas-dev/pandas/pull/50430:
121+
# let pyarrow infer type, then cast
135122
pa_array = pa.array(value, from_pandas=True)
136123

137124
if pa_type is not None and pa_array.type != pa_type:
@@ -181,8 +168,12 @@ def dtype(self) -> JSONDtype:
181168
return self._dtype
182169

183170
def _cmp_method(self, other, op):
184-
pc_func = ARROW_CMP_FUNCS[op.__name__]
185-
result = pc_func(self._pa_array, self._box_pa(other))
171+
if op.__name__ == "eq":
172+
result = pyarrow.compute.equal(self._pa_array, self._box_pa(other))
173+
elif op.__name__ == "ne":
174+
result = pyarrow.compute.not_equal(self._pa_array, self._box_pa(other))
175+
else:
176+
raise NotImplementedError(f"{op.__name__} not implemented for JSONArray")
186177
return arrays.ArrowExtensionArray(result)
187178

188179
def __getitem__(self, item):

0 commit comments

Comments
 (0)