Skip to content

Commit b771e05

Browse files
fix: address failing 2D array compliance tests in DateArray (#64)
* fix: address failing compliance tests in DateArray and TimeArray test: add a test session with prerelease versions of dependencies * fix min/max/median for 2D arrays * fixes except for null contains * actually use NaT as 'advertised' * fix!: use `pandas.NaT` for missing values in dbdate and dbtime dtypes This makes them consistent with other date/time dtypes, as well as internally consistent with the advertised `dtype.na_value`. BREAKING-CHANGE: dbdate and dbtime dtypes return NaT instead of None for missing values Release-As: 0.4.0 * more progress towards compliance * address errors in TestMethods * move tests * add prerelease deps * fix: address failing tests with pandas 1.5.0 test: add a test session with prerelease versions of dependencies * fix owlbot config * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * document why microsecond precision is used * use correct units * add box_func tests * typo * add unit tests Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 1db1357 commit b771e05

File tree

4 files changed

+210
-23
lines changed

4 files changed

+210
-23
lines changed

db_dtypes/core.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -152,29 +152,35 @@ def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
152152
result = pandas_backports.nanmin(
153153
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
154154
)
155-
return self._box_func(result)
155+
if axis is None or self.ndim == 1:
156+
return self._box_func(result)
157+
return self._from_backing_data(result)
156158

157159
def max(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
158160
pandas_backports.numpy_validate_max((), kwargs)
159161
result = pandas_backports.nanmax(
160162
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
161163
)
162-
return self._box_func(result)
163-
164-
if pandas_release >= (1, 2):
165-
166-
def median(
167-
self,
168-
*,
169-
axis: Optional[int] = None,
170-
out=None,
171-
overwrite_input: bool = False,
172-
keepdims: bool = False,
173-
skipna: bool = True,
174-
):
175-
pandas_backports.numpy_validate_median(
176-
(),
177-
{"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims},
178-
)
179-
result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna)
164+
if axis is None or self.ndim == 1:
165+
return self._box_func(result)
166+
return self._from_backing_data(result)
167+
168+
def median(
169+
self,
170+
*,
171+
axis: Optional[int] = None,
172+
out=None,
173+
overwrite_input: bool = False,
174+
keepdims: bool = False,
175+
skipna: bool = True,
176+
):
177+
if not hasattr(pandas_backports, "numpy_validate_median"):
178+
raise NotImplementedError("Need pandas 1.3 or later to calculate median.")
179+
180+
pandas_backports.numpy_validate_median(
181+
(), {"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims},
182+
)
183+
result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna)
184+
if axis is None or self.ndim == 1:
180185
return self._box_func(result)
186+
return self._from_backing_data(result)

db_dtypes/pandas_backports.py

-4
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,8 @@ def __ge__(self, other):
106106
# See: https://github.com/pandas-dev/pandas/pull/45544
107107
@import_default("pandas.core.arrays._mixins", pandas_release < (1, 3))
108108
class NDArrayBackedExtensionArray(pandas.core.arrays.base.ExtensionArray):
109-
110-
ndim = 1
111-
112109
def __init__(self, values, dtype):
113110
assert isinstance(values, numpy.ndarray)
114-
assert values.ndim == 1
115111
self._ndarray = values
116112
self._dtype = dtype
117113

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Tests for extension interface compliance, inherited from pandas.
16+
17+
See:
18+
https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/decimal/test_decimal.py
19+
and
20+
https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/test_period.py
21+
"""
22+
23+
from pandas.tests.extension import base
24+
import pytest
25+
26+
# NDArrayBacked2DTests suite added in https://github.com/pandas-dev/pandas/pull/44974
27+
pytest.importorskip("pandas", minversion="1.5.0dev")
28+
29+
30+
class Test2DCompat(base.NDArrayBacked2DTests):
31+
pass
32+
33+
34+
class TestIndex(base.BaseIndexTests):
35+
pass

tests/unit/test_date.py

+150
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import operator
1717

1818
import numpy
19+
import numpy.testing
1920
import pandas
2021
import pandas.testing
2122
import pytest
@@ -154,6 +155,100 @@ def test_date_parsing_errors(value, error):
154155
pandas.Series([value], dtype="dbdate")
155156

156157

158+
def test_date_max_2d():
159+
input_array = db_dtypes.DateArray(
160+
numpy.array(
161+
[
162+
[
163+
numpy.datetime64("1970-01-01"),
164+
numpy.datetime64("1980-02-02"),
165+
numpy.datetime64("1990-03-03"),
166+
],
167+
[
168+
numpy.datetime64("1971-02-02"),
169+
numpy.datetime64("1981-03-03"),
170+
numpy.datetime64("1991-04-04"),
171+
],
172+
[
173+
numpy.datetime64("1972-03-03"),
174+
numpy.datetime64("1982-04-04"),
175+
numpy.datetime64("1992-05-05"),
176+
],
177+
],
178+
dtype="datetime64[ns]",
179+
)
180+
)
181+
numpy.testing.assert_array_equal(
182+
input_array.max(axis=0)._ndarray,
183+
numpy.array(
184+
[
185+
numpy.datetime64("1972-03-03"),
186+
numpy.datetime64("1982-04-04"),
187+
numpy.datetime64("1992-05-05"),
188+
],
189+
dtype="datetime64[ns]",
190+
),
191+
)
192+
numpy.testing.assert_array_equal(
193+
input_array.max(axis=1)._ndarray,
194+
numpy.array(
195+
[
196+
numpy.datetime64("1990-03-03"),
197+
numpy.datetime64("1991-04-04"),
198+
numpy.datetime64("1992-05-05"),
199+
],
200+
dtype="datetime64[ns]",
201+
),
202+
)
203+
204+
205+
def test_date_min_2d():
206+
input_array = db_dtypes.DateArray(
207+
numpy.array(
208+
[
209+
[
210+
numpy.datetime64("1970-01-01"),
211+
numpy.datetime64("1980-02-02"),
212+
numpy.datetime64("1990-03-03"),
213+
],
214+
[
215+
numpy.datetime64("1971-02-02"),
216+
numpy.datetime64("1981-03-03"),
217+
numpy.datetime64("1991-04-04"),
218+
],
219+
[
220+
numpy.datetime64("1972-03-03"),
221+
numpy.datetime64("1982-04-04"),
222+
numpy.datetime64("1992-05-05"),
223+
],
224+
],
225+
dtype="datetime64[ns]",
226+
)
227+
)
228+
numpy.testing.assert_array_equal(
229+
input_array.min(axis=0)._ndarray,
230+
numpy.array(
231+
[
232+
numpy.datetime64("1970-01-01"),
233+
numpy.datetime64("1980-02-02"),
234+
numpy.datetime64("1990-03-03"),
235+
],
236+
dtype="datetime64[ns]",
237+
),
238+
)
239+
numpy.testing.assert_array_equal(
240+
input_array.min(axis=1)._ndarray,
241+
numpy.array(
242+
[
243+
numpy.datetime64("1970-01-01"),
244+
numpy.datetime64("1971-02-02"),
245+
numpy.datetime64("1972-03-03"),
246+
],
247+
dtype="datetime64[ns]",
248+
),
249+
)
250+
251+
157252
@pytest.mark.skipif(
158253
not hasattr(pandas_backports, "numpy_validate_median"),
159254
reason="median not available with this version of pandas",
@@ -178,3 +273,58 @@ def test_date_parsing_errors(value, error):
178273
def test_date_median(values, expected):
179274
series = pandas.Series(values, dtype="dbdate")
180275
assert series.median() == expected
276+
277+
278+
@pytest.mark.skipif(
279+
not hasattr(pandas_backports, "numpy_validate_median"),
280+
reason="median not available with this version of pandas",
281+
)
282+
def test_date_median_2d():
283+
input_array = db_dtypes.DateArray(
284+
numpy.array(
285+
[
286+
[
287+
numpy.datetime64("1970-01-01"),
288+
numpy.datetime64("1980-02-02"),
289+
numpy.datetime64("1990-03-03"),
290+
],
291+
[
292+
numpy.datetime64("1971-02-02"),
293+
numpy.datetime64("1981-03-03"),
294+
numpy.datetime64("1991-04-04"),
295+
],
296+
[
297+
numpy.datetime64("1972-03-03"),
298+
numpy.datetime64("1982-04-04"),
299+
numpy.datetime64("1992-05-05"),
300+
],
301+
],
302+
dtype="datetime64[ns]",
303+
)
304+
)
305+
pandas.testing.assert_extension_array_equal(
306+
input_array.median(axis=0),
307+
db_dtypes.DateArray(
308+
numpy.array(
309+
[
310+
numpy.datetime64("1971-02-02"),
311+
numpy.datetime64("1981-03-03"),
312+
numpy.datetime64("1991-04-04"),
313+
],
314+
dtype="datetime64[ns]",
315+
)
316+
),
317+
)
318+
pandas.testing.assert_extension_array_equal(
319+
input_array.median(axis=1),
320+
db_dtypes.DateArray(
321+
numpy.array(
322+
[
323+
numpy.datetime64("1980-02-02"),
324+
numpy.datetime64("1981-03-03"),
325+
numpy.datetime64("1982-04-04"),
326+
],
327+
dtype="datetime64[ns]",
328+
)
329+
),
330+
)

0 commit comments

Comments
 (0)