|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import json |
| 18 | + |
| 19 | +import numpy as np |
| 20 | +import pandas as pd |
| 21 | +import pandas.arrays as arrays |
| 22 | +import pandas.core.dtypes.common as common |
| 23 | +import pandas.core.indexers as indexers |
| 24 | +import pyarrow as pa |
| 25 | +import pyarrow.compute |
| 26 | + |
| 27 | + |
| 28 | +@pd.api.extensions.register_extension_dtype |
| 29 | +class JSONDtype(pd.api.extensions.ExtensionDtype): |
| 30 | + """Extension dtype for BigQuery JSON data.""" |
| 31 | + |
| 32 | + name = "dbjson" |
| 33 | + |
| 34 | + @property |
| 35 | + def na_value(self) -> pd.NA: |
| 36 | + """Default NA value to use for this type.""" |
| 37 | + return pd.NA |
| 38 | + |
| 39 | + @property |
| 40 | + def type(self) -> type[str]: |
| 41 | + """ |
| 42 | + Return the scalar type for the array elements. |
| 43 | + The standard JSON data types can be one of `dict`, `list`, `str`, `int`, `float`, |
| 44 | + `bool` and `None`. However, this method returns a `str` type to indicate its |
| 45 | + storage type, because the union of multiple types are not supported well in pandas. |
| 46 | + """ |
| 47 | + return str |
| 48 | + |
| 49 | + @property |
| 50 | + def pyarrow_dtype(self): |
| 51 | + """Return the pyarrow data type used for storing data in the pyarrow array.""" |
| 52 | + return pa.string() |
| 53 | + |
| 54 | + @property |
| 55 | + def _is_numeric(self) -> bool: |
| 56 | + return False |
| 57 | + |
| 58 | + @property |
| 59 | + def _is_boolean(self) -> bool: |
| 60 | + return False |
| 61 | + |
| 62 | + @classmethod |
| 63 | + def construct_array_type(cls): |
| 64 | + """Return the array type associated with this dtype.""" |
| 65 | + return JSONArray |
| 66 | + |
| 67 | + |
| 68 | +class JSONArray(arrays.ArrowExtensionArray): |
| 69 | + """Extension array that handles BigQuery JSON data, leveraging a string-based |
| 70 | + pyarrow array for storage. It enables seamless conversion to JSON objects when |
| 71 | + accessing individual elements.""" |
| 72 | + |
| 73 | + _dtype = JSONDtype() |
| 74 | + |
| 75 | + def __init__(self, values, dtype=None, copy=False) -> None: |
| 76 | + self._dtype = JSONDtype() |
| 77 | + if isinstance(values, pa.Array): |
| 78 | + self._pa_array = pa.chunked_array([values]) |
| 79 | + elif isinstance(values, pa.ChunkedArray): |
| 80 | + self._pa_array = values |
| 81 | + else: |
| 82 | + raise ValueError(f"Unsupported type '{type(values)}' for JSONArray") |
| 83 | + |
| 84 | + @classmethod |
| 85 | + def _box_pa( |
| 86 | + cls, value, pa_type: pa.DataType | None = None |
| 87 | + ) -> pa.Array | pa.ChunkedArray | pa.Scalar: |
| 88 | + """Box value into a pyarrow Array, ChunkedArray or Scalar.""" |
| 89 | + assert pa_type is None or pa_type == cls._dtype.pyarrow_dtype |
| 90 | + |
| 91 | + if isinstance(value, pa.Scalar) or not ( |
| 92 | + common.is_list_like(value) and not common.is_dict_like(value) |
| 93 | + ): |
| 94 | + return cls._box_pa_scalar(value) |
| 95 | + return cls._box_pa_array(value) |
| 96 | + |
| 97 | + @classmethod |
| 98 | + def _box_pa_scalar(cls, value) -> pa.Scalar: |
| 99 | + """Box value into a pyarrow Scalar.""" |
| 100 | + if pd.isna(value): |
| 101 | + pa_scalar = pa.scalar(None, type=cls._dtype.pyarrow_dtype) |
| 102 | + else: |
| 103 | + value = JSONArray._serialize_json(value) |
| 104 | + pa_scalar = pa.scalar( |
| 105 | + value, type=cls._dtype.pyarrow_dtype, from_pandas=True |
| 106 | + ) |
| 107 | + |
| 108 | + return pa_scalar |
| 109 | + |
| 110 | + @classmethod |
| 111 | + def _box_pa_array(cls, value, copy: bool = False) -> pa.Array | pa.ChunkedArray: |
| 112 | + """Box value into a pyarrow Array or ChunkedArray.""" |
| 113 | + if isinstance(value, cls): |
| 114 | + pa_array = value._pa_array |
| 115 | + else: |
| 116 | + value = [JSONArray._serialize_json(x) for x in value] |
| 117 | + pa_array = pa.array(value, type=cls._dtype.pyarrow_dtype, from_pandas=True) |
| 118 | + return pa_array |
| 119 | + |
| 120 | + @classmethod |
| 121 | + def _from_sequence(cls, scalars, *, dtype=None, copy=False): |
| 122 | + """Construct a new ExtensionArray from a sequence of scalars.""" |
| 123 | + pa_array = cls._box_pa(scalars) |
| 124 | + arr = cls(pa_array) |
| 125 | + return arr |
| 126 | + |
| 127 | + @staticmethod |
| 128 | + def _serialize_json(value): |
| 129 | + """A static method that converts a JSON value into a string representation.""" |
| 130 | + if not common.is_list_like(value) and pd.isna(value): |
| 131 | + return value |
| 132 | + else: |
| 133 | + # `sort_keys=True` sorts dictionary keys before serialization, making |
| 134 | + # JSON comparisons deterministic. |
| 135 | + return json.dumps(value, sort_keys=True) |
| 136 | + |
| 137 | + @staticmethod |
| 138 | + def _deserialize_json(value): |
| 139 | + """A static method that converts a JSON string back into its original value.""" |
| 140 | + if not pd.isna(value): |
| 141 | + return json.loads(value) |
| 142 | + else: |
| 143 | + return value |
| 144 | + |
| 145 | + @property |
| 146 | + def dtype(self) -> JSONDtype: |
| 147 | + """An instance of JSONDtype""" |
| 148 | + return self._dtype |
| 149 | + |
| 150 | + def _cmp_method(self, other, op): |
| 151 | + if op.__name__ == "eq": |
| 152 | + result = pyarrow.compute.equal(self._pa_array, self._box_pa(other)) |
| 153 | + elif op.__name__ == "ne": |
| 154 | + result = pyarrow.compute.not_equal(self._pa_array, self._box_pa(other)) |
| 155 | + else: |
| 156 | + # Comparison is not a meaningful one. We don't want to support sorting by JSON columns. |
| 157 | + raise TypeError(f"{op.__name__} not supported for JSONArray") |
| 158 | + return arrays.ArrowExtensionArray(result) |
| 159 | + |
| 160 | + def __getitem__(self, item): |
| 161 | + """Select a subset of self.""" |
| 162 | + item = indexers.check_array_indexer(self, item) |
| 163 | + |
| 164 | + if isinstance(item, np.ndarray): |
| 165 | + if not len(item): |
| 166 | + return type(self)(pa.chunked_array([], type=self.dtype.pyarrow_dtype)) |
| 167 | + elif item.dtype.kind in "iu": |
| 168 | + return self.take(item) |
| 169 | + else: |
| 170 | + # `check_array_indexer` should verify that the assertion hold true. |
| 171 | + assert item.dtype.kind == "b" |
| 172 | + return type(self)(self._pa_array.filter(item)) |
| 173 | + elif isinstance(item, tuple): |
| 174 | + item = indexers.unpack_tuple_and_ellipses(item) |
| 175 | + |
| 176 | + if common.is_scalar(item) and not common.is_integer(item): |
| 177 | + # e.g. "foo" or 2.5 |
| 178 | + # exception message copied from numpy |
| 179 | + raise IndexError( |
| 180 | + r"only integers, slices (`:`), ellipsis (`...`), numpy.newaxis " |
| 181 | + r"(`None`) and integer or boolean arrays are valid indices" |
| 182 | + ) |
| 183 | + |
| 184 | + value = self._pa_array[item] |
| 185 | + if isinstance(value, pa.ChunkedArray): |
| 186 | + return type(self)(value) |
| 187 | + else: |
| 188 | + scalar = JSONArray._deserialize_json(value.as_py()) |
| 189 | + if scalar is None: |
| 190 | + return self._dtype.na_value |
| 191 | + else: |
| 192 | + return scalar |
| 193 | + |
| 194 | + def __iter__(self): |
| 195 | + """Iterate over elements of the array.""" |
| 196 | + for value in self._pa_array: |
| 197 | + val = JSONArray._deserialize_json(value.as_py()) |
| 198 | + if val is None: |
| 199 | + yield self._dtype.na_value |
| 200 | + else: |
| 201 | + yield val |
| 202 | + |
| 203 | + def _reduce( |
| 204 | + self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
| 205 | + ): |
| 206 | + """Return a scalar result of performing the reduction operation.""" |
| 207 | + if name in ["min", "max"]: |
| 208 | + raise TypeError("JSONArray does not support min/max reducntion.") |
| 209 | + super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs) |
0 commit comments