Skip to content

Commit e9d41d1

Browse files
authored
fix: use public pandas APIs where possible (#60)
* refactor: use public pandas APIs where possible * no need to override take * backport take implementation * move remaining private pandas methods to backports * add note about _validate_scalar to docstring * comment why we can't use public mixin
1 parent 5cb2c6b commit e9d41d1

File tree

3 files changed

+69
-74
lines changed

3 files changed

+69
-74
lines changed

db_dtypes/__init__.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,7 @@
2222
import numpy
2323
import packaging.version
2424
import pandas
25-
import pandas.compat.numpy.function
26-
import pandas.core.algorithms
27-
import pandas.core.arrays
28-
import pandas.core.dtypes.base
29-
import pandas.core.dtypes.dtypes
30-
import pandas.core.dtypes.generic
31-
import pandas.core.nanops
25+
import pandas.api.extensions
3226
import pyarrow
3327
import pyarrow.compute
3428

@@ -44,7 +38,7 @@
4438
pandas_release = packaging.version.parse(pandas.__version__).release
4539

4640

47-
@pandas.core.dtypes.dtypes.register_extension_dtype
41+
@pandas.api.extensions.register_extension_dtype
4842
class TimeDtype(core.BaseDatetimeDtype):
4943
"""
5044
Extension dtype for time data.
@@ -113,7 +107,7 @@ def _datetime(
113107
.as_py()
114108
)
115109

116-
if scalar is None:
110+
if pandas.isna(scalar):
117111
return None
118112
if isinstance(scalar, datetime.time):
119113
return pandas.Timestamp(
@@ -194,7 +188,7 @@ def __arrow_array__(self, type=None):
194188
)
195189

196190

197-
@pandas.core.dtypes.dtypes.register_extension_dtype
191+
@pandas.api.extensions.register_extension_dtype
198192
class DateDtype(core.BaseDatetimeDtype):
199193
"""
200194
Extension dtype for time data.
@@ -238,7 +232,7 @@ def _datetime(
238232
if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)):
239233
scalar = scalar.as_py()
240234

241-
if scalar is None:
235+
if pandas.isna(scalar):
242236
return None
243237
elif isinstance(scalar, datetime.date):
244238
return pandas.Timestamp(

db_dtypes/core.py

+18-62
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Optional, Sequence
15+
from typing import Optional
1616

1717
import numpy
1818
import pandas
19-
from pandas._libs import NaT
19+
from pandas import NaT
2020
import pandas.api.extensions
21-
import pandas.compat.numpy.function
22-
import pandas.core.algorithms
23-
import pandas.core.arrays
24-
import pandas.core.dtypes.base
25-
from pandas.core.dtypes.common import is_dtype_equal, is_list_like, pandas_dtype
26-
import pandas.core.dtypes.dtypes
27-
import pandas.core.dtypes.generic
28-
import pandas.core.nanops
21+
from pandas.api.types import is_dtype_equal, is_list_like, pandas_dtype
2922

3023
from db_dtypes import pandas_backports
3124

@@ -107,42 +100,11 @@ def isna(self):
107100
return pandas.isna(self._ndarray)
108101

109102
def _validate_scalar(self, value):
110-
if pandas.isna(value):
111-
return None
112-
113-
if not isinstance(value, self.dtype.type):
114-
raise ValueError(value)
115-
116-
return value
117-
118-
def take(
119-
self,
120-
indices: Sequence[int],
121-
*,
122-
allow_fill: bool = False,
123-
fill_value: Any = None,
124-
):
125-
indices = numpy.asarray(indices, dtype=numpy.intp)
126-
data = self._ndarray
127-
if allow_fill:
128-
fill_value = self._validate_scalar(fill_value)
129-
fill_value = (
130-
numpy.datetime64() if fill_value is None else self._datetime(fill_value)
131-
)
132-
if (indices < -1).any():
133-
raise ValueError(
134-
"take called with negative indexes other than -1,"
135-
" when a fill value is provided."
136-
)
137-
out = data.take(indices)
138-
if allow_fill:
139-
out[indices == -1] = fill_value
140-
141-
return self.__class__(out)
142-
143-
# TODO: provide implementations of dropna, fillna, unique,
144-
# factorize, argsort, searchsoeted for better performance over
145-
# abstract implementations.
103+
"""
104+
Validate and convert a scalar value to datetime64[ns] for storage in
105+
backing NumPy array.
106+
"""
107+
return self._datetime(value)
146108

147109
def any(
148110
self,
@@ -152,10 +114,8 @@ def any(
152114
keepdims: bool = False,
153115
skipna: bool = True,
154116
):
155-
pandas.compat.numpy.function.validate_any(
156-
(), {"out": out, "keepdims": keepdims}
157-
)
158-
result = pandas.core.nanops.nanany(self._ndarray, axis=axis, skipna=skipna)
117+
pandas_backports.numpy_validate_any((), {"out": out, "keepdims": keepdims})
118+
result = pandas_backports.nanany(self._ndarray, axis=axis, skipna=skipna)
159119
return result
160120

161121
def all(
@@ -166,22 +126,20 @@ def all(
166126
keepdims: bool = False,
167127
skipna: bool = True,
168128
):
169-
pandas.compat.numpy.function.validate_all(
170-
(), {"out": out, "keepdims": keepdims}
171-
)
172-
result = pandas.core.nanops.nanall(self._ndarray, axis=axis, skipna=skipna)
129+
pandas_backports.numpy_validate_all((), {"out": out, "keepdims": keepdims})
130+
result = pandas_backports.nanall(self._ndarray, axis=axis, skipna=skipna)
173131
return result
174132

175133
def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
176-
pandas.compat.numpy.function.validate_min((), kwargs)
177-
result = pandas.core.nanops.nanmin(
134+
pandas_backports.numpy_validate_min((), kwargs)
135+
result = pandas_backports.nanmin(
178136
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
179137
)
180138
return self._box_func(result)
181139

182140
def max(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
183-
pandas.compat.numpy.function.validate_max((), kwargs)
184-
result = pandas.core.nanops.nanmax(
141+
pandas_backports.numpy_validate_max((), kwargs)
142+
result = pandas_backports.nanmax(
185143
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
186144
)
187145
return self._box_func(result)
@@ -197,11 +155,9 @@ def median(
197155
keepdims: bool = False,
198156
skipna: bool = True,
199157
):
200-
pandas.compat.numpy.function.validate_median(
158+
pandas_backports.numpy_validate_median(
201159
(),
202160
{"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims},
203161
)
204-
result = pandas.core.nanops.nanmedian(
205-
self._ndarray, axis=axis, skipna=skipna
206-
)
162+
result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna)
207163
return self._box_func(result)

db_dtypes/pandas_backports.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,32 @@
2020
"""
2121

2222
import operator
23+
from typing import Any
2324

2425
import numpy
2526
import packaging.version
2627
import pandas
27-
from pandas._libs.lib import is_integer
28+
from pandas.api.types import is_integer
29+
import pandas.compat.numpy.function
30+
import pandas.core.nanops
2831

2932

3033
pandas_release = packaging.version.parse(pandas.__version__).release
3134

35+
# Create aliases for private methods in case they move in a future version.
36+
nanall = pandas.core.nanops.nanall
37+
nanany = pandas.core.nanops.nanany
38+
nanmax = pandas.core.nanops.nanmax
39+
nanmin = pandas.core.nanops.nanmin
40+
numpy_validate_all = pandas.compat.numpy.function.validate_all
41+
numpy_validate_any = pandas.compat.numpy.function.validate_any
42+
numpy_validate_max = pandas.compat.numpy.function.validate_max
43+
numpy_validate_min = pandas.compat.numpy.function.validate_min
44+
45+
if pandas_release >= (1, 2):
46+
nanmedian = pandas.core.nanops.nanmedian
47+
numpy_validate_median = pandas.compat.numpy.function.validate_median
48+
3249

3350
def import_default(module_name, force=False, default=None):
3451
"""
@@ -55,6 +72,10 @@ def import_default(module_name, force=False, default=None):
5572
return getattr(module, name, default)
5673

5774

75+
# pandas.core.arraylike.OpsMixin is private, but the related public API
76+
# "ExtensionScalarOpsMixin" is not sufficient for adding dates to times.
77+
# It results in unsupported operand type(s) for +: 'datetime.time' and
78+
# 'datetime.date'
5879
@import_default("pandas.core.arraylike")
5980
class OpsMixin:
6081
def _cmp_method(self, other, op): # pragma: NO COVER
@@ -81,6 +102,8 @@ def __ge__(self, other):
81102
__add__ = __radd__ = __sub__ = lambda self, other: NotImplemented
82103

83104

105+
# TODO: use public API once pandas 1.5 / 2.x is released.
106+
# See: https://github.com/pandas-dev/pandas/pull/45544
84107
@import_default("pandas.core.arrays._mixins", pandas_release < (1, 3))
85108
class NDArrayBackedExtensionArray(pandas.core.arrays.base.ExtensionArray):
86109

@@ -130,6 +153,28 @@ def copy(self):
130153
def repeat(self, n):
131154
return self.__class__(self._ndarray.repeat(n), self._dtype)
132155

156+
def take(
157+
self,
158+
indices,
159+
*,
160+
allow_fill: bool = False,
161+
fill_value: Any = None,
162+
axis: int = 0,
163+
):
164+
from pandas.core.algorithms import take
165+
166+
if allow_fill:
167+
fill_value = self._validate_scalar(fill_value)
168+
169+
new_data = take(
170+
self._ndarray,
171+
indices,
172+
allow_fill=allow_fill,
173+
fill_value=fill_value,
174+
axis=axis,
175+
)
176+
return self._from_backing_data(new_data)
177+
133178
@classmethod
134179
def _concat_same_type(cls, to_concat, axis=0):
135180
dtypes = {str(x.dtype) for x in to_concat}

0 commit comments

Comments
 (0)