Skip to content

Commit f413f35

Browse files
authored
fix: support correct numpy construction for dbjson dtype in pandas 1.5 (#297)
* fix: support correct numpy construction for dbjson dtype in pandas 1.5 * add unit tests for pandas 1.5 * nit * fixing import error in python 3.7 * update unit tests * nit
1 parent 2bc6a1c commit f413f35

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

db_dtypes/json.py

+13
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,16 @@ def _reduce(
231231
if name in ["min", "max"]:
232232
raise TypeError("JSONArray does not support min/max reducntion.")
233233
super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
234+
235+
def __array__(self, dtype=None, copy: bool | None = None) -> np.ndarray:
236+
"""Correctly construct numpy arrays when passed to `np.asarray()`."""
237+
pa_type = self.pa_data.type
238+
data = self
239+
if dtype is None:
240+
empty = pa.array([], type=pa_type).to_numpy(zero_copy_only=False)
241+
dtype = empty.dtype
242+
result = np.empty(len(data), dtype=dtype)
243+
mask = data.isna()
244+
result[mask] = self._dtype.na_value
245+
result[~mask] = data[~mask].pa_data.to_numpy()
246+
return result

testing/constraints-3.9.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# Make sure we test with pandas 1.5.0. The Python version isn't that relevant.
1+
# Make sure we test with pandas 1.5.3. The Python version isn't that relevant.
22
pandas==1.5.3
33
numpy==1.24.0

tests/unit/test_json.py

+33
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516

17+
import numpy as np
1618
import pandas as pd
1719
import pytest
1820

@@ -81,3 +83,34 @@ def test_deterministic_json_serialization():
8183
y = {"b": 1, "a": 0}
8284
data = db_dtypes.JSONArray._from_sequence([y])
8385
assert data[0] == x
86+
87+
88+
def test_to_numpy():
89+
"""
90+
Verifies that JSONArray can be cast to a NumPy array.
91+
This test ensures compatibility with Python 3.9 and replicates the behavior
92+
of the `test_to_numpy` test from `test_json_compliance.py::TestJSONArrayCasting`,
93+
which is run with Python 3.12 environments only.
94+
"""
95+
data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values())
96+
expected = np.asarray(data)
97+
98+
result = data.to_numpy()
99+
pd._testing.assert_equal(result, expected)
100+
101+
result = pd.Series(data).to_numpy()
102+
pd._testing.assert_equal(result, expected)
103+
104+
105+
def test_as_numpy_array():
106+
data = db_dtypes.JSONArray._from_sequence(JSON_DATA.values())
107+
result = np.asarray(data)
108+
expected = np.asarray(
109+
[
110+
json.dumps(value, sort_keys=True, separators=(",", ":"))
111+
if value is not None
112+
else pd.NA
113+
for value in JSON_DATA.values()
114+
]
115+
)
116+
pd._testing.assert_equal(result, expected)

0 commit comments

Comments
 (0)