diff --git a/db_dtypes/core.py b/db_dtypes/core.py index 7879571..5d5c053 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -152,29 +152,35 @@ def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs): result = pandas_backports.nanmin( values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna ) - return self._box_func(result) + if axis is None or self.ndim == 1: + return self._box_func(result) + return self._from_backing_data(result) def max(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs): pandas_backports.numpy_validate_max((), kwargs) result = pandas_backports.nanmax( values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna ) - return self._box_func(result) - - if pandas_release >= (1, 2): - - def median( - self, - *, - axis: Optional[int] = None, - out=None, - overwrite_input: bool = False, - keepdims: bool = False, - skipna: bool = True, - ): - pandas_backports.numpy_validate_median( - (), - {"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims}, - ) - result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna) + if axis is None or self.ndim == 1: + return self._box_func(result) + return self._from_backing_data(result) + + def median( + self, + *, + axis: Optional[int] = None, + out=None, + overwrite_input: bool = False, + keepdims: bool = False, + skipna: bool = True, + ): + if not hasattr(pandas_backports, "numpy_validate_median"): + raise NotImplementedError("Need pandas 1.3 or later to calculate median.") + + pandas_backports.numpy_validate_median( + (), {"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims}, + ) + result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna) + if axis is None or self.ndim == 1: return self._box_func(result) + return self._from_backing_data(result) diff --git a/db_dtypes/pandas_backports.py b/db_dtypes/pandas_backports.py index 0e39986..0966e83 100644 --- a/db_dtypes/pandas_backports.py +++ b/db_dtypes/pandas_backports.py @@ -106,12 +106,8 @@ def __ge__(self, other): # See: https://github.com/pandas-dev/pandas/pull/45544 @import_default("pandas.core.arrays._mixins", pandas_release < (1, 3)) class NDArrayBackedExtensionArray(pandas.core.arrays.base.ExtensionArray): - - ndim = 1 - def __init__(self, values, dtype): assert isinstance(values, numpy.ndarray) - assert values.ndim == 1 self._ndarray = values self._dtype = dtype diff --git a/tests/compliance/date/test_date_compliance_1_5.py b/tests/compliance/date/test_date_compliance_1_5.py new file mode 100644 index 0000000..9c6da24 --- /dev/null +++ b/tests/compliance/date/test_date_compliance_1_5.py @@ -0,0 +1,35 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for extension interface compliance, inherited from pandas. + +See: +https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/decimal/test_decimal.py +and +https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/test_period.py +""" + +from pandas.tests.extension import base +import pytest + +# NDArrayBacked2DTests suite added in https://github.com/pandas-dev/pandas/pull/44974 +pytest.importorskip("pandas", minversion="1.5.0dev") + + +class Test2DCompat(base.NDArrayBacked2DTests): + pass + + +class TestIndex(base.BaseIndexTests): + pass diff --git a/tests/unit/test_date.py b/tests/unit/test_date.py index fb41620..b8f36f6 100644 --- a/tests/unit/test_date.py +++ b/tests/unit/test_date.py @@ -16,6 +16,7 @@ import operator import numpy +import numpy.testing import pandas import pandas.testing import pytest @@ -154,6 +155,100 @@ def test_date_parsing_errors(value, error): pandas.Series([value], dtype="dbdate") +def test_date_max_2d(): + input_array = db_dtypes.DateArray( + numpy.array( + [ + [ + numpy.datetime64("1970-01-01"), + numpy.datetime64("1980-02-02"), + numpy.datetime64("1990-03-03"), + ], + [ + numpy.datetime64("1971-02-02"), + numpy.datetime64("1981-03-03"), + numpy.datetime64("1991-04-04"), + ], + [ + numpy.datetime64("1972-03-03"), + numpy.datetime64("1982-04-04"), + numpy.datetime64("1992-05-05"), + ], + ], + dtype="datetime64[ns]", + ) + ) + numpy.testing.assert_array_equal( + input_array.max(axis=0)._ndarray, + numpy.array( + [ + numpy.datetime64("1972-03-03"), + numpy.datetime64("1982-04-04"), + numpy.datetime64("1992-05-05"), + ], + dtype="datetime64[ns]", + ), + ) + numpy.testing.assert_array_equal( + input_array.max(axis=1)._ndarray, + numpy.array( + [ + numpy.datetime64("1990-03-03"), + numpy.datetime64("1991-04-04"), + numpy.datetime64("1992-05-05"), + ], + dtype="datetime64[ns]", + ), + ) + + +def test_date_min_2d(): + input_array = db_dtypes.DateArray( + numpy.array( + [ + [ + numpy.datetime64("1970-01-01"), + numpy.datetime64("1980-02-02"), + numpy.datetime64("1990-03-03"), + ], + [ + numpy.datetime64("1971-02-02"), + numpy.datetime64("1981-03-03"), + numpy.datetime64("1991-04-04"), + ], + [ + numpy.datetime64("1972-03-03"), + numpy.datetime64("1982-04-04"), + numpy.datetime64("1992-05-05"), + ], + ], + dtype="datetime64[ns]", + ) + ) + numpy.testing.assert_array_equal( + input_array.min(axis=0)._ndarray, + numpy.array( + [ + numpy.datetime64("1970-01-01"), + numpy.datetime64("1980-02-02"), + numpy.datetime64("1990-03-03"), + ], + dtype="datetime64[ns]", + ), + ) + numpy.testing.assert_array_equal( + input_array.min(axis=1)._ndarray, + numpy.array( + [ + numpy.datetime64("1970-01-01"), + numpy.datetime64("1971-02-02"), + numpy.datetime64("1972-03-03"), + ], + dtype="datetime64[ns]", + ), + ) + + @pytest.mark.skipif( not hasattr(pandas_backports, "numpy_validate_median"), reason="median not available with this version of pandas", @@ -178,3 +273,58 @@ def test_date_parsing_errors(value, error): def test_date_median(values, expected): series = pandas.Series(values, dtype="dbdate") assert series.median() == expected + + +@pytest.mark.skipif( + not hasattr(pandas_backports, "numpy_validate_median"), + reason="median not available with this version of pandas", +) +def test_date_median_2d(): + input_array = db_dtypes.DateArray( + numpy.array( + [ + [ + numpy.datetime64("1970-01-01"), + numpy.datetime64("1980-02-02"), + numpy.datetime64("1990-03-03"), + ], + [ + numpy.datetime64("1971-02-02"), + numpy.datetime64("1981-03-03"), + numpy.datetime64("1991-04-04"), + ], + [ + numpy.datetime64("1972-03-03"), + numpy.datetime64("1982-04-04"), + numpy.datetime64("1992-05-05"), + ], + ], + dtype="datetime64[ns]", + ) + ) + pandas.testing.assert_extension_array_equal( + input_array.median(axis=0), + db_dtypes.DateArray( + numpy.array( + [ + numpy.datetime64("1971-02-02"), + numpy.datetime64("1981-03-03"), + numpy.datetime64("1991-04-04"), + ], + dtype="datetime64[ns]", + ) + ), + ) + pandas.testing.assert_extension_array_equal( + input_array.median(axis=1), + db_dtypes.DateArray( + numpy.array( + [ + numpy.datetime64("1980-02-02"), + numpy.datetime64("1981-03-03"), + numpy.datetime64("1982-04-04"), + ], + dtype="datetime64[ns]", + ) + ), + )