Skip to content

Commit be58cd9

Browse files
authored
REF: implement _unbox to de-duplicate unwrapping (#33906)
1 parent d149f41 commit be58cd9

File tree

2 files changed

+28
-40
lines changed

2 files changed

+28
-40
lines changed

pandas/core/arrays/datetimelike.py

+27-38
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pandas._libs import NaT, NaTType, Timestamp, algos, iNaT, lib
99
from pandas._libs.tslibs.c_timestamp import integer_op_not_supported
1010
from pandas._libs.tslibs.period import DIFFERENT_FREQ, IncompatibleFrequency, Period
11-
from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds
11+
from pandas._libs.tslibs.timedeltas import delta_to_nanoseconds
1212
from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64
1313
from pandas._typing import DatetimeLikeScalar
1414
from pandas.compat import set_function_name
@@ -52,6 +52,8 @@
5252
from pandas.tseries import frequencies
5353
from pandas.tseries.offsets import DateOffset, Tick
5454

55+
DTScalarOrNaT = Union[DatetimeLikeScalar, NaTType]
56+
5557

5658
def _datetimelike_array_cmp(cls, op):
5759
"""
@@ -122,12 +124,7 @@ def wrapper(self, other):
122124
result = ops.comp_method_OBJECT_ARRAY(op, self.astype(object), other)
123125
return result
124126

125-
if isinstance(other, self._scalar_type) or other is NaT:
126-
other_i8 = self._unbox_scalar(other)
127-
else:
128-
# Then type(other) == type(self)
129-
other_i8 = other.asi8
130-
127+
other_i8 = self._unbox(other)
131128
result = op(self.asi8, other_i8)
132129

133130
o_mask = isna(other)
@@ -157,9 +154,7 @@ def _scalar_type(self) -> Type[DatetimeLikeScalar]:
157154
"""
158155
raise AbstractMethodError(self)
159156

160-
def _scalar_from_string(
161-
self, value: str
162-
) -> Union[Period, Timestamp, Timedelta, NaTType]:
157+
def _scalar_from_string(self, value: str) -> DTScalarOrNaT:
163158
"""
164159
Construct a scalar type from a string.
165160
@@ -179,13 +174,14 @@ def _scalar_from_string(
179174
"""
180175
raise AbstractMethodError(self)
181176

182-
def _unbox_scalar(self, value: Union[Period, Timestamp, Timedelta, NaTType]) -> int:
177+
def _unbox_scalar(self, value: DTScalarOrNaT) -> int:
183178
"""
184179
Unbox the integer value of a scalar `value`.
185180
186181
Parameters
187182
----------
188-
value : Union[Period, Timestamp, Timedelta]
183+
value : Period, Timestamp, Timedelta, or NaT
184+
Depending on subclass.
189185
190186
Returns
191187
-------
@@ -199,7 +195,7 @@ def _unbox_scalar(self, value: Union[Period, Timestamp, Timedelta, NaTType]) ->
199195
raise AbstractMethodError(self)
200196

201197
def _check_compatible_with(
202-
self, other: Union[Period, Timestamp, Timedelta, NaTType], setitem: bool = False
198+
self, other: DTScalarOrNaT, setitem: bool = False
203199
) -> None:
204200
"""
205201
Verify that `self` and `other` are compatible.
@@ -727,17 +723,16 @@ def _validate_fill_value(self, fill_value):
727723
ValueError
728724
"""
729725
if is_valid_nat_for_dtype(fill_value, self.dtype):
730-
fill_value = iNaT
726+
fill_value = NaT
731727
elif isinstance(fill_value, self._recognized_scalars):
732-
self._check_compatible_with(fill_value)
733728
fill_value = self._scalar_type(fill_value)
734-
fill_value = self._unbox_scalar(fill_value)
735729
else:
736730
raise ValueError(
737731
f"'fill_value' should be a {self._scalar_type}. "
738732
f"Got '{str(fill_value)}'."
739733
)
740-
return fill_value
734+
735+
return self._unbox(fill_value)
741736

742737
def _validate_shift_value(self, fill_value):
743738
# TODO(2.0): once this deprecation is enforced, use _validate_fill_value
@@ -764,8 +759,7 @@ def _validate_shift_value(self, fill_value):
764759
)
765760
fill_value = new_fill
766761

767-
fill_value = self._unbox_scalar(fill_value)
768-
return fill_value
762+
return self._unbox(fill_value)
769763

770764
def _validate_searchsorted_value(self, value):
771765
if isinstance(value, str):
@@ -797,13 +791,7 @@ def _validate_searchsorted_value(self, value):
797791
else:
798792
raise TypeError(f"Unexpected type for 'value': {type(value)}")
799793

800-
if isinstance(value, type(self)):
801-
self._check_compatible_with(value)
802-
value = value.asi8
803-
else:
804-
value = self._unbox_scalar(value)
805-
806-
return value
794+
return self._unbox(value)
807795

808796
def _validate_setitem_value(self, value):
809797

@@ -836,19 +824,11 @@ def _validate_setitem_value(self, value):
836824
raise TypeError(msg)
837825

838826
self._check_compatible_with(value, setitem=True)
839-
if isinstance(value, type(self)):
840-
value = value.asi8
841-
else:
842-
value = self._unbox_scalar(value)
843-
844-
return value
827+
return self._unbox(value)
845828

846829
def _validate_insert_value(self, value):
847830
if isinstance(value, self._recognized_scalars):
848831
value = self._scalar_type(value)
849-
self._check_compatible_with(value, setitem=True)
850-
# TODO: if we dont have compat, should we raise or astype(object)?
851-
# PeriodIndex does astype(object)
852832
elif is_valid_nat_for_dtype(value, self.dtype):
853833
# GH#18295
854834
value = NaT
@@ -857,14 +837,16 @@ def _validate_insert_value(self, value):
857837
f"cannot insert {type(self).__name__} with incompatible label"
858838
)
859839

840+
self._check_compatible_with(value, setitem=True)
841+
# TODO: if we dont have compat, should we raise or astype(object)?
842+
# PeriodIndex does astype(object)
860843
return value
861844

862845
def _validate_where_value(self, other):
863846
if is_valid_nat_for_dtype(other, self.dtype):
864847
other = NaT
865848
elif isinstance(other, self._recognized_scalars):
866849
other = self._scalar_type(other)
867-
self._check_compatible_with(other, setitem=True)
868850
elif not is_list_like(other):
869851
raise TypeError(f"Where requires matching dtype, not {type(other)}")
870852

@@ -881,13 +863,20 @@ def _validate_where_value(self, other):
881863

882864
if not type(self)._is_recognized_dtype(other.dtype):
883865
raise TypeError(f"Where requires matching dtype, not {other.dtype}")
884-
self._check_compatible_with(other, setitem=True)
885866

867+
self._check_compatible_with(other, setitem=True)
868+
return self._unbox(other)
869+
870+
def _unbox(self, other) -> Union[np.int64, np.ndarray]:
871+
"""
872+
Unbox either a scalar with _unbox_scalar or an instance of our own type.
873+
"""
886874
if lib.is_scalar(other):
887875
other = self._unbox_scalar(other)
888876
else:
877+
# same type as self
878+
self._check_compatible_with(other)
889879
other = other.view("i8")
890-
891880
return other
892881

893882
# ------------------------------------------------------------------

pandas/core/arrays/period.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,7 @@ def _unbox_scalar(self, value: Union[Period, NaTType]) -> int:
249249
if value is NaT:
250250
return value.value
251251
elif isinstance(value, self._scalar_type):
252-
if not isna(value):
253-
self._check_compatible_with(value)
252+
self._check_compatible_with(value)
254253
return value.ordinal
255254
else:
256255
raise ValueError(f"'value' should be a Period. Got '{value}' instead.")

0 commit comments

Comments
 (0)