diff --git a/db_dtypes/core.py b/db_dtypes/core.py index 5d5c053..f577960 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -113,6 +113,14 @@ def _validate_scalar(self, value): """ return self._datetime(value) + def _validate_searchsorted_value(self, value): + """ + Convert a value for use in searching for a value in the backing numpy array. + + TODO: With pandas 2.0, this may be unnecessary. https://github.com/pandas-dev/pandas/pull/45544#issuecomment-1052809232 + """ + return self._validate_setitem_value(value) + def _validate_setitem_value(self, value): """ Convert a value for use in setting a value in the backing numpy array. diff --git a/tests/compliance/conftest.py b/tests/compliance/conftest.py index bc76692..54b767c 100644 --- a/tests/compliance/conftest.py +++ b/tests/compliance/conftest.py @@ -16,6 +16,28 @@ import pytest +@pytest.fixture(params=[True, False]) +def as_frame(request): + """ + Boolean fixture to support Series and Series.to_frame() comparison testing. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def as_series(request): + """ + Boolean fixture to support arr and Series(arr) comparison testing. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param + + @pytest.fixture(params=["ffill", "bfill"]) def fillna_method(request): """ @@ -28,6 +50,21 @@ def fillna_method(request): return request.param +@pytest.fixture +def invalid_scalar(data): + """ + A scalar that *cannot* be held by this ExtensionArray. + + The default should work for most subclasses, but is not guaranteed. + + If the array can hold any item (i.e. object dtype), then use pytest.skip. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return object.__new__(object) + + @pytest.fixture def na_value(): return pandas.NaT @@ -51,3 +88,26 @@ def cmp(a, b): return a is pandas.NaT and a is b return cmp + + +@pytest.fixture(params=[None, lambda x: x]) +def sort_by_key(request): + """ + Simple fixture for testing keys in sorting methods. + Tests None (no key) and the identity key. + + See: https://github.com/pandas-dev/pandas/blob/main/pandas/conftest.py + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def use_numpy(request): + """ + Boolean fixture to support comparison testing of ExtensionDtype array + and numpy array. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param diff --git a/tests/compliance/date/conftest.py b/tests/compliance/date/conftest.py index e25ccc9..6f0a816 100644 --- a/tests/compliance/date/conftest.py +++ b/tests/compliance/date/conftest.py @@ -20,6 +20,15 @@ from db_dtypes import DateArray, DateDtype +@pytest.fixture(params=["data", "data_missing"]) +def all_data(request, data, data_missing): + """Parametrized fixture giving 'data' and 'data_missing'""" + if request.param == "data": + return data + elif request.param == "data_missing": + return data_missing + + @pytest.fixture def data(): return DateArray( @@ -32,6 +41,52 @@ def data(): ) +@pytest.fixture +def data_for_grouping(): + """ + Data for factorization, grouping, and unique tests. + + Expected to be like [B, B, NA, NA, A, A, B, C] + + Where A < B < C and NA is missing + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return DateArray( + [ + datetime.date(1980, 1, 27), + datetime.date(1980, 1, 27), + None, + None, + datetime.date(1969, 12, 30), + datetime.date(1969, 12, 30), + datetime.date(1980, 1, 27), + datetime.date(2022, 3, 18), + ] + ) + + +@pytest.fixture +def data_for_sorting(): + """ + Length-3 array with a known sort order. + + This should be three items [B, C, A] with + A < B < C + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return DateArray( + [ + datetime.date(1980, 1, 27), + datetime.date(2022, 3, 18), + datetime.date(1969, 12, 30), + ] + ) + + @pytest.fixture def data_missing(): """Length-2 array with [NA, Valid] @@ -42,6 +97,36 @@ def data_missing(): return DateArray([None, datetime.date(2022, 1, 27)]) +@pytest.fixture +def data_missing_for_sorting(): + """ + Length-3 array with a known sort order. + + This should be three items [B, NA, A] with + A < B and NA missing. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return DateArray([datetime.date(1980, 1, 27), None, datetime.date(1969, 12, 30)]) + + +@pytest.fixture +def data_repeated(data): + """ + Generate many datasets. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + + def gen(count): + for _ in range(count): + yield data + + return gen + + @pytest.fixture def dtype(): return DateDtype() diff --git a/tests/compliance/date/test_date_compliance.py b/tests/compliance/date/test_date_compliance.py index a805ecd..13327a7 100644 --- a/tests/compliance/date/test_date_compliance.py +++ b/tests/compliance/date/test_date_compliance.py @@ -20,7 +20,11 @@ https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/test_period.py """ +import pandas from pandas.tests.extension import base +import pytest + +import db_dtypes class TestDtype(base.BaseDtypeTests): @@ -45,3 +49,28 @@ class TestGetitem(base.BaseGetitemTests): class TestMissing(base.BaseMissingTests): pass + + +# TODO(https://github.com/googleapis/python-db-dtypes-pandas/issues/78): Add +# compliance tests for reduction operations. + + +class TestMethods(base.BaseMethodsTests): + def test_combine_add(self): + pytest.skip("Cannot add dates.") + + @pytest.mark.parametrize("dropna", [True, False]) + def test_value_counts(self, all_data, dropna): + all_data = all_data[:10] + if dropna: + # Overridden from + # https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/base/methods.py + # to avoid difference in dtypes. + other = db_dtypes.DateArray(all_data[~all_data.isna()]) + else: + other = all_data + + result = pandas.Series(all_data).value_counts(dropna=dropna).sort_index() + expected = pandas.Series(other).value_counts(dropna=dropna).sort_index() + + self.assert_series_equal(result, expected) diff --git a/tests/unit/test_date.py b/tests/unit/test_date.py index b8f36f6..bbe74cb 100644 --- a/tests/unit/test_date.py +++ b/tests/unit/test_date.py @@ -328,3 +328,30 @@ def test_date_median_2d(): ) ), ) + + +@pytest.mark.parametrize( + ("search_term", "expected_index"), + ( + (datetime.date(1899, 12, 31), 0), + (datetime.date(1900, 1, 1), 0), + (datetime.date(1920, 2, 2), 1), + (datetime.date(1930, 3, 3), 1), + (datetime.date(1950, 5, 5), 2), + (datetime.date(1990, 9, 9), 3), + (datetime.date(2012, 12, 12), 3), + (datetime.date(2022, 3, 24), 4), + ), +) +def test_date_searchsorted(search_term, expected_index): + test_series = pandas.Series( + [ + datetime.date(1900, 1, 1), + datetime.date(1930, 3, 3), + datetime.date(1980, 8, 8), + datetime.date(2012, 12, 12), + ], + dtype="dbdate", + ) + got = test_series.searchsorted(search_term) + assert got == expected_index