diff --git a/.github/workflows/compliance.yml b/.github/workflows/compliance.yml new file mode 100644 index 0000000..77e6b05 --- /dev/null +++ b/.github/workflows/compliance.yml @@ -0,0 +1,27 @@ +on: + pull_request: + branches: + - main +name: unittest +jobs: + compliance: + runs-on: ubuntu-latest + strategy: + matrix: + python: ['3.10'] + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Python + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python }} + - name: Install nox + run: | + python -m pip install --upgrade setuptools pip wheel + python -m pip install nox + - name: Run compliance tests + env: + COVERAGE_FILE: .coverage-compliance-${{ matrix.python }} + run: | + nox -s compliance diff --git a/db_dtypes/core.py b/db_dtypes/core.py index a06c6d6..b5b0b7a 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -17,7 +17,7 @@ import numpy import pandas import pandas.api.extensions -from pandas.api.types import is_dtype_equal, is_list_like, pandas_dtype +from pandas.api.types import is_dtype_equal, is_list_like, is_scalar, pandas_dtype from db_dtypes import pandas_backports @@ -31,9 +31,14 @@ class BaseDatetimeDtype(pandas.api.extensions.ExtensionDtype): names = None @classmethod - def construct_from_string(cls, name): + def construct_from_string(cls, name: str): + if not isinstance(name, str): + raise TypeError( + f"'construct_from_string' expects a string, got {type(name)}" + ) + if name != cls.name: - raise TypeError() + raise TypeError(f"Cannot construct a '{cls.__name__}' from 'another_type'") return cls() @@ -74,6 +79,11 @@ def astype(self, dtype, copy=True): return super().astype(dtype, copy=copy) def _cmp_method(self, other, op): + """Compare array values, for use in OpsMixin.""" + + if is_scalar(other) and (pandas.isna(other) or type(other) == self.dtype.type): + other = type(self)([other]) + oshape = getattr(other, "shape", None) if oshape != self.shape and oshape != (1,) and self.shape != (1,): raise TypeError( diff --git a/noxfile.py b/noxfile.py index 5f48361..54421d8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -37,6 +37,7 @@ nox.options.sessions = [ "lint", "unit", + "compliance", "cover", "lint_setup_py", "blacken", @@ -77,7 +78,7 @@ def lint_setup_py(session): session.run("python", "setup.py", "check", "--restructuredtext", "--strict") -def default(session): +def default(session, tests_path): # Install all test dependencies, then install this package in-place. constraints_path = str( @@ -106,15 +107,21 @@ def default(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit"), + tests_path, *session.posargs, ) +@nox.session(python=UNIT_TEST_PYTHON_VERSIONS[-1]) +def compliance(session): + """Run the compliance test suite.""" + default(session, os.path.join("tests", "compliance")) + + @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" - default(session) + default(session, os.path.join("tests", "unit")) @nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) diff --git a/owlbot.py b/owlbot.py index 30f3b3d..6c59671 100644 --- a/owlbot.py +++ b/owlbot.py @@ -64,11 +64,39 @@ new_sessions = """ "lint", "unit", + "compliance", "cover", """ s.replace(["noxfile.py"], old_sessions, new_sessions) +# Add compliance tests. +s.replace( + ["noxfile.py"], r"def default\(session\):", "def default(session, tests_path):" +) +s.replace(["noxfile.py"], r'os.path.join\("tests", "unit"\),', "tests_path,") +s.replace( + ["noxfile.py"], + r''' +@nox.session\(python=UNIT_TEST_PYTHON_VERSIONS\) +def unit\(session\): + """Run the unit test suite.""" + default\(session\) +''', + ''' +@nox.session(python=UNIT_TEST_PYTHON_VERSIONS[-1]) +def compliance(session): + """Run the compliance test suite.""" + default(session, os.path.join("tests", "compliance")) + + +@nox.session(python=UNIT_TEST_PYTHON_VERSIONS) +def unit(session): + """Run the unit test suite.""" + default(session, os.path.join("tests", "unit")) +''', +) + # ---------------------------------------------------------------------------- # Samples templates # ---------------------------------------------------------------------------- diff --git a/tests/compliance/conftest.py b/tests/compliance/conftest.py new file mode 100644 index 0000000..bc76692 --- /dev/null +++ b/tests/compliance/conftest.py @@ -0,0 +1,53 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas +import pytest + + +@pytest.fixture(params=["ffill", "bfill"]) +def fillna_method(request): + """ + Parametrized fixture giving method parameters 'ffill' and 'bfill' for + Series.fillna(method=) testing. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param + + +@pytest.fixture +def na_value(): + return pandas.NaT + + +@pytest.fixture +def na_cmp(): + """ + Binary operator for comparing NA values. + + Should return a function of two arguments that returns + True if both arguments are (scalar) NA for your type. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + and + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/test_datetime.py + """ + + def cmp(a, b): + return a is pandas.NaT and a is b + + return cmp diff --git a/tests/compliance/date/conftest.py b/tests/compliance/date/conftest.py new file mode 100644 index 0000000..e25ccc9 --- /dev/null +++ b/tests/compliance/date/conftest.py @@ -0,0 +1,47 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import numpy +import pytest + +from db_dtypes import DateArray, DateDtype + + +@pytest.fixture +def data(): + return DateArray( + numpy.arange( + datetime.datetime(1900, 1, 1), + datetime.datetime(2099, 12, 31), + datetime.timedelta(days=731), + dtype="datetime64[ns]", + ) + ) + + +@pytest.fixture +def data_missing(): + """Length-2 array with [NA, Valid] + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return DateArray([None, datetime.date(2022, 1, 27)]) + + +@pytest.fixture +def dtype(): + return DateDtype() diff --git a/tests/compliance/date/test_date_compliance.py b/tests/compliance/date/test_date_compliance.py new file mode 100644 index 0000000..a805ecd --- /dev/null +++ b/tests/compliance/date/test_date_compliance.py @@ -0,0 +1,47 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for extension interface compliance, inherited from pandas. + +See: +https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/decimal/test_decimal.py +and +https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/test_period.py +""" + +from pandas.tests.extension import base + + +class TestDtype(base.BaseDtypeTests): + pass + + +class TestInterface(base.BaseInterfaceTests): + pass + + +class TestConstructors(base.BaseConstructorsTests): + pass + + +class TestReshaping(base.BaseReshapingTests): + pass + + +class TestGetitem(base.BaseGetitemTests): + pass + + +class TestMissing(base.BaseMissingTests): + pass diff --git a/tests/unit/test_date.py b/tests/unit/test_date.py index bf877ea..bce2dc1 100644 --- a/tests/unit/test_date.py +++ b/tests/unit/test_date.py @@ -13,15 +13,27 @@ # limitations under the License. import datetime +import operator import pandas +import pandas.testing import pytest -# To register the types. -import db_dtypes # noqa +import db_dtypes from db_dtypes import pandas_backports +def test_construct_from_string_with_nonstring(): + with pytest.raises(TypeError): + db_dtypes.DateDtype.construct_from_string(object()) + + +def test__cmp_method_with_scalar(): + input_array = db_dtypes.DateArray([datetime.date(1900, 1, 1)]) + got = input_array._cmp_method(datetime.date(1900, 1, 1), operator.eq) + assert got[0] + + @pytest.mark.parametrize( "value, expected", [