diff --git a/db_dtypes/json.py b/db_dtypes/json.py index ed04b72..a00fe2b 100644 --- a/db_dtypes/json.py +++ b/db_dtypes/json.py @@ -72,14 +72,25 @@ class JSONArray(arrays.ArrowExtensionArray): _dtype = JSONDtype() - def __init__(self, values, dtype=None, copy=False) -> None: + def __init__(self, values) -> None: + super().__init__(values) self._dtype = JSONDtype() if isinstance(values, pa.Array): - self._pa_array = pa.chunked_array([values]) + pa_data = pa.chunked_array([values]) elif isinstance(values, pa.ChunkedArray): - self._pa_array = values + pa_data = values else: - raise ValueError(f"Unsupported type '{type(values)}' for JSONArray") + raise NotImplementedError( + f"Unsupported type '{type(values)}' for JSONArray" + ) + + # Ensures compatibility with pandas version 1.5.3 + if hasattr(self, "_data"): + self._data = pa_data + elif hasattr(self, "_pa_array"): + self._pa_array = pa_data + else: + raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}") @classmethod def _box_pa( @@ -111,7 +122,7 @@ def _box_pa_scalar(cls, value) -> pa.Scalar: def _box_pa_array(cls, value, copy: bool = False) -> pa.Array | pa.ChunkedArray: """Box value into a pyarrow Array or ChunkedArray.""" if isinstance(value, cls): - pa_array = value._pa_array + pa_array = value.pa_data else: value = [JSONArray._serialize_json(x) for x in value] pa_array = pa.array(value, type=cls._dtype.pyarrow_dtype, from_pandas=True) @@ -147,11 +158,22 @@ def dtype(self) -> JSONDtype: """An instance of JSONDtype""" return self._dtype + @property + def pa_data(self): + """An instance of stored pa data""" + # Ensures compatibility with pandas version 1.5.3 + if hasattr(self, "_data"): + return self._data + elif hasattr(self, "_pa_array"): + return self._pa_array + else: + raise NotImplementedError(f"Unsupported pandas version: {pd.__version__}") + def _cmp_method(self, other, op): if op.__name__ == "eq": - result = pyarrow.compute.equal(self._pa_array, self._box_pa(other)) + result = pyarrow.compute.equal(self.pa_data, self._box_pa(other)) elif op.__name__ == "ne": - result = pyarrow.compute.not_equal(self._pa_array, self._box_pa(other)) + result = pyarrow.compute.not_equal(self.pa_data, self._box_pa(other)) else: # Comparison is not a meaningful one. We don't want to support sorting by JSON columns. raise TypeError(f"{op.__name__} not supported for JSONArray") @@ -169,7 +191,7 @@ def __getitem__(self, item): else: # `check_array_indexer` should verify that the assertion hold true. assert item.dtype.kind == "b" - return type(self)(self._pa_array.filter(item)) + return type(self)(self.pa_data.filter(item)) elif isinstance(item, tuple): item = indexers.unpack_tuple_and_ellipses(item) @@ -181,7 +203,7 @@ def __getitem__(self, item): r"(`None`) and integer or boolean arrays are valid indices" ) - value = self._pa_array[item] + value = self.pa_data[item] if isinstance(value, pa.ChunkedArray): return type(self)(value) else: @@ -193,7 +215,7 @@ def __getitem__(self, item): def __iter__(self): """Iterate over elements of the array.""" - for value in self._pa_array: + for value in self.pa_data: val = JSONArray._deserialize_json(value.as_py()) if val is None: yield self._dtype.na_value diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index b9ab6bf..4700825 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -1,3 +1,3 @@ -# Make sure we test with pandas 1.3.0. The Python version isn't that relevant. -pandas==1.3.0 -numpy<2.0.0 +# Make sure we test with pandas 1.5.0. The Python version isn't that relevant. +pandas==1.5.3 +numpy==1.24.0 \ No newline at end of file diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index c48635d..365bd8f 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -13,8 +13,6 @@ # limitations under the License. -import json - import pandas as pd import pytest @@ -78,18 +76,8 @@ def test_getitems_when_iter_with_null(): assert pd.isna(result) -def test_to_numpy(): - s = pd.Series(db_dtypes.JSONArray._from_sequence(JSON_DATA.values())) - data = s.to_numpy() - for id, key in enumerate(JSON_DATA.keys()): - if key == "null": - assert pd.isna(data[id]) - else: - assert data[id] == json.dumps(JSON_DATA[key], sort_keys=True) - - def test_deterministic_json_serialization(): x = {"a": 0, "b": 1} y = {"b": 1, "a": 0} - data = db_dtypes.JSONArray._from_sequence([x]) - assert y in data + data = db_dtypes.JSONArray._from_sequence([y]) + assert data[0] == x