diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 2a5e5b27973e1..4890f82c5fdda 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -109,12 +109,16 @@ from pandas._libs.missing cimport ( is_null_datetime64, is_null_timedelta64, ) -from pandas._libs.tslibs.conversion cimport convert_to_tsobject +from pandas._libs.tslibs.conversion cimport ( + _TSObject, + convert_to_tsobject, +) from pandas._libs.tslibs.nattype cimport ( NPY_NAT, c_NaT as NaT, checknull_with_nat, ) +from pandas._libs.tslibs.np_datetime cimport NPY_FR_ns from pandas._libs.tslibs.offsets cimport is_offset_object from pandas._libs.tslibs.period cimport is_period_object from pandas._libs.tslibs.timedeltas cimport convert_to_timedelta64 @@ -2378,6 +2382,7 @@ def maybe_convert_objects(ndarray[object] objects, ndarray[uint8_t] bools Seen seen = Seen() object val + _TSObject tsobj float64_t fnan = np.nan if dtype_if_all_nat is not None: @@ -2470,7 +2475,8 @@ def maybe_convert_objects(ndarray[object] objects, else: seen.datetime_ = True try: - convert_to_tsobject(val, None, None, 0, 0) + tsobj = convert_to_tsobject(val, None, None, 0, 0) + tsobj.ensure_reso(NPY_FR_ns) except OutOfBoundsDatetime: seen.object_ = True break diff --git a/pandas/_libs/src/ujson/python/objToJSON.c b/pandas/_libs/src/ujson/python/objToJSON.c index 591dff72e3872..d4ec21f38cdad 100644 --- a/pandas/_libs/src/ujson/python/objToJSON.c +++ b/pandas/_libs/src/ujson/python/objToJSON.c @@ -278,11 +278,42 @@ static int is_simple_frame(PyObject *obj) { } static npy_int64 get_long_attr(PyObject *o, const char *attr) { + // NB we are implicitly assuming that o is a Timedelta or Timestamp, or NaT + npy_int64 long_val; PyObject *value = PyObject_GetAttrString(o, attr); long_val = (PyLong_Check(value) ? PyLong_AsLongLong(value) : PyLong_AsLong(value)); + Py_DECREF(value); + + if (object_is_nat_type(o)) { + // i.e. o is NaT, long_val will be NPY_MIN_INT64 + return long_val; + } + + // ensure we are in nanoseconds, similar to Timestamp._as_creso or _as_unit + PyObject* reso = PyObject_GetAttrString(o, "_creso"); + if (!PyLong_Check(reso)) { + // https://github.com/pandas-dev/pandas/pull/49034#discussion_r1023165139 + Py_DECREF(reso); + return -1; + } + + long cReso = PyLong_AsLong(reso); + Py_DECREF(reso); + if (cReso == -1 && PyErr_Occurred()) { + return -1; + } + + if (cReso == NPY_FR_us) { + long_val = long_val * 1000L; + } else if (cReso == NPY_FR_ms) { + long_val = long_val * 1000000L; + } else if (cReso == NPY_FR_s) { + long_val = long_val * 1000000000L; + } + return long_val; } @@ -1265,6 +1296,7 @@ char **NpyArr_encodeLabels(PyArrayObject *labels, PyObjectEncoder *enc, } else if (PyDate_Check(item) || PyDelta_Check(item)) { is_datetimelike = 1; if (PyObject_HasAttrString(item, "value")) { + // see test_date_index_and_values for case with non-nano nanosecVal = get_long_attr(item, "value"); } else { if (PyDelta_Check(item)) { diff --git a/pandas/_libs/tslib.pyx b/pandas/_libs/tslib.pyx index 5679ea9306c72..7fee48c0a5d1f 100644 --- a/pandas/_libs/tslib.pyx +++ b/pandas/_libs/tslib.pyx @@ -841,16 +841,19 @@ cdef _array_to_datetime_object( cdef inline bint _parse_today_now(str val, int64_t* iresult, bint utc): # We delay this check for as long as possible # because it catches relatively rare cases + + # Multiply by 1000 to convert to nanos, since these methods naturally have + # microsecond resolution if val == "now": if utc: - iresult[0] = Timestamp.utcnow().value + iresult[0] = Timestamp.utcnow().value * 1000 else: # GH#18705 make sure to_datetime("now") matches Timestamp("now") # Note using Timestamp.now() is faster than Timestamp("now") - iresult[0] = Timestamp.now().value + iresult[0] = Timestamp.now().value * 1000 return True elif val == "today": - iresult[0] = Timestamp.today().value + iresult[0] = Timestamp.today().value * 1000 return True return False diff --git a/pandas/_libs/tslibs/conversion.pyx b/pandas/_libs/tslibs/conversion.pyx index d0d6dc3f42d85..17facf9e16f4b 100644 --- a/pandas/_libs/tslibs/conversion.pyx +++ b/pandas/_libs/tslibs/conversion.pyx @@ -32,6 +32,7 @@ from pandas._libs.tslibs.dtypes cimport ( from pandas._libs.tslibs.np_datetime cimport ( NPY_DATETIMEUNIT, NPY_FR_ns, + NPY_FR_us, check_dts_bounds, convert_reso, get_datetime64_unit, @@ -212,7 +213,12 @@ cdef class _TSObject: cdef int64_t ensure_reso(self, NPY_DATETIMEUNIT creso) except? -1: if self.creso != creso: - self.value = convert_reso(self.value, self.creso, creso, False) + try: + self.value = convert_reso(self.value, self.creso, creso, False) + except OverflowError as err: + raise OutOfBoundsDatetime from err + + self.creso = creso return self.value @@ -288,11 +294,22 @@ cdef _TSObject convert_to_tsobject(object ts, tzinfo tz, str unit, obj.value = ts pandas_datetime_to_datetimestruct(ts, NPY_FR_ns, &obj.dts) elif PyDateTime_Check(ts): - return convert_datetime_to_tsobject(ts, tz, nanos) + if nanos == 0: + if isinstance(ts, ABCTimestamp): + reso = abbrev_to_npy_unit(ts.unit) # TODO: faster way to do this? + else: + # TODO: what if user explicitly passes nanos=0? + reso = NPY_FR_us + else: + reso = NPY_FR_ns + return convert_datetime_to_tsobject(ts, tz, nanos, reso=reso) elif PyDate_Check(ts): # Keep the converter same as PyDateTime's + # For date object we give the lowest supported resolution, i.e. "s" ts = datetime.combine(ts, time()) - return convert_datetime_to_tsobject(ts, tz) + return convert_datetime_to_tsobject( + ts, tz, nanos=0, reso=NPY_DATETIMEUNIT.NPY_FR_s + ) else: from .period import Period if isinstance(ts, Period): @@ -346,6 +363,7 @@ cdef _TSObject convert_datetime_to_tsobject( _TSObject obj = _TSObject() int64_t pps + obj.creso = reso obj.fold = ts.fold if tz is not None: tz = maybe_get_tz(tz) diff --git a/pandas/_libs/tslibs/offsets.pyx b/pandas/_libs/tslibs/offsets.pyx index 82f134a348fee..4c6493652b216 100644 --- a/pandas/_libs/tslibs/offsets.pyx +++ b/pandas/_libs/tslibs/offsets.pyx @@ -162,7 +162,8 @@ def apply_wraps(func): result = func(self, other) - result = Timestamp(result) + result = (<_Timestamp>Timestamp(result))._as_creso(other._creso) + if self._adjust_dst: result = result.tz_localize(tz) @@ -175,9 +176,10 @@ def apply_wraps(func): if result.nanosecond != nano: if result.tz is not None: # convert to UTC - value = result.tz_localize(None).value + res = result.tz_localize(None) else: - value = result.value + res = result + value = res.as_unit("ns").value result = Timestamp(value + nano) if tz is not None and result.tzinfo is None: diff --git a/pandas/_libs/tslibs/timestamps.pxd b/pandas/_libs/tslibs/timestamps.pxd index fc62e04961dcb..1b87d2ba4eb25 100644 --- a/pandas/_libs/tslibs/timestamps.pxd +++ b/pandas/_libs/tslibs/timestamps.pxd @@ -33,4 +33,4 @@ cdef class _Timestamp(ABCTimestamp): cdef bint _compare_outside_nanorange(_Timestamp self, datetime other, int op) except -1 cdef bint _compare_mismatched_resos(_Timestamp self, _Timestamp other, int op) - cdef _Timestamp _as_creso(_Timestamp self, NPY_DATETIMEUNIT reso, bint round_ok=*) + cdef _Timestamp _as_creso(_Timestamp self, NPY_DATETIMEUNIT creso, bint round_ok=*) diff --git a/pandas/_libs/tslibs/timestamps.pyx b/pandas/_libs/tslibs/timestamps.pyx index c40712251ae5b..8e9c8d40398d9 100644 --- a/pandas/_libs/tslibs/timestamps.pyx +++ b/pandas/_libs/tslibs/timestamps.pyx @@ -497,9 +497,9 @@ cdef class _Timestamp(ABCTimestamp): # Matching numpy, we cast to the higher resolution. Unlike numpy, # we raise instead of silently overflowing during this casting. if self._creso < other._creso: - self = (<_Timestamp>self)._as_creso(other._creso, round_ok=False) + self = (<_Timestamp>self)._as_creso(other._creso, round_ok=True) elif self._creso > other._creso: - other = (<_Timestamp>other)._as_creso(self._creso, round_ok=False) + other = (<_Timestamp>other)._as_creso(self._creso, round_ok=True) # scalar Timestamp/datetime - Timestamp/datetime -> yields a # Timedelta @@ -983,15 +983,22 @@ cdef class _Timestamp(ABCTimestamp): # Conversion Methods @cython.cdivision(False) - cdef _Timestamp _as_creso(self, NPY_DATETIMEUNIT reso, bint round_ok=True): + cdef _Timestamp _as_creso(self, NPY_DATETIMEUNIT creso, bint round_ok=True): cdef: int64_t value - if reso == self._creso: + if creso == self._creso: return self - value = convert_reso(self.value, self._creso, reso, round_ok=round_ok) - return type(self)._from_value_and_reso(value, reso=reso, tz=self.tzinfo) + try: + value = convert_reso(self.value, self._creso, creso, round_ok=round_ok) + except OverflowError as err: + unit = npy_unit_to_abbrev(creso) + raise OutOfBoundsDatetime( + f"Cannot cast {self} to unit='{unit}' without overflow." + ) from err + + return type(self)._from_value_and_reso(value, reso=creso, tz=self.tzinfo) def as_unit(self, str unit, bint round_ok=True): """ @@ -1025,7 +1032,7 @@ cdef class _Timestamp(ABCTimestamp): -------- >>> ts = pd.Timestamp(2020, 3, 14, 15) >>> ts.asm8 - numpy.datetime64('2020-03-14T15:00:00.000000000') + numpy.datetime64('2020-03-14T15:00:00.000000') """ return self.to_datetime64() diff --git a/pandas/core/array_algos/take.py b/pandas/core/array_algos/take.py index 8dc855bd25f78..00b1c898942b3 100644 --- a/pandas/core/array_algos/take.py +++ b/pandas/core/array_algos/take.py @@ -360,7 +360,14 @@ def wrapper( if out_dtype is not None: out = out.view(out_dtype) if fill_wrap is not None: + # FIXME: if we get here with dt64/td64 we need to be sure we have + # matching resos + if fill_value.dtype.kind == "m": + fill_value = fill_value.astype("m8[ns]") + else: + fill_value = fill_value.astype("M8[ns]") fill_value = fill_wrap(fill_value) + f(arr, indexer, out, fill_value=fill_value) return wrapper diff --git a/pandas/core/arrays/datetimes.py b/pandas/core/arrays/datetimes.py index 5fdf2c88503a5..286cec3afdc45 100644 --- a/pandas/core/arrays/datetimes.py +++ b/pandas/core/arrays/datetimes.py @@ -445,7 +445,7 @@ def _generate_range( # type: ignore[override] i8values = generate_regular_range(start, end, periods, freq, unit=unit) else: xdr = _generate_range( - start=start, end=end, periods=periods, offset=freq + start=start, end=end, periods=periods, offset=freq, unit=unit ) i8values = np.array([x.value for x in xdr], dtype=np.int64) @@ -508,7 +508,10 @@ def _unbox_scalar(self, value) -> np.datetime64: if not isinstance(value, self._scalar_type) and value is not NaT: raise ValueError("'value' should be a Timestamp.") self._check_compatible_with(value) - return value.asm8 + if value is NaT: + return np.datetime64(value.value, self.unit) + else: + return value.as_unit(self.unit).asm8 def _scalar_from_string(self, value) -> Timestamp | NaTType: return Timestamp(value, tz=self.tz) @@ -2475,6 +2478,8 @@ def _generate_range( end: Timestamp | None, periods: int | None, offset: BaseOffset, + *, + unit: str, ): """ Generates a sequence of dates corresponding to the specified time @@ -2486,7 +2491,8 @@ def _generate_range( start : Timestamp or None end : Timestamp or None periods : int or None - offset : DateOffset, + offset : DateOffset + unit : str Notes ----- @@ -2506,13 +2512,20 @@ def _generate_range( start = Timestamp(start) # type: ignore[arg-type] # Non-overlapping identity check (left operand type: "Timestamp", right # operand type: "NaTType") - start = start if start is not NaT else None # type: ignore[comparison-overlap] + if start is not NaT: # type: ignore[comparison-overlap] + start = start.as_unit(unit) + else: + start = None + # Argument 1 to "Timestamp" has incompatible type "Optional[Timestamp]"; # expected "Union[integer[Any], float, str, date, datetime64]" end = Timestamp(end) # type: ignore[arg-type] # Non-overlapping identity check (left operand type: "Timestamp", right # operand type: "NaTType") - end = end if end is not NaT else None # type: ignore[comparison-overlap] + if end is not NaT: # type: ignore[comparison-overlap] + end = end.as_unit(unit) + else: + end = None if start and not offset.is_on_offset(start): # Incompatible types in assignment (expression has type "datetime", @@ -2553,7 +2566,7 @@ def _generate_range( break # faster than cur + offset - next_date = offset._apply(cur) + next_date = offset._apply(cur).as_unit(unit) if next_date <= cur: raise ValueError(f"Offset {offset} did not increment date") cur = next_date @@ -2567,7 +2580,7 @@ def _generate_range( break # faster than cur + offset - next_date = offset._apply(cur) + next_date = offset._apply(cur).as_unit(unit) if next_date >= cur: raise ValueError(f"Offset {offset} did not decrement date") cur = next_date diff --git a/pandas/core/computation/pytables.py b/pandas/core/computation/pytables.py index 93928d8bf6b83..4055be3f943fa 100644 --- a/pandas/core/computation/pytables.py +++ b/pandas/core/computation/pytables.py @@ -11,6 +11,7 @@ import numpy as np from pandas._libs.tslibs import ( + NaT, Timedelta, Timestamp, ) @@ -216,6 +217,8 @@ def stringify(value): v = stringify(v) v = ensure_decoded(v) v = Timestamp(v) + if v is not NaT: + v = v.as_unit("ns") # pyright: ignore[reportGeneralTypeIssues] if v.tz is not None: v = v.tz_convert("UTC") return TermValue(v, v.value, kind) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index a668dec6e073e..60488a8ef9715 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -754,16 +754,21 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> tuple[DtypeObj, elif isinstance(val, (np.datetime64, dt.datetime)): try: val = Timestamp(val) + # error: Non-overlapping identity check (left operand type: + # "Timestamp", right operand type: "NaTType") + if val is not NaT: # type: ignore[comparison-overlap] + val = val.as_unit("ns") except OutOfBoundsDatetime: return _dtype_obj, val # error: Non-overlapping identity check (left operand type: "Timestamp", # right operand type: "NaTType") if val is NaT or val.tz is None: # type: ignore[comparison-overlap] - dtype = np.dtype("M8[ns]") val = val.to_datetime64() + dtype = val.dtype # TODO(2.0): this should be dtype = val.dtype # to get the correct M8 resolution + # TODO: test with datetime(2920, 10, 1) based on test_replace_dtypes else: if pandas_dtype: dtype = DatetimeTZDtype(unit="ns", tz=val.tz) diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 022476af1e173..c30323338e676 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -9,6 +9,7 @@ import warnings import numpy as np +import pytz from pandas._libs import ( NaT, @@ -578,7 +579,7 @@ def get_loc(self, key, method=None, tolerance=None): try: parsed, reso = self._parse_with_reso(key) - except ValueError as err: + except (ValueError, pytz.NonExistentTimeError) as err: raise KeyError(key) from err self._disallow_mismatched_indexing(parsed) diff --git a/pandas/core/resample.py b/pandas/core/resample.py index f5c76aade9956..f5aeb61df633a 100644 --- a/pandas/core/resample.py +++ b/pandas/core/resample.py @@ -2074,6 +2074,11 @@ def _adjust_dates_anchored( # not a multiple of the frequency. See GH 8683 # To handle frequencies that are not multiple or divisible by a day we let # the possibility to define a fixed origin timestamp. See GH 31809 + first = first.as_unit("ns") + last = last.as_unit("ns") + if offset is not None: + offset = offset.as_unit("ns") + origin_nanos = 0 # origin == "epoch" if origin == "start_day": origin_nanos = first.normalize().value diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 269dd169cdeaa..18f89754cb672 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -418,7 +418,7 @@ def parse_dates_safe( d = {} if is_datetime64_dtype(dates.dtype): if delta: - time_delta = dates - stata_epoch + time_delta = dates - Timestamp(stata_epoch).as_unit("ns") d["delta"] = time_delta._values.view(np.int64) // 1000 # microseconds if days or year: date_index = DatetimeIndex(dates) diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 8c4701bb2f8ee..89c9ba85fcfa9 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -12,11 +12,11 @@ import numpy as np import pytest -from pandas._libs.tslibs import tz_compare -from pandas._libs.tslibs.dtypes import ( - NpyDatetimeUnit, +from pandas._libs.tslibs import ( npy_unit_to_abbrev, + tz_compare, ) +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit from pandas.core.dtypes.dtypes import DatetimeTZDtype diff --git a/pandas/tests/frame/methods/test_replace.py b/pandas/tests/frame/methods/test_replace.py index 626bc658b199c..1923299476a32 100644 --- a/pandas/tests/frame/methods/test_replace.py +++ b/pandas/tests/frame/methods/test_replace.py @@ -839,7 +839,7 @@ def test_replace_for_new_dtypes(self, datetime_frame): ], ) def test_replace_dtypes(self, frame, to_replace, value, expected): - result = getattr(frame, "replace")(to_replace, value) + result = frame.replace(to_replace, value) tm.assert_frame_equal(result, expected) def test_replace_input_formats_listlike(self): diff --git a/pandas/tests/frame/test_constructors.py b/pandas/tests/frame/test_constructors.py index 767912a7d2667..32a4dc06d08e2 100644 --- a/pandas/tests/frame/test_constructors.py +++ b/pandas/tests/frame/test_constructors.py @@ -3026,10 +3026,10 @@ def test_from_scalar_datetimelike_mismatched(self, constructor, cls): "but DatetimeArray._from_sequence has not" ) @pytest.mark.parametrize("cls", [datetime, np.datetime64]) - def test_from_out_of_ns_bounds_datetime(self, constructor, cls, request): + def test_from_out_of_bounds_ns_datetime(self, constructor, cls): # scalar that won't fit in nanosecond dt64, but will fit in microsecond scalar = datetime(9999, 1, 1) - exp_dtype = "M8[us]" # smallest reso that fits + exp_dtype = "M8[us]" # pydatetime objects default to this reso if cls is np.datetime64: scalar = np.datetime64(scalar, "D") exp_dtype = "M8[s]" # closest reso to input @@ -3071,11 +3071,12 @@ def test_from_out_of_bounds_ns_timedelta(self, constructor, cls): assert item.asm8.dtype == exp_dtype assert dtype == exp_dtype - def test_out_of_s_bounds_timedelta64(self, constructor): - scalar = np.timedelta64(np.iinfo(np.int64).max, "D") + @pytest.mark.parametrize("cls", [np.datetime64, np.timedelta64]) + def test_out_of_s_bounds_timedelta64(self, constructor, cls): + scalar = cls(np.iinfo(np.int64).max, "D") result = constructor(scalar) item = get1(result) - assert type(item) is np.timedelta64 + assert type(item) is cls dtype = result.dtype if isinstance(result, Series) else result.dtypes.iloc[0] assert dtype == object diff --git a/pandas/tests/indexes/datetimes/test_constructors.py b/pandas/tests/indexes/datetimes/test_constructors.py index 7b99d40ae0e16..effabafb4db67 100644 --- a/pandas/tests/indexes/datetimes/test_constructors.py +++ b/pandas/tests/indexes/datetimes/test_constructors.py @@ -1204,8 +1204,8 @@ def test_timestamp_constructor_infer_fold_from_value(tz, ts_input, fold_out): @pytest.mark.parametrize( "ts_input,fold,value_out", [ - (datetime(2019, 10, 27, 1, 30, 0, 0), 0, 1572136200000000000), - (datetime(2019, 10, 27, 1, 30, 0, 0), 1, 1572139800000000000), + (datetime(2019, 10, 27, 1, 30, 0, 0), 0, 1572136200000000), + (datetime(2019, 10, 27, 1, 30, 0, 0), 1, 1572139800000000), ], ) def test_timestamp_constructor_adjust_value_for_fold(tz, ts_input, fold, value_out): diff --git a/pandas/tests/indexes/datetimes/test_date_range.py b/pandas/tests/indexes/datetimes/test_date_range.py index e90f9fb2b5e36..14bfb14d27239 100644 --- a/pandas/tests/indexes/datetimes/test_date_range.py +++ b/pandas/tests/indexes/datetimes/test_date_range.py @@ -847,19 +847,23 @@ def test_date_range_with_tz(self, tzstr): class TestGenRangeGeneration: def test_generate(self): - rng1 = list(generate_range(START, END, periods=None, offset=BDay())) - rng2 = list(generate_range(START, END, periods=None, offset="B")) + rng1 = list(generate_range(START, END, periods=None, offset=BDay(), unit="ns")) + rng2 = list(generate_range(START, END, periods=None, offset="B", unit="ns")) assert rng1 == rng2 def test_generate_cday(self): - rng1 = list(generate_range(START, END, periods=None, offset=CDay())) - rng2 = list(generate_range(START, END, periods=None, offset="C")) + rng1 = list(generate_range(START, END, periods=None, offset=CDay(), unit="ns")) + rng2 = list(generate_range(START, END, periods=None, offset="C", unit="ns")) assert rng1 == rng2 def test_1(self): rng = list( generate_range( - start=datetime(2009, 3, 25), end=None, periods=2, offset=BDay() + start=datetime(2009, 3, 25), + end=None, + periods=2, + offset=BDay(), + unit="ns", ) ) expected = [datetime(2009, 3, 25), datetime(2009, 3, 26)] @@ -872,6 +876,7 @@ def test_2(self): end=datetime(2008, 1, 3), periods=None, offset=BDay(), + unit="ns", ) ) expected = [datetime(2008, 1, 1), datetime(2008, 1, 2), datetime(2008, 1, 3)] @@ -884,6 +889,7 @@ def test_3(self): end=datetime(2008, 1, 6), periods=None, offset=BDay(), + unit="ns", ) ) expected = [] diff --git a/pandas/tests/indexes/datetimes/test_timezones.py b/pandas/tests/indexes/datetimes/test_timezones.py index 624aec7d978e9..e8bb1252c3033 100644 --- a/pandas/tests/indexes/datetimes/test_timezones.py +++ b/pandas/tests/indexes/datetimes/test_timezones.py @@ -1137,7 +1137,7 @@ def test_dti_convert_tz_aware_datetime_datetime(self, tz): assert timezones.tz_compare(result.tz, tz) converted = to_datetime(dates_aware, utc=True) - ex_vals = np.array([Timestamp(x).value for x in dates_aware]) + ex_vals = np.array([Timestamp(x).as_unit("ns").value for x in dates_aware]) tm.assert_numpy_array_equal(converted.asi8, ex_vals) assert converted.tz is pytz.utc diff --git a/pandas/tests/io/json/test_ujson.py b/pandas/tests/io/json/test_ujson.py index 109c6dbb469c9..3c841d829efd7 100644 --- a/pandas/tests/io/json/test_ujson.py +++ b/pandas/tests/io/json/test_ujson.py @@ -383,7 +383,7 @@ def test_encode_as_null(self, decoded_input): def test_datetime_units(self): val = datetime.datetime(2013, 8, 17, 21, 17, 12, 215504) - stamp = Timestamp(val) + stamp = Timestamp(val).as_unit("ns") roundtrip = ujson.decode(ujson.encode(val, date_unit="s")) assert roundtrip == stamp.value // 10**9 diff --git a/pandas/tests/scalar/period/test_period.py b/pandas/tests/scalar/period/test_period.py index b5491c606ec7b..112f23b3b0f16 100644 --- a/pandas/tests/scalar/period/test_period.py +++ b/pandas/tests/scalar/period/test_period.py @@ -821,7 +821,7 @@ def test_end_time(self): p = Period("2012", freq="A") def _ex(*args): - return Timestamp(Timestamp(datetime(*args)).value - 1) + return Timestamp(Timestamp(datetime(*args)).as_unit("ns").value - 1) xp = _ex(2013, 1, 1) assert xp == p.end_time @@ -873,7 +873,7 @@ def test_end_time_business_friday(self): def test_anchor_week_end_time(self): def _ex(*args): - return Timestamp(Timestamp(datetime(*args)).value - 1) + return Timestamp(Timestamp(datetime(*args)).as_unit("ns").value - 1) p = Period("2013-1-1", "W-SAT") xp = _ex(2013, 1, 6) diff --git a/pandas/tests/scalar/timestamp/test_arithmetic.py b/pandas/tests/scalar/timestamp/test_arithmetic.py index 17fee1ff3f949..0ddbdddef5465 100644 --- a/pandas/tests/scalar/timestamp/test_arithmetic.py +++ b/pandas/tests/scalar/timestamp/test_arithmetic.py @@ -197,11 +197,15 @@ def test_radd_tdscalar(self, td, fixed_now_ts): ], ) def test_timestamp_add_timedelta64_unit(self, other, expected_difference): - ts = Timestamp(datetime.utcnow()) + now = datetime.utcnow() + ts = Timestamp(now).as_unit("ns") result = ts + other valdiff = result.value - ts.value assert valdiff == expected_difference + ts2 = Timestamp(now) + assert ts2 + other == result + @pytest.mark.parametrize( "ts", [ diff --git a/pandas/tests/scalar/timestamp/test_constructors.py b/pandas/tests/scalar/timestamp/test_constructors.py index 9c3fa0f64153a..4294bf326950c 100644 --- a/pandas/tests/scalar/timestamp/test_constructors.py +++ b/pandas/tests/scalar/timestamp/test_constructors.py @@ -1,5 +1,6 @@ import calendar from datetime import ( + date, datetime, timedelta, timezone, @@ -23,6 +24,13 @@ class TestTimestampConstructors: + def test_constructor_from_date_second_reso(self): + # GH#49034 constructing from a pydate object gets lowest supported + # reso, i.e. seconds + obj = date(2012, 9, 1) + ts = Timestamp(obj) + assert ts.unit == "s" + @pytest.mark.parametrize("typ", [int, float]) def test_constructor_int_float_with_YM_unit(self, typ): # GH#47266 avoid the conversions in cast_from_unit @@ -97,8 +105,9 @@ def test_constructor(self): (dateutil.tz.tzoffset(None, 18000), 5), ] - for date_str, date, expected in tests: - for result in [Timestamp(date_str), Timestamp(date)]: + for date_str, date_obj, expected in tests: + for result in [Timestamp(date_str), Timestamp(date_obj)]: + result = result.as_unit("ns") # test originally written before non-nano # only with timestring assert result.value == expected @@ -108,7 +117,10 @@ def test_constructor(self): # with timezone for tz, offset in timezones: - for result in [Timestamp(date_str, tz=tz), Timestamp(date, tz=tz)]: + for result in [Timestamp(date_str, tz=tz), Timestamp(date_obj, tz=tz)]: + result = result.as_unit( + "ns" + ) # test originally written before non-nano expected_tz = expected - offset * 3600 * 1_000_000_000 assert result.value == expected_tz diff --git a/pandas/tests/scalar/timestamp/test_timestamp.py b/pandas/tests/scalar/timestamp/test_timestamp.py index f5b9a35a53a24..c20e6052b1f7e 100644 --- a/pandas/tests/scalar/timestamp/test_timestamp.py +++ b/pandas/tests/scalar/timestamp/test_timestamp.py @@ -579,7 +579,9 @@ def test_to_datetime_bijective(self): with tm.assert_produces_warning(exp_warning): pydt_max = Timestamp.max.to_pydatetime() - assert Timestamp(pydt_max).value / 1000 == Timestamp.max.value / 1000 + assert ( + Timestamp(pydt_max).as_unit("ns").value / 1000 == Timestamp.max.value / 1000 + ) exp_warning = None if Timestamp.min.nanosecond == 0 else UserWarning with tm.assert_produces_warning(exp_warning): @@ -590,7 +592,10 @@ def test_to_datetime_bijective(self): tdus = timedelta(microseconds=1) assert pydt_min + tdus > Timestamp.min - assert Timestamp(pydt_min + tdus).value / 1000 == Timestamp.min.value / 1000 + assert ( + Timestamp(pydt_min + tdus).as_unit("ns").value / 1000 + == Timestamp.min.value / 1000 + ) def test_to_period_tz_warning(self): # GH#21333 make sure a warning is issued when timezone diff --git a/pandas/tests/tools/test_to_datetime.py b/pandas/tests/tools/test_to_datetime.py index 4c70aeb3e36aa..125d42def65e3 100644 --- a/pandas/tests/tools/test_to_datetime.py +++ b/pandas/tests/tools/test_to_datetime.py @@ -654,7 +654,7 @@ def test_to_datetime_today(self, tz): pdtoday2 = to_datetime(["today"])[0] tstoday = Timestamp("today") - tstoday2 = Timestamp.today() + tstoday2 = Timestamp.today().as_unit("ns") # These should all be equal with infinite perf; this gives # a generous margin of 10 seconds @@ -2750,7 +2750,13 @@ def test_epoch(self, units, epochs, epoch_1960, units_from_epochs): ) def test_invalid_origins(self, origin, exc, units, units_from_epochs): - msg = f"origin {origin} (is Out of Bounds|cannot be converted to a Timestamp)" + msg = "|".join( + [ + f"origin {origin} is Out of Bounds", + f"origin {origin} cannot be converted to a Timestamp", + "Cannot cast .* to unit='ns' without overflow", + ] + ) with pytest.raises(exc, match=msg): to_datetime(units_from_epochs, unit=units, origin=origin)