Skip to content

[WIP]: API ExtensionDtype for DTA & TDA #24674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
3 changes: 2 additions & 1 deletion doc/source/api/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Methods

A collection of timestamps may be stored in a :class:`arrays.DatetimeArray`.
For timezone-aware data, the ``.dtype`` of a ``DatetimeArray`` is a
:class:`DatetimeTZDtype`. For timezone-naive data, ``np.dtype("datetime64[ns]")``
:class:`DatetimeTZDtype`. For timezone-naive data, :class:`DatetimeDtype`
is used.

If the data are tz-aware, then every value in the array must have the same timezone.
Expand All @@ -145,6 +145,7 @@ If the data are tz-aware, then every value in the array must have the same timez
:toctree: generated/

arrays.DatetimeArray
DatetimeDtype
DatetimeTZDtype

.. _api.arrays.timedelta:
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
CategoricalDtype,
PeriodDtype,
IntervalDtype,
DatetimeDtype,
DatetimeTZDtype,
TimedeltaDtype,
)
from pandas.core.arrays import Categorical, array
from pandas.core.groupby import Grouper
Expand Down
7 changes: 5 additions & 2 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
is_datetime64_ns_dtype, is_datetime64tz_dtype, is_dtype_equal,
is_extension_type, is_float_dtype, is_object_dtype, is_period_dtype,
is_string_dtype, is_timedelta64_dtype, pandas_dtype)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.dtypes import DatetimeDtype, DatetimeTZDtype
from pandas.core.dtypes.generic import (
ABCDataFrame, ABCIndexClass, ABCPandasArray, ABCSeries)
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -334,6 +334,8 @@ def __init__(self, values, dtype=_NS_DTYPE, freq=None, copy=False):
# a tz-aware Timestamp (with a tz specific to its datetime) will
# be incorrect(ish?) for the array as a whole
dtype = DatetimeTZDtype(tz=timezones.tz_standardize(dtype.tz))
else:
dtype = DatetimeDtype()

self._data = values
self._dtype = dtype
Expand Down Expand Up @@ -1987,7 +1989,8 @@ def _validate_dt64_dtype(dtype):
if dtype is not None:
dtype = pandas_dtype(dtype)
if ((isinstance(dtype, np.dtype) and dtype != _NS_DTYPE)
or not isinstance(dtype, (np.dtype, DatetimeTZDtype))):
or not isinstance(dtype, (np.dtype, DatetimeTZDtype,
DatetimeDtype))):
raise ValueError("Unexpected value for 'dtype': '{dtype}'. "
"Must be 'datetime64[ns]' or DatetimeTZDtype'."
.format(dtype=dtype))
Expand Down
61 changes: 44 additions & 17 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
is_string_dtype, is_timedelta64_dtype, is_timedelta64_ns_dtype,
pandas_dtype)
from pandas.core.dtypes.dtypes import DatetimeTZDtype
from pandas.core.dtypes.dtypes import DatetimeTZDtype, TimedeltaDtype
from pandas.core.dtypes.generic import (
ABCDataFrame, ABCIndexClass, ABCSeries, ABCTimedeltaIndex)
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -127,7 +127,7 @@ def _box_func(self):

@property
def dtype(self):
return _TD_DTYPE
return self._dtype

# ----------------------------------------------------------------
# Constructors
Expand Down Expand Up @@ -160,16 +160,8 @@ def __init__(self, values, dtype=_TD_DTYPE, freq=None, copy=False):
# nanosecond UTC (or tz-naive) unix timestamps
values = values.view(_TD_DTYPE)

if values.dtype != _TD_DTYPE:
raise TypeError(_BAD_DTYPE.format(dtype=values.dtype))

try:
dtype_mismatch = dtype != _TD_DTYPE
except TypeError:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
else:
if dtype_mismatch:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
_validate_td64_dtype(values.dtype)
dtype = _validate_td64_dtype(dtype)

if freq == "infer":
msg = (
Expand All @@ -192,21 +184,19 @@ def __init__(self, values, dtype=_TD_DTYPE, freq=None, copy=False):

@classmethod
def _simple_new(cls, values, freq=None, dtype=_TD_DTYPE):
assert dtype == _TD_DTYPE, dtype
dtype = _validate_td64_dtype(dtype)
assert isinstance(values, np.ndarray), type(values)

result = object.__new__(cls)
result._data = values.view(_TD_DTYPE)
result._freq = to_offset(freq)
result._dtype = _TD_DTYPE
result._dtype = dtype
return result

@classmethod
def _from_sequence(cls, data, dtype=_TD_DTYPE, copy=False,
freq=None, unit=None):
if dtype != _TD_DTYPE:
raise ValueError("Only timedelta64[ns] dtype is valid.")

_validate_td64_dtype(dtype)
freq, freq_infer = dtl.maybe_infer_freq(freq)

data, inferred_freq = sequence_to_td64ns(data, copy=copy, unit=unit)
Expand Down Expand Up @@ -1015,3 +1005,40 @@ def _generate_regular_range(start, end, periods, offset):

data = np.arange(b, e, stride, dtype=np.int64)
return data


def _validate_td64_dtype(dtype):
"""
Validate a dtype for TimedeltaArray.

Parameters
----------
dtype : Union[str, numpy.dtype, Timedelta]
Only np.dtype("m8[ns]") is allowed numpy dtypes.

Returns
-------
TimedeltaDtype
"""
if isinstance(dtype, compat.string_types):
try:
dtype = np.dtype(dtype)
except TypeError:
# not a Numpy dtype
pass

if isinstance(dtype, np.dtype):
if dtype != _TD_DTYPE:
raise TypeError(_BAD_DTYPE.format(dtype=dtype))

dtype = TimedeltaDtype()

elif isinstance(dtype, compat.string_types):
if dtype != "ns":
raise TypeError(_BAD_DTYPE.format(dtype=dtype))
dtype = TimedeltaDtype(dtype)

if dtype != TimedeltaDtype():
raise ValueError("Only timedelta64[ns] dtype is valid")

return dtype
36 changes: 28 additions & 8 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import numpy as np

from pandas._libs import algos, lib
from pandas._libs.tslibs import conversion
from pandas._libs.tslibs import Timedelta, Timestamp, conversion
from pandas.compat import PY3, PY36, string_types

from pandas.core.dtypes.dtypes import (
CategoricalDtype, DatetimeTZDtype, ExtensionDtype, IntervalDtype,
PandasExtensionDtype, PeriodDtype, registry)
PandasExtensionDtype, PeriodDtype, TimedeltaDtype, registry)
from pandas.core.dtypes.generic import (
ABCCategorical, ABCDateOffset, ABCDatetimeIndex, ABCIndexClass,
ABCPeriodArray, ABCPeriodIndex, ABCSeries)
Expand Down Expand Up @@ -426,9 +426,24 @@ def is_datetime64_dtype(arr_or_dtype):
True
>>> is_datetime64_dtype([1, 2, 3])
False
>>> is_datetime64_dtype(pd.DatetimeDtype())
True
>>> is_datetime64_dtype(pd.DatetimeTZDtype(tz="CET"))
False
"""
# It's somewhat tricky to support both of the following:
# 1. is_datetime64_dtype(DatetimeDtype()) == True
# 2. is_datetime64_dtype(DatetimeTZDtype()) == False
# because both use `Timestamp` as the `type`.
# So we look at the `dtype` to see if there's a `.tz` attached.
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
try:
dtype = pandas_dtype(dtype)
except (ValueError, TypeError):
dtype = None

return _is_dtype_type(arr_or_dtype, classes(np.datetime64))
return (_is_dtype_type(arr_or_dtype, classes(np.datetime64, Timestamp))
and getattr(dtype, 'tz', None) is None)


def is_datetime64tz_dtype(arr_or_dtype):
Expand Down Expand Up @@ -497,7 +512,7 @@ def is_timedelta64_dtype(arr_or_dtype):
False
"""

return _is_dtype_type(arr_or_dtype, classes(np.timedelta64))
return _is_dtype_type(arr_or_dtype, classes(np.timedelta64, Timedelta))


def is_period_dtype(arr_or_dtype):
Expand Down Expand Up @@ -1192,7 +1207,10 @@ def is_timedelta64_ns_dtype(arr_or_dtype):
>>> is_timedelta64_ns_dtype(np.array([1, 2], dtype=np.timedelta64))
False
"""
return _is_dtype(arr_or_dtype, lambda dtype: dtype == _TD_DTYPE)
def condition(dtype):
return isinstance(dtype, TimedeltaDtype) or dtype == _TD_DTYPE

return _is_dtype(arr_or_dtype, condition)


def is_datetime_or_timedelta_dtype(arr_or_dtype):
Expand Down Expand Up @@ -1229,9 +1247,11 @@ def is_datetime_or_timedelta_dtype(arr_or_dtype):
>>> is_datetime_or_timedelta_dtype(np.array([], dtype=np.datetime64))
True
"""

return _is_dtype_type(
arr_or_dtype, classes(np.datetime64, np.timedelta64))
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
return (_is_dtype_type(
arr_or_dtype, classes(np.datetime64, np.timedelta64,
Timestamp, Timedelta))
and getattr(dtype, 'tz', None) is None)


def _is_unorderable_exception(e):
Expand Down
Loading