diff --git a/db_dtypes/core.py b/db_dtypes/core.py index f577960..68123e1 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -90,14 +90,14 @@ def _cmp_method(self, other, op): if is_scalar(other) and (pandas.isna(other) or type(other) == self.dtype.type): other = type(self)([other]) + if type(other) != type(self): + return NotImplemented + oshape = getattr(other, "shape", None) if oshape != self.shape and oshape != (1,) and self.shape != (1,): raise TypeError( "Can't compare arrays with different shapes", self.shape, oshape ) - - if type(other) != type(self): - return NotImplemented return op(self._ndarray, other._ndarray) def _from_factorized(self, unique, original): diff --git a/tests/compliance/conftest.py b/tests/compliance/conftest.py index 54b767c..b891ed6 100644 --- a/tests/compliance/conftest.py +++ b/tests/compliance/conftest.py @@ -12,10 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import operator + import pandas import pytest +@pytest.fixture(params=[True, False]) +def as_array(request): + """ + Boolean fixture to support ExtensionDtype _from_sequence method testing. + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param + + @pytest.fixture(params=[True, False]) def as_frame(request): """ @@ -38,6 +51,36 @@ def as_series(request): return request.param +@pytest.fixture(params=[True, False]) +def box_in_series(request): + """ + Whether to box the data in a Series + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param + + +@pytest.fixture( + params=[ + operator.eq, + operator.ne, + operator.gt, + operator.ge, + operator.lt, + operator.le, + ] +) +def comparison_op(request): + """ + Fixture for operator module comparison functions. + + See: https://github.com/pandas-dev/pandas/blob/main/pandas/conftest.py + """ + return request.param + + @pytest.fixture(params=["ffill", "bfill"]) def fillna_method(request): """ @@ -50,6 +93,25 @@ def fillna_method(request): return request.param +@pytest.fixture( + params=[ + lambda x: 1, + lambda x: [1] * len(x), + lambda x: pandas.Series([1] * len(x)), + lambda x: x, + ], + ids=["scalar", "list", "series", "object"], +) +def groupby_apply_op(request): + """ + Functions to test groupby.apply(). + + See: + https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/conftest.py + """ + return request.param + + @pytest.fixture def invalid_scalar(data): """ diff --git a/tests/compliance/date/test_date_compliance.py b/tests/compliance/date/test_date_compliance.py index 13327a7..6281986 100644 --- a/tests/compliance/date/test_date_compliance.py +++ b/tests/compliance/date/test_date_compliance.py @@ -74,3 +74,27 @@ def test_value_counts(self, all_data, dropna): expected = pandas.Series(other).value_counts(dropna=dropna).sort_index() self.assert_series_equal(result, expected) + + +class TestCasting(base.BaseCastingTests): + pass + + +class TestGroupby(base.BaseGroupbyTests): + pass + + +class TestSetitem(base.BaseSetitemTests): + pass + + +class TestPrinting(base.BasePrintingTests): + pass + + +# TODO(https://github.com/googleapis/python-db-dtypes-pandas/issues/78): Add +# compliance tests for arithmetic operations. + + +class TestComparisonOps(base.BaseComparisonOpsTests): + pass diff --git a/tests/unit/test_dtypes.py b/tests/unit/test_dtypes.py index 66074d8..dc1613b 100644 --- a/tests/unit/test_dtypes.py +++ b/tests/unit/test_dtypes.py @@ -169,16 +169,12 @@ def test_timearray_comparisons( np.testing.assert_array_equal(comparisons[op](left, r), expected) np.testing.assert_array_equal(complements[op](left, r), ~expected) - # Bad shape - for bad_shape in ([], [1, 2, 3]): + # Bad shape, but same type + for bad_shape in ([], sample_values[:3]): with pytest.raises( TypeError, match="Can't compare arrays with different shapes" ): - comparisons[op](left, np.array(bad_shape)) - with pytest.raises( - TypeError, match="Can't compare arrays with different shapes" - ): - complements[op](left, np.array(bad_shape)) + comparisons[op](left, _cls(dtype)._from_sequence(bad_shape)) # Bad items for bad_items in (