From ae224958327ff75e60833cc3c8ae07b9508c5004 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Mon, 21 Mar 2022 12:12:11 -0500 Subject: [PATCH] fix: dbdate and dbtime support set item --- db_dtypes/__init__.py | 10 ++++- db_dtypes/core.py | 18 ++++---- db_dtypes/pandas_backports.py | 2 +- tests/unit/test_date.py | 82 ++++++++++++++++++++++++++++------- tests/unit/test_time.py | 6 +++ 5 files changed, 91 insertions(+), 27 deletions(-) diff --git a/db_dtypes/__init__.py b/db_dtypes/__init__.py index d8e2ae5..7889dac 100644 --- a/db_dtypes/__init__.py +++ b/db_dtypes/__init__.py @@ -106,6 +106,9 @@ def _datetime( r"(?:\.(?P\d*))?)?)?\s*$" ).match, ) -> Optional[numpy.datetime64]: + if isinstance(scalar, numpy.datetime64): + return scalar + # Convert pyarrow values to datetime.time. if isinstance(scalar, (pyarrow.Time32Scalar, pyarrow.Time64Scalar)): scalar = ( @@ -116,7 +119,7 @@ def _datetime( ) if pandas.isna(scalar): - return None + return numpy.datetime64("NaT") if isinstance(scalar, datetime.time): return pandas.Timestamp( year=1970, @@ -238,12 +241,15 @@ def _datetime( scalar, match_fn=re.compile(r"\s*(?P\d+)-(?P\d+)-(?P\d+)\s*$").match, ) -> Optional[numpy.datetime64]: + if isinstance(scalar, numpy.datetime64): + return scalar + # Convert pyarrow values to datetime.date. if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)): scalar = scalar.as_py() if pandas.isna(scalar): - return None + return numpy.datetime64("NaT") elif isinstance(scalar, datetime.date): return pandas.Timestamp( year=scalar.year, month=scalar.month, day=scalar.day diff --git a/db_dtypes/core.py b/db_dtypes/core.py index 14d76aa..7879571 100644 --- a/db_dtypes/core.py +++ b/db_dtypes/core.py @@ -100,14 +100,6 @@ def _cmp_method(self, other, op): return NotImplemented return op(self._ndarray, other._ndarray) - def __setitem__(self, key, value): - if is_list_like(value): - _datetime = self._datetime - value = [_datetime(v) for v in value] - elif not pandas.isna(value): - value = self._datetime(value) - return super().__setitem__(key, value) - def _from_factorized(self, unique, original): return self.__class__(unique) @@ -121,6 +113,16 @@ def _validate_scalar(self, value): """ return self._datetime(value) + def _validate_setitem_value(self, value): + """ + Convert a value for use in setting a value in the backing numpy array. + """ + if is_list_like(value): + _datetime = self._datetime + return [_datetime(v) for v in value] + + return self._datetime(value) + def any( self, *, diff --git a/db_dtypes/pandas_backports.py b/db_dtypes/pandas_backports.py index f53adff..0e39986 100644 --- a/db_dtypes/pandas_backports.py +++ b/db_dtypes/pandas_backports.py @@ -126,7 +126,7 @@ def __getitem__(self, index): return self.__class__(value, self._dtype) def __setitem__(self, index, value): - self._ndarray[index] = value + self._ndarray[index] = self._validate_setitem_value(value) def __len__(self): return len(self._ndarray) diff --git a/tests/unit/test_date.py b/tests/unit/test_date.py index 79c97ac..fb41620 100644 --- a/tests/unit/test_date.py +++ b/tests/unit/test_date.py @@ -24,6 +24,33 @@ from db_dtypes import pandas_backports +VALUE_PARSING_TEST_CASES = [ + # Min/Max values for pandas.Timestamp. + ("1677-09-22", datetime.date(1677, 9, 22)), + ("2262-04-11", datetime.date(2262, 4, 11)), + # Typical "zero" values. + ("1900-01-01", datetime.date(1900, 1, 1)), + ("1970-01-01", datetime.date(1970, 1, 1)), + # Assorted values. + ("1993-10-31", datetime.date(1993, 10, 31)), + (datetime.date(1993, 10, 31), datetime.date(1993, 10, 31)), + ("2012-02-29", datetime.date(2012, 2, 29)), + (numpy.datetime64("2012-02-29"), datetime.date(2012, 2, 29)), + ("2021-12-17", datetime.date(2021, 12, 17)), + (pandas.Timestamp("2021-12-17"), datetime.date(2021, 12, 17)), + ("2038-01-19", datetime.date(2038, 1, 19)), +] + +NULL_VALUE_TEST_CASES = [ + None, + pandas.NaT, + float("nan"), +] + +if hasattr(pandas, "NA"): + NULL_VALUE_TEST_CASES.append(pandas.NA) + + def test_box_func(): input_array = db_dtypes.DateArray([]) input_datetime = datetime.datetime(2022, 3, 16) @@ -58,26 +85,49 @@ def test__cmp_method_with_scalar(): assert got[0] -@pytest.mark.parametrize( - "value, expected", - [ - # Min/Max values for pandas.Timestamp. - ("1677-09-22", datetime.date(1677, 9, 22)), - ("2262-04-11", datetime.date(2262, 4, 11)), - # Typical "zero" values. - ("1900-01-01", datetime.date(1900, 1, 1)), - ("1970-01-01", datetime.date(1970, 1, 1)), - # Assorted values. - ("1993-10-31", datetime.date(1993, 10, 31)), - ("2012-02-29", datetime.date(2012, 2, 29)), - ("2021-12-17", datetime.date(2021, 12, 17)), - ("2038-01-19", datetime.date(2038, 1, 19)), - ], -) +@pytest.mark.parametrize("value, expected", VALUE_PARSING_TEST_CASES) def test_date_parsing(value, expected): assert pandas.Series([value], dtype="dbdate")[0] == expected +@pytest.mark.parametrize("value", NULL_VALUE_TEST_CASES) +def test_date_parsing_null(value): + assert pandas.Series([value], dtype="dbdate")[0] is pandas.NaT + + +@pytest.mark.parametrize("value, expected", VALUE_PARSING_TEST_CASES) +def test_date_set_item(value, expected): + series = pandas.Series([None], dtype="dbdate") + series[0] = value + assert series[0] == expected + + +@pytest.mark.parametrize("value", NULL_VALUE_TEST_CASES) +def test_date_set_item_null(value): + series = pandas.Series(["1970-01-01"], dtype="dbdate") + series[0] = value + assert series[0] is pandas.NaT + + +def test_date_set_slice(): + series = pandas.Series([None, None, None], dtype="dbdate") + series[:] = [ + datetime.date(2022, 3, 21), + "2011-12-13", + numpy.datetime64("1998-09-04"), + ] + assert series[0] == datetime.date(2022, 3, 21) + assert series[1] == datetime.date(2011, 12, 13) + assert series[2] == datetime.date(1998, 9, 4) + + +def test_date_set_slice_null(): + series = pandas.Series(["1970-01-01"] * len(NULL_VALUE_TEST_CASES), dtype="dbdate") + series[:] = NULL_VALUE_TEST_CASES + for row_index in range(len(NULL_VALUE_TEST_CASES)): + assert series[row_index] is pandas.NaT + + @pytest.mark.parametrize( "value, error", [ diff --git a/tests/unit/test_time.py b/tests/unit/test_time.py index db533f5..bdfc48b 100644 --- a/tests/unit/test_time.py +++ b/tests/unit/test_time.py @@ -73,8 +73,14 @@ def test_box_func(): # Fractional seconds can cause rounding problems if cast to float. See: # https://github.com/googleapis/python-db-dtypes-pandas/issues/18 ("0:0:59.876543", datetime.time(0, 0, 59, 876543)), + ( + numpy.datetime64("1970-01-01 00:00:59.876543"), + datetime.time(0, 0, 59, 876543), + ), ("01:01:01.010101", datetime.time(1, 1, 1, 10101)), + (pandas.Timestamp("1970-01-01 01:01:01.010101"), datetime.time(1, 1, 1, 10101)), ("09:09:09.090909", datetime.time(9, 9, 9, 90909)), + (datetime.time(9, 9, 9, 90909), datetime.time(9, 9, 9, 90909)), ("11:11:11.111111", datetime.time(11, 11, 11, 111111)), ("19:16:23.987654", datetime.time(19, 16, 23, 987654)), # Microsecond precision